def _connect()

in core/train_pipeline.py [0:0]


  def _connect(self, dataloader_iter: Iterator[In]) -> None:
    # batch 1
    with torch.cuda.stream(self._memcpy_stream):
      batch_i = next(dataloader_iter)
      self._batch_i = batch_i = _to_device(batch_i, self._device, non_blocking=True)
      # Try to pipeline input data dist.
      self._pipelined_modules = _rewrite_model(self._model, self._context, self._data_dist_stream)

    with torch.cuda.stream(self._data_dist_stream):
      _wait_for_batch(batch_i, self._memcpy_stream)
      _start_data_dist(self._pipelined_modules, batch_i, self._context)

    # batch 2
    with torch.cuda.stream(self._memcpy_stream):
      batch_ip1 = next(dataloader_iter)
      self._batch_ip1 = batch_ip1 = _to_device(batch_ip1, self._device, non_blocking=True)
    self._connected = True
    self.__class__.synced_pipeline_id[id(self._model)] = id(self)