in tfx/components/trainer/component.py [0:0]
def __init__(
self,
examples: Optional[types.BaseChannel] = None,
transformed_examples: Optional[types.BaseChannel] = None,
transform_graph: Optional[types.BaseChannel] = None,
schema: Optional[types.BaseChannel] = None,
base_model: Optional[types.BaseChannel] = None,
hyperparameters: Optional[types.BaseChannel] = None,
module_file: Optional[Union[str, data_types.RuntimeParameter]] = None,
run_fn: Optional[Union[str, data_types.RuntimeParameter]] = None,
# TODO(b/147702778): deprecate trainer_fn.
trainer_fn: Optional[Union[str, data_types.RuntimeParameter]] = None,
train_args: Optional[Union[trainer_pb2.TrainArgs,
data_types.RuntimeParameter]] = None,
eval_args: Optional[Union[trainer_pb2.EvalArgs,
data_types.RuntimeParameter]] = None,
custom_config: Optional[Union[Dict[str, Any],
data_types.RuntimeParameter]] = None,
custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None):
"""Construct a Trainer component.
Args:
examples: A BaseChannel of type `standard_artifacts.Examples`, serving as
the source of examples used in training (required). May be raw or
transformed.
transformed_examples: Deprecated (no compatibility guarantee). Please set
'examples' instead.
transform_graph: An optional BaseChannel of type
`standard_artifacts.TransformGraph`, serving as the input transform
graph if present.
schema: An optional BaseChannel of type `standard_artifacts.Schema`,
serving as the schema of training and eval data. Schema is optional when
1) transform_graph is provided which contains schema. 2) user module
bypasses the usage of schema, e.g., hardcoded.
base_model: A BaseChannel of type `Model`, containing model that will be
used for training. This can be used for warmstart, transfer learning or
model ensembling.
hyperparameters: A BaseChannel of type
`standard_artifacts.HyperParameters`, serving as the hyperparameters for
training module. Tuner's output best hyperparameters can be feed into
this.
module_file: A path to python module file containing UDF model definition.
The module_file must implement a function named `run_fn` at its top
level with function signature:
`def run_fn(trainer.fn_args_utils.FnArgs)`,
and the trained model must be saved to FnArgs.serving_model_dir when
this function is executed.
For Estimator based Executor, The module_file must implement a function
named `trainer_fn` at its top level. The function must have the
following signature.
def trainer_fn(trainer.fn_args_utils.FnArgs,
tensorflow_metadata.proto.v0.schema_pb2) -> Dict:
...
where the returned Dict has the following key-values.
'estimator': an instance of tf.estimator.Estimator
'train_spec': an instance of tf.estimator.TrainSpec
'eval_spec': an instance of tf.estimator.EvalSpec
'eval_input_receiver_fn': an instance of tfma EvalInputReceiver.
Exactly one of 'module_file' or 'run_fn' must be supplied if Trainer
uses GenericExecutor (default). Use of a RuntimeParameter for this
argument is experimental.
run_fn: A python path to UDF model definition function for generic
trainer. See 'module_file' for details. Exactly one of 'module_file' or
'run_fn' must be supplied if Trainer uses GenericExecutor (default). Use
of a RuntimeParameter for this argument is experimental.
trainer_fn: A python path to UDF model definition function for estimator
based trainer. See 'module_file' for the required signature of the UDF.
Exactly one of 'module_file' or 'trainer_fn' must be supplied if Trainer
uses Estimator based Executor. Use of a RuntimeParameter for this
argument is experimental.
train_args: A proto.TrainArgs instance, containing args used for training
Currently only splits and num_steps are available. Default behavior
(when splits is empty) is train on `train` split.
eval_args: A proto.EvalArgs instance, containing args used for evaluation.
Currently only splits and num_steps are available. Default behavior
(when splits is empty) is evaluate on `eval` split.
custom_config: A dict which contains addtional training job parameters
that will be passed into user module.
custom_executor_spec: Optional custom executor spec. Deprecated (no
compatibility guarantee), please customize component directly.
Raises:
ValueError:
- When both or neither of 'module_file' and user function
(e.g., trainer_fn and run_fn) is supplied.
- When both or neither of 'examples' and 'transformed_examples'
is supplied.
- When 'transformed_examples' is supplied but 'transform_graph'
is not supplied.
"""
if [bool(module_file), bool(run_fn), bool(trainer_fn)].count(True) != 1:
raise ValueError(
"Exactly one of 'module_file', 'trainer_fn', or 'run_fn' must be "
"supplied.")
if bool(examples) == bool(transformed_examples):
raise ValueError(
"Exactly one of 'example' or 'transformed_example' must be supplied.")
if transformed_examples and not transform_graph:
raise ValueError("If 'transformed_examples' is supplied, "
"'transform_graph' must be supplied too.")
if custom_executor_spec:
logging.warning(
"`custom_executor_spec` is deprecated. Please customize component directly."
)
if transformed_examples:
logging.warning(
"`transformed_examples` is deprecated. Please use `examples` instead."
)
examples = examples or transformed_examples
model = types.Channel(type=standard_artifacts.Model)
model_run = types.Channel(type=standard_artifacts.ModelRun)
spec = standard_component_specs.TrainerSpec(
examples=examples,
transform_graph=transform_graph,
schema=schema,
base_model=base_model,
hyperparameters=hyperparameters,
train_args=train_args or trainer_pb2.TrainArgs(),
eval_args=eval_args or trainer_pb2.EvalArgs(),
module_file=module_file,
run_fn=run_fn,
trainer_fn=trainer_fn,
custom_config=(custom_config
if isinstance(custom_config, data_types.RuntimeParameter)
else json_utils.dumps(custom_config)),
model=model,
model_run=model_run)
super().__init__(spec=spec, custom_executor_spec=custom_executor_spec)
if udf_utils.should_package_user_modules():
# In this case, the `MODULE_PATH_KEY` execution property will be injected
# as a reference to the given user module file after packaging, at which
# point the `MODULE_FILE_KEY` execution property will be removed.
udf_utils.add_user_module_dependency(
self, standard_component_specs.MODULE_FILE_KEY,
standard_component_specs.MODULE_PATH_KEY)