def forward()

in projects/twhin/models/models.py [0:0]


  def forward(self, batch: "RecapBatch"):  # type: ignore[name-defined]
    """Runs model forward and calculates loss according to given loss_fn.

    NOTE: The input signature here needs to be a Pipelineable object for
    prefetching purposes during training using torchrec's pipeline.  However
    the underlying model signature needs to be exportable to onnx, requiring
    generic python types.  see https://pytorch.org/docs/stable/onnx.html#types.

    """
    outputs = self.model(batch)
    logits = outputs["logits"]

    num_negatives = 2 * self.batch_size * self.in_batch_negatives
    num_positives = self.batch_size

    neg_weight = float(num_positives) / num_negatives

    labels = torch.cat([batch.labels.float(), torch.ones(num_negatives).to(self.device)])

    weights = torch.cat(
      [batch.weights.float(), (torch.ones(num_negatives) * neg_weight).to(self.device)]
    )

    losses = self.loss_fn(logits, labels, weights)

    outputs.update(
      {
        "loss": losses,
        "labels": labels,
        "weights": weights,
      }
    )

    # Allow multiple losses.
    return losses, outputs