in twml/twml/trainers/trainer.py [0:0]
def learn(self, train_input_fn=None, eval_input_fn=None,
train_max_steps=None,
train_steps=None, eval_steps=None,
train_hooks=None, eval_hooks=None,
early_stop_metric=None, early_stop_patience=-1,
early_stop_minimize=True, early_stop_tolerance=0, start_epoch=0,
exporters=None, export_output_fn=None, max_duration=None):
"""
Train and evaluate the estimator for ``train_max_steps`` steps.
Each epoch involves ``train_steps`` training steps followed
by ``eval_steps`` evaluation steps. Note that each step
is a ``session.run()``, that is, each batch is a step.
Args:
train_max_steps:
maximum number of global steps of training to run.
Defaults to params.train_max_steps.
None-values cause learn() to terminate after *one* call to train() and evaluate(),
which is usually useful when using train_steps=-1
Non-positive values trains indefinitely in a loop (use with caution),
which is usually useful when used with early stopping.
train_steps:
number of training steps per epoch. For example, 100 means each
training epoch will end after processing 100 batches.
Defaults to params.train_steps.
Non-positive values and None-values go through the entire training set each epoch.
eval_steps:
number of evaluation steps per epoch.
Defaults to params.eval_steps.
Non-positive values and None-values go through the entire evaluation set each epoch.
train_input_fn:
Function to iterate through training set. It is passed to estimator.train.
eval_input_fn:
Function to iterate through evaluation set. It is passed to estimator.evaluate.
train_hooks:
List of SessionRunHooks uses for training. Defaults to self.get_train_hooks().
eval_hooks:
List of SessionRunHooks uses for evaluation. Defaults to self.get_eval_hooks()
start_epoch:
The epoch from which to start learn. If you want to do training and evaluation
for N epochs, you can call ``learn()`` in a loop as follows:
exporters:
List of exporters called at the end of each evaluation run.
Defaults to none.
export_output_fn:
The output format to use for exported models.
Only used if exporters is not None.
.. code-block:: python
for epoch in range(1,max_epoch):
trainer.learn(start_epoch=epoch)
Early-stopping arguments:
early_stop_metric:
String specifying the metric to early-stop on. Required with positive
``early_stop_patience``. For example, 'accuracy', 'accuracy_0', 'loss', etc.
The string is used to extract the relevant tensor Op from the dict returned by
the get_eval_metric_ops method. For ``metrics`` pass to the constructor,
the string is one of those. For multi-class (that is, multi-metric)
metrics, the string may be appended with a ``_0``, ``_1``, etc. or one
of the ``multi_metric_names`` (one per class).
early_stop_patience:
Maximum number of epochs to wait for an improvement in the early_stop_metric
before breaking off training. For example, a patience of 10 means that
training will have 10 epochs to improve the metric before it is killed.
Whenever the metric is improved before running out of patience,
patience is reset to ``early_stop_patience``.
Defaults to -1 (that is, no early-stopping).
early_stop_minimize:
Set this to True (the default) for metrics that need to be minimized
(like ``loss``). Metrics like ``accuracy`` that need to be maximized
should set this to False.
early_stop_tolerance:
A non-negative tolerance for comparing early_stop_metric.
E.g. when maximizing the condition is current_metric > best_metric + tolerance.
Defaults to 0.
max_duration:
A float. When this argument is defined, the job will automatically terminate after
`max_duration` seconds if it has not already compeleted.
Returns:
The directory where the checkpoints were saved.
That is, save_dir.
You can point TensorBoard to this directory to get metrics,
or pass it to another Trainer via ``init_from_dir`` when doing
multi-phase training.
"""
# pylint: disable=too-many-branches
if not callable(train_input_fn):
raise ValueError("Expecting callable train_input_fn function")
if not callable(eval_input_fn):
raise ValueError("Expecting callable eval_input_fn function")
if os.environ.get('TF_CONFIG'):
raise ValueError("trainer.learn() can not be used with distributed / hogwild setups")
if exporters and export_output_fn:
self._export_output_fn = export_output_fn
train_hooks = self.get_train_hooks() if train_hooks is None else train_hooks
eval_hooks = self.get_eval_hooks() if eval_hooks is None else eval_hooks
eval_hooks = [] if eval_hooks is None else eval_hooks
if train_max_steps is None:
train_max_steps = self.params.get('train_max_steps')
if train_steps is None:
train_steps = self.params.train_steps
if train_steps <= 0:
train_steps = None
if eval_steps is None:
eval_steps = self.params.eval_steps
if eval_steps <= 0:
eval_steps = None
if early_stop_patience > 0:
assert train_max_steps is not None, "Early stopping and max_steps=None are not compatible."
# prepare early stopping hook (which also handles logic here)
self._is_early_stopping = True
early_stop_hook = twml.hooks.EarlyStopHook(
metric=early_stop_metric,
checkpoint_dir=self._save_dir,
patience=early_stop_patience,
minimize=early_stop_minimize,
tolerance=early_stop_tolerance,
get_estimator_spec_fn=lambda: self.current_estimator_spec,
start_epoch=start_epoch)
# add early stop hook to eval hooks
eval_hooks.append(early_stop_hook)
if max_duration is not None:
train_early_stop_duration_hook = twml.hooks.EarlyStopDuration(
max_duration=max_duration,
exit_on_end=False,
save_dir=self._save_dir,
overwrite=True,
)
train_hooks.append(train_early_stop_duration_hook)
eval_early_stop_duration_hook = twml.hooks.EarlyStopDuration(
max_duration=max_duration,
exit_on_end=False,
save_dir=self._save_dir,
overwrite=True,
)
eval_hooks.append(eval_early_stop_duration_hook)
if not self._is_early_stopping:
if (train_max_steps is not None) and (train_max_steps <= 0):
if ((max_duration is not None) and (max_duration < 0)) or (max_duration is None):
logging.warn("train.max_steps is non-positive, and no early or duration stopping is configured. "
"Training job will loop forever.")
if train_max_steps is not None and train_max_steps > 0:
# we can't pass max_steps AND steps to estimator.train.
# so we pass steps to estimator.train and max_steps to this hook instead...
stop_at_step_hook = twml.hooks.StopAtStepHook(last_step=train_max_steps)
train_hooks.append(stop_at_step_hook)
with self.experiment_tracker.track_experiment(eval_hooks,
lambda: self.current_estimator_spec):
# alternate training and evaluation epochs
epoch = start_epoch
while True:
logging.info("Training epoch %d", epoch)
self._estimator.train(train_input_fn, steps=train_steps, hooks=train_hooks)
logging.info("Evaluating epoch %d", epoch)
eval_result = self._estimator.evaluate(
eval_input_fn, steps=eval_steps, hooks=eval_hooks)
if exporters:
checkpoint_path = self.estimator.latest_checkpoint()
for exporter in exporters:
export_path = os.path.join(self._save_dir, "export", exporter.name)
exporter.export(
estimator=self.estimator, export_path=export_path,
checkpoint_path=checkpoint_path, eval_result=eval_result,
is_the_final_export=False)
# If train_max_step is none. Terminate after one loop.
if train_max_steps is None:
break
# If stop_at_step_hook requested a stop, break
if train_max_steps > 0 and stop_at_step_hook.stop_requested:
break
# early-stopping logic is handled internally by the hook
if early_stop_patience > 0 and early_stop_hook.should_stop:
# but we still need to break here
break
epoch += 1
self.write_state_to_disk(save_dir=self._save_dir, filename='_SUCCESS')
return self._save_dir