in projects/twhin/models/models.py [0:0]
def forward(self, batch: "RecapBatch"): # type: ignore[name-defined]
"""Runs model forward and calculates loss according to given loss_fn.
NOTE: The input signature here needs to be a Pipelineable object for
prefetching purposes during training using torchrec's pipeline. However
the underlying model signature needs to be exportable to onnx, requiring
generic python types. see https://pytorch.org/docs/stable/onnx.html#types.
"""
outputs = self.model(batch)
logits = outputs["logits"]
num_negatives = 2 * self.batch_size * self.in_batch_negatives
num_positives = self.batch_size
neg_weight = float(num_positives) / num_negatives
labels = torch.cat([batch.labels.float(), torch.ones(num_negatives).to(self.device)])
weights = torch.cat(
[batch.weights.float(), (torch.ones(num_negatives) * neg_weight).to(self.device)]
)
losses = self.loss_fn(logits, labels, weights)
outputs.update(
{
"loss": losses,
"labels": labels,
"weights": weights,
}
)
# Allow multiple losses.
return losses, outputs