def setup_and_get_device()

in common/device.py [0:0]


def setup_and_get_device(tf_ok: bool = True) -> torch.device:
  if tf_ok:
    maybe_setup_tensorflow()

  device = torch.device("cpu")
  backend = "gloo"
  if torch.cuda.is_available():
    rank = os.environ["LOCAL_RANK"]
    device = torch.device(f"cuda:{rank}")
    backend = "nccl"
    torch.cuda.set_device(device)
  if not torch.distributed.is_initialized():
    dist.init_process_group(backend)

  return device