def _get_step_fn()

in core/custom_training_loop.py [0:0]


def _get_step_fn(pipeline, data_iterator, training: bool):
  def step_fn():
    # It turns out that model.train() and model.eval() simply switch a single field inside the model
    # class,so it's somewhat safer to wrap in here.
    if training:
      pipeline._model.train()
    else:
      pipeline._model.eval()

    outputs = pipeline.progress(data_iterator)
    return tree.map_structure(lambda elem: elem.detach(), outputs)

  return step_fn