def progress()

in core/train_pipeline.py [0:0]


  def progress(self, dataloader_iter: Iterator[In]) -> Out:
    """
    NOTE: This method has been updated to perform gradient accumulation.
    If `_grad_accum` is set, then loss values are scaled by this amount and
    optimizer update/reset is skipped for `_grad_accum` calls of `progress`
    (congruent to training steps), and then update/reset on every `_grad_accum`th
    step.

    """
    should_step_optimizer = (
      self._grad_accum is not None
      and self._progress_calls > 0
      and (self._progress_calls + 1) % self._grad_accum == 0
    ) or self._grad_accum is None
    should_reset_optimizer = (
      self._grad_accum is not None
      and self._progress_calls > 0
      and (self._progress_calls + 2) % self._grad_accum == 0
    ) or self._grad_accum is None

    if not self._connected:
      self._connect(dataloader_iter)
    elif self.__class__.synced_pipeline_id.get(id(self._model), None) != id(self):
      self._sync_pipeline()
      self.__class__.synced_pipeline_id[id(self._model)] = id(self)

    if self._model.training and should_reset_optimizer:
      with record_function("## zero_grad ##"):
        self._optimizer.zero_grad()

    with record_function("## copy_batch_to_gpu ##"):
      with torch.cuda.stream(self._memcpy_stream):
        batch_ip2 = next(dataloader_iter)
        self._batch_ip2 = batch_ip2 = _to_device(batch_ip2, self._device, non_blocking=True)
    batch_i = cast(In, self._batch_i)
    batch_ip1 = cast(In, self._batch_ip1)

    with record_function("## wait_for_batch ##"):
      _wait_for_batch(batch_i, self._data_dist_stream)

    # Forward
    with record_function("## forward ##"):
      # if using multiple streams (ie. CUDA), create an event in default stream
      # before starting forward pass
      if self._data_dist_stream:
        event = torch.cuda.current_stream().record_event()
      if self._enable_amp:
        # conditionally apply the model to the batch in the autocast context
        # it appears that `enabled=self._enable_amp` should handle this,
        # but it does not.
        with torch.autocast(
          device_type=self._device.type,
          dtype=torch.bfloat16,
          enabled=self._enable_amp,
        ):
          (losses, output) = cast(Tuple[torch.Tensor, Out], self._model(batch_i))
      else:
        (losses, output) = cast(Tuple[torch.Tensor, Out], self._model(batch_i))

    # Data Distribution
    with record_function("## sparse_data_dist ##"):
      with torch.cuda.stream(self._data_dist_stream):
        _wait_for_batch(batch_ip1, self._memcpy_stream)
        # Ensure event in default stream has been called before
        # starting data dist
        if self._data_dist_stream:
          # pyre-ignore [61]: Local variable `event` is undefined, or not always defined
          self._data_dist_stream.wait_event(event)
        _start_data_dist(self._pipelined_modules, batch_ip1, self._context)

    if self._model.training:
      # Backward
      with record_function("## backward ##"):
        # Loss is normalize by number of accumulation steps.
        # The reported loss in `output['loss']` remains the unnormalized value.
        if self._grad_accum is not None:
          losses = losses / self._grad_accum
        self._grad_scaler.scale(torch.sum(losses, dim=0)).backward()

      if should_step_optimizer:
        # Update
        with record_function("## optimizer ##"):
          self._grad_scaler.step(self._optimizer)
          self._grad_scaler.update()

    self._batch_i = batch_ip1
    self._batch_ip1 = batch_ip2

    if self._model.training:
      self._progress_calls += 1

    return output