def to_batches()

in projects/twhin/data/edges.py [0:0]


  def to_batches(self):
    ds = super().to_batches()
    batch_size = self._dataset_kwargs["batch_size"]

    names = [
      self.lhs_column_name,
      self.rhs_column_name,
      self.rel_column_name,
      self.label_column_name,
    ]
    for _, batch in enumerate(ds):
      # Pass along positive edges
      lhs = batch.column(self.lhs_column_name)
      rhs = batch.column(self.rhs_column_name)
      rel = batch.column(self.rel_column_name)
      label = pa.array(np.ones(batch_size, dtype=np.int64))

      yield pa.RecordBatch.from_arrays(
        arrays=[lhs, rhs, rel, label],
        names=names,
      )