def get_global_loss_detached()

in core/losses.py [0:0]


def get_global_loss_detached(local_loss, reduction="mean"):
  """
  Perform all_reduce to obtain the global loss function using the provided reduction.
  :param local_loss: The local loss of the current rank.
  :param reduction: The reduction to use for all_reduce. Should match the reduction used by DDP.
  :return: The reduced & detached global loss.
  """
  if reduction != "mean":
    logging.warn(
      f"The reduction used in this function should be the same as the one used by "
      f"the DDP model. By default DDP uses mean, So ensure that DDP is appropriately"
      f"modified for reduction {reduction}."
    )

  if reduction not in ["mean", "sum"]:
    raise ValueError(f"Reduction {reduction} is currently unsupported.")

  global_loss = local_loss.detach()

  if reduction == "mean":
    global_loss.div_(torch.distributed.get_world_size())

  torch.distributed.all_reduce(global_loss)
  return global_loss