def loss_fn()

in core/losses.py [0:0]


  def loss_fn(logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor):
    if pos_weights is None:
      torch_weights = torch.ones([len(tasks)])
    else:
      torch_weights = torch.tensor(pos_weights)

    losses = {}
    for task_idx, task in enumerate(tasks):
      task_logits = logits[:, task_idx]
      label = labels[:, task_idx].type_as(task_logits)

      loss = f(
        task_logits,
        label,
        reduction=task_loss_reduction,
        pos_weight=torch_weights[task_idx],
        weight=weights[:, task_idx],
      )
      losses[f"loss/{task}"] = loss

    losses["loss"] = loss_reduction_fns[global_reduction](torch.stack(list(losses.values())))
    return losses