in core/custom_training_loop.py [0:0]
def _get_step_fn(pipeline, data_iterator, training: bool):
def step_fn():
# It turns out that model.train() and model.eval() simply switch a single field inside the model
# class,so it's somewhat safer to wrap in here.
if training:
pipeline._model.train()
else:
pipeline._model.eval()
outputs = pipeline.progress(data_iterator)
return tree.map_structure(lambda elem: elem.detach(), outputs)
return step_fn