def __init__()

in tfx/components/trainer/component.py [0:0]


  def __init__(
      self,
      examples: Optional[types.BaseChannel] = None,
      transformed_examples: Optional[types.BaseChannel] = None,
      transform_graph: Optional[types.BaseChannel] = None,
      schema: Optional[types.BaseChannel] = None,
      base_model: Optional[types.BaseChannel] = None,
      hyperparameters: Optional[types.BaseChannel] = None,
      module_file: Optional[Union[str, data_types.RuntimeParameter]] = None,
      run_fn: Optional[Union[str, data_types.RuntimeParameter]] = None,
      # TODO(b/147702778): deprecate trainer_fn.
      trainer_fn: Optional[Union[str, data_types.RuntimeParameter]] = None,
      train_args: Optional[Union[trainer_pb2.TrainArgs,
                                 data_types.RuntimeParameter]] = None,
      eval_args: Optional[Union[trainer_pb2.EvalArgs,
                                data_types.RuntimeParameter]] = None,
      custom_config: Optional[Union[Dict[str, Any],
                                    data_types.RuntimeParameter]] = None,
      custom_executor_spec: Optional[executor_spec.ExecutorSpec] = None):
    """Construct a Trainer component.

    Args:
      examples: A BaseChannel of type `standard_artifacts.Examples`, serving as
        the source of examples used in training (required). May be raw or
        transformed.
      transformed_examples: Deprecated (no compatibility guarantee). Please set
        'examples' instead.
      transform_graph: An optional BaseChannel of type
        `standard_artifacts.TransformGraph`, serving as the input transform
        graph if present.
      schema:  An optional BaseChannel of type `standard_artifacts.Schema`,
        serving as the schema of training and eval data. Schema is optional when
        1) transform_graph is provided which contains schema. 2) user module
        bypasses the usage of schema, e.g., hardcoded.
      base_model: A BaseChannel of type `Model`, containing model that will be
        used for training. This can be used for warmstart, transfer learning or
        model ensembling.
      hyperparameters: A BaseChannel of type
        `standard_artifacts.HyperParameters`, serving as the hyperparameters for
        training module. Tuner's output best hyperparameters can be feed into
        this.
      module_file: A path to python module file containing UDF model definition.
        The module_file must implement a function named `run_fn` at its top
        level with function signature:
          `def run_fn(trainer.fn_args_utils.FnArgs)`,
        and the trained model must be saved to FnArgs.serving_model_dir when
        this function is executed.

        For Estimator based Executor, The module_file must implement a function
        named `trainer_fn` at its top level. The function must have the
        following signature.
          def trainer_fn(trainer.fn_args_utils.FnArgs,
                         tensorflow_metadata.proto.v0.schema_pb2) -> Dict:
            ...
          where the returned Dict has the following key-values.
            'estimator': an instance of tf.estimator.Estimator
            'train_spec': an instance of tf.estimator.TrainSpec
            'eval_spec': an instance of tf.estimator.EvalSpec
            'eval_input_receiver_fn': an instance of tfma EvalInputReceiver.
        Exactly one of 'module_file' or 'run_fn' must be supplied if Trainer
        uses GenericExecutor (default). Use of a RuntimeParameter for this
        argument is experimental.
      run_fn:  A python path to UDF model definition function for generic
        trainer. See 'module_file' for details. Exactly one of 'module_file' or
        'run_fn' must be supplied if Trainer uses GenericExecutor (default). Use
        of a RuntimeParameter for this argument is experimental.
      trainer_fn:  A python path to UDF model definition function for estimator
        based trainer. See 'module_file' for the required signature of the UDF.
        Exactly one of 'module_file' or 'trainer_fn' must be supplied if Trainer
        uses Estimator based Executor. Use of a RuntimeParameter for this
        argument is experimental.
      train_args: A proto.TrainArgs instance, containing args used for training
        Currently only splits and num_steps are available. Default behavior
        (when splits is empty) is train on `train` split.
      eval_args: A proto.EvalArgs instance, containing args used for evaluation.
        Currently only splits and num_steps are available. Default behavior
        (when splits is empty) is evaluate on `eval` split.
      custom_config: A dict which contains addtional training job parameters
        that will be passed into user module.
      custom_executor_spec: Optional custom executor spec. Deprecated (no
        compatibility guarantee), please customize component directly.

    Raises:
      ValueError:
        - When both or neither of 'module_file' and user function
          (e.g., trainer_fn and run_fn) is supplied.
        - When both or neither of 'examples' and 'transformed_examples'
            is supplied.
        - When 'transformed_examples' is supplied but 'transform_graph'
            is not supplied.
    """
    if [bool(module_file), bool(run_fn), bool(trainer_fn)].count(True) != 1:
      raise ValueError(
          "Exactly one of 'module_file', 'trainer_fn', or 'run_fn' must be "
          "supplied.")

    if bool(examples) == bool(transformed_examples):
      raise ValueError(
          "Exactly one of 'example' or 'transformed_example' must be supplied.")

    if transformed_examples and not transform_graph:
      raise ValueError("If 'transformed_examples' is supplied, "
                       "'transform_graph' must be supplied too.")

    if custom_executor_spec:
      logging.warning(
          "`custom_executor_spec` is deprecated. Please customize component directly."
      )
    if transformed_examples:
      logging.warning(
          "`transformed_examples` is deprecated. Please use `examples` instead."
      )
    examples = examples or transformed_examples
    model = types.Channel(type=standard_artifacts.Model)
    model_run = types.Channel(type=standard_artifacts.ModelRun)
    spec = standard_component_specs.TrainerSpec(
        examples=examples,
        transform_graph=transform_graph,
        schema=schema,
        base_model=base_model,
        hyperparameters=hyperparameters,
        train_args=train_args or trainer_pb2.TrainArgs(),
        eval_args=eval_args or trainer_pb2.EvalArgs(),
        module_file=module_file,
        run_fn=run_fn,
        trainer_fn=trainer_fn,
        custom_config=(custom_config
                       if isinstance(custom_config, data_types.RuntimeParameter)
                       else json_utils.dumps(custom_config)),
        model=model,
        model_run=model_run)
    super().__init__(spec=spec, custom_executor_spec=custom_executor_spec)

    if udf_utils.should_package_user_modules():
      # In this case, the `MODULE_PATH_KEY` execution property will be injected
      # as a reference to the given user module file after packaging, at which
      # point the `MODULE_FILE_KEY` execution property will be removed.
      udf_utils.add_user_module_dependency(
          self, standard_component_specs.MODULE_FILE_KEY,
          standard_component_specs.MODULE_PATH_KEY)