in core/losses.py [0:0]
def loss_fn(logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor):
if pos_weights is None:
torch_weights = torch.ones([len(tasks)])
else:
torch_weights = torch.tensor(pos_weights)
losses = {}
for task_idx, task in enumerate(tasks):
task_logits = logits[:, task_idx]
label = labels[:, task_idx].type_as(task_logits)
loss = f(
task_logits,
label,
reduction=task_loss_reduction,
pos_weight=torch_weights[task_idx],
weight=weights[:, task_idx],
)
losses[f"loss/{task}"] = loss
losses["loss"] = loss_reduction_fns[global_reduction](torch.stack(list(losses.values())))
return losses