def pa_to_batch()

in projects/twhin/data/edges.py [0:0]


  def pa_to_batch(self, batch: pa.RecordBatch):
    lhs = torch.from_numpy(batch.column(self.lhs_column_name).to_numpy())
    rhs = torch.from_numpy(batch.column(self.rhs_column_name).to_numpy())
    rel = torch.from_numpy(batch.column(self.rel_column_name).to_numpy())
    label = torch.from_numpy(batch.column(self.label_column_name).to_numpy())

    nodes = self._to_kjt(lhs, rhs, rel)
    return EdgeBatch(
      nodes=nodes,
      rels=rel,
      labels=label,
      weights=torch.ones(batch.num_rows),
    )