def forward()

in projects/twhin/models/models.py [0:0]


  def forward(self, batch: EdgeBatch):

    # B x D
    trans_embs = self.all_trans_embs.data[batch.rels]

    # KeyedTensor
    outs = self.large_embeddings(batch.nodes)

    # 2B x TD
    x = outs.values()

    # 2B x T x D
    x = x.reshape(2 * self.batch_size, -1, self.embedding_dim)

    # 2B x D
    x = torch.sum(x, 1)

    # B x 2 x D
    x = x.reshape(self.batch_size, 2, self.embedding_dim)

    # translated
    translated = x[:, 1, :] + trans_embs

    negs = []
    if self.in_batch_negatives:
      # construct dot products for negatives via matmul
      for relation in range(self.num_relations):
        rel_mask = batch.rels == relation
        rel_count = rel_mask.sum()

        if not rel_count:
          continue

        # R x D
        lhs_matrix = x[rel_mask, 0, :]
        rhs_matrix = x[rel_mask, 1, :]

        lhs_perm = torch.randperm(lhs_matrix.shape[0])
        # repeat until we have enough negatives
        lhs_perm = lhs_perm.repeat(math.ceil(float(self.in_batch_negatives) / rel_count))
        lhs_indices = lhs_perm[: self.in_batch_negatives]
        sampled_lhs = lhs_matrix[lhs_indices]

        rhs_perm = torch.randperm(rhs_matrix.shape[0])
        # repeat until we have enough negatives
        rhs_perm = rhs_perm.repeat(math.ceil(float(self.in_batch_negatives) / rel_count))
        rhs_indices = rhs_perm[: self.in_batch_negatives]
        sampled_rhs = rhs_matrix[rhs_indices]

        # RS
        negs_rhs = torch.flatten(torch.matmul(lhs_matrix, sampled_rhs.t()))
        negs_lhs = torch.flatten(torch.matmul(rhs_matrix, sampled_lhs.t()))

        negs.append(negs_lhs)
        negs.append(negs_rhs)

    # dot product for positives
    x = (x[:, 0, :] * translated).sum(-1)

    # concat positives and negatives
    x = torch.cat([x, *negs])
    return {
      "logits": x,
      "probabilities": torch.sigmoid(x),
    }