in core/train_pipeline.py [0:0]
def _connect(self, dataloader_iter: Iterator[In]) -> None:
# batch 1
with torch.cuda.stream(self._memcpy_stream):
batch_i = next(dataloader_iter)
self._batch_i = batch_i = _to_device(batch_i, self._device, non_blocking=True)
# Try to pipeline input data dist.
self._pipelined_modules = _rewrite_model(self._model, self._context, self._data_dist_stream)
with torch.cuda.stream(self._data_dist_stream):
_wait_for_batch(batch_i, self._memcpy_stream)
_start_data_dist(self._pipelined_modules, batch_i, self._context)
# batch 2
with torch.cuda.stream(self._memcpy_stream):
batch_ip1 = next(dataloader_iter)
self._batch_ip1 = batch_ip1 = _to_device(batch_ip1, self._device, non_blocking=True)
self._connected = True
self.__class__.synced_pipeline_id[id(self._model)] = id(self)