in projects/twhin/models/models.py [0:0]
def forward(self, batch: EdgeBatch):
# B x D
trans_embs = self.all_trans_embs.data[batch.rels]
# KeyedTensor
outs = self.large_embeddings(batch.nodes)
# 2B x TD
x = outs.values()
# 2B x T x D
x = x.reshape(2 * self.batch_size, -1, self.embedding_dim)
# 2B x D
x = torch.sum(x, 1)
# B x 2 x D
x = x.reshape(self.batch_size, 2, self.embedding_dim)
# translated
translated = x[:, 1, :] + trans_embs
negs = []
if self.in_batch_negatives:
# construct dot products for negatives via matmul
for relation in range(self.num_relations):
rel_mask = batch.rels == relation
rel_count = rel_mask.sum()
if not rel_count:
continue
# R x D
lhs_matrix = x[rel_mask, 0, :]
rhs_matrix = x[rel_mask, 1, :]
lhs_perm = torch.randperm(lhs_matrix.shape[0])
# repeat until we have enough negatives
lhs_perm = lhs_perm.repeat(math.ceil(float(self.in_batch_negatives) / rel_count))
lhs_indices = lhs_perm[: self.in_batch_negatives]
sampled_lhs = lhs_matrix[lhs_indices]
rhs_perm = torch.randperm(rhs_matrix.shape[0])
# repeat until we have enough negatives
rhs_perm = rhs_perm.repeat(math.ceil(float(self.in_batch_negatives) / rel_count))
rhs_indices = rhs_perm[: self.in_batch_negatives]
sampled_rhs = rhs_matrix[rhs_indices]
# RS
negs_rhs = torch.flatten(torch.matmul(lhs_matrix, sampled_rhs.t()))
negs_lhs = torch.flatten(torch.matmul(rhs_matrix, sampled_lhs.t()))
negs.append(negs_lhs)
negs.append(negs_rhs)
# dot product for positives
x = (x[:, 0, :] * translated).sum(-1)
# concat positives and negatives
x = torch.cat([x, *negs])
return {
"logits": x,
"probabilities": torch.sigmoid(x),
}