in projects/twhin/data/edges.py [0:0]
def to_batches(self):
ds = super().to_batches()
batch_size = self._dataset_kwargs["batch_size"]
names = [
self.lhs_column_name,
self.rhs_column_name,
self.rel_column_name,
self.label_column_name,
]
for _, batch in enumerate(ds):
# Pass along positive edges
lhs = batch.column(self.lhs_column_name)
rhs = batch.column(self.rhs_column_name)
rel = batch.column(self.rel_column_name)
label = pa.array(np.ones(batch_size, dtype=np.int64))
yield pa.RecordBatch.from_arrays(
arrays=[lhs, rhs, rel, label],
names=names,
)