def _wait_for_batch()

in core/train_pipeline.py [0:0]


def _wait_for_batch(batch: In, stream: Optional[torch.cuda.streams.Stream]) -> None:
  if stream is None:
    return
  torch.cuda.current_stream().wait_stream(stream)
  # As mentioned in https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html,
  # PyTorch uses the "caching allocator" for memory allocation for tensors. When a tensor is
  # freed, its memory is likely to be reused by newly constructed tenosrs.  By default,
  # this allocator traces whether a tensor is still in use by only the CUDA stream where it
  # was created.   When a tensor is used by additional CUDA streams, we need to call record_stream
  # to tell the allocator about all these streams.  Otherwise, the allocator might free the
  # underlying memory of the tensor once it is no longer used by the creator stream.  This is
  # a notable programming trick when we write programs using multi CUDA streams.
  cur_stream = torch.cuda.current_stream()
  assert isinstance(
    batch, (torch.Tensor, Multistreamable)
  ), f"{type(batch)} must implement Multistreamable interface"
  batch.record_stream(cur_stream)