def get_common_fn_args()

in tfx/components/trainer/fn_args_utils.py [0:0]


def get_common_fn_args(input_dict: Dict[str, List[types.Artifact]],
                       exec_properties: Dict[str, Any],
                       working_dir: Optional[str] = None) -> FnArgs:
  """Get common args of training and tuning."""
  if input_dict.get(standard_component_specs.TRANSFORM_GRAPH_KEY):
    transform_graph_path = artifact_utils.get_single_uri(
        input_dict[standard_component_specs.TRANSFORM_GRAPH_KEY])
  else:
    transform_graph_path = None

  if input_dict.get(standard_component_specs.SCHEMA_KEY):
    schema_path = io_utils.get_only_uri_in_dir(
        artifact_utils.get_single_uri(
            input_dict[standard_component_specs.SCHEMA_KEY]))
  else:
    schema_path = None

  train_args = trainer_pb2.TrainArgs()
  eval_args = trainer_pb2.EvalArgs()
  proto_utils.json_to_proto(
      exec_properties[standard_component_specs.TRAIN_ARGS_KEY], train_args)
  proto_utils.json_to_proto(
      exec_properties[standard_component_specs.EVAL_ARGS_KEY], eval_args)

  # Default behavior is train on `train` split (when splits is empty in train
  # args) and evaluate on `eval` split (when splits is empty in eval args).
  if not train_args.splits:
    train_args.splits.append('train')
    absl.logging.info("Train on the 'train' split when train_args.splits is "
                      'not set.')
  if not eval_args.splits:
    eval_args.splits.append('eval')
    absl.logging.info("Evaluate on the 'eval' split when eval_args.splits is "
                      'not set.')

  train_files = []
  for train_split in train_args.splits:
    train_files.extend([
        io_utils.all_files_pattern(uri)
        for uri in artifact_utils.get_split_uris(
            input_dict[standard_component_specs.EXAMPLES_KEY], train_split)
    ])

  eval_files = []
  for eval_split in eval_args.splits:
    eval_files.extend([
        io_utils.all_files_pattern(uri)
        for uri in artifact_utils.get_split_uris(
            input_dict[standard_component_specs.EXAMPLES_KEY], eval_split)
    ])

  data_accessor = DataAccessor(
      tf_dataset_factory=tfxio_utils.get_tf_dataset_factory_from_artifact(
          input_dict[standard_component_specs.EXAMPLES_KEY],
          _TELEMETRY_DESCRIPTORS),
      record_batch_factory=tfxio_utils.get_record_batch_factory_from_artifact(
          input_dict[standard_component_specs.EXAMPLES_KEY],
          _TELEMETRY_DESCRIPTORS),
      data_view_decode_fn=tfxio_utils.get_data_view_decode_fn_from_artifact(
          input_dict[standard_component_specs.EXAMPLES_KEY],
          _TELEMETRY_DESCRIPTORS)
      )

  # https://github.com/tensorflow/tfx/issues/45: Replace num_steps=0 with
  # num_steps=None.  Conversion of the proto to python will set the default
  # value of an int as 0 so modify the value here.  Tensorflow will raise an
  # error if num_steps <= 0.
  train_steps = train_args.num_steps or None
  eval_steps = eval_args.num_steps or None

  # Load and deserialize custom config from execution properties.
  # Note that in the component interface the default serialization of custom
  # config is 'null' instead of '{}'. Therefore we need to default the
  # json_utils.loads to 'null' then populate it with an empty dict when
  # needed.
  custom_config = json_utils.loads(
      exec_properties.get(standard_component_specs.CUSTOM_CONFIG_KEY, 'null'))

  # TODO(ruoyu): Make this a dict of tag -> uri instead of list.
  if input_dict.get(standard_component_specs.BASE_MODEL_KEY):
    base_model_artifact = artifact_utils.get_single_instance(
        input_dict[standard_component_specs.BASE_MODEL_KEY])
    base_model = path_utils.serving_model_path(
        base_model_artifact.uri,
        path_utils.is_old_model_artifact(base_model_artifact))
  else:
    base_model = None

  return FnArgs(
      working_dir=working_dir,
      train_files=train_files,
      eval_files=eval_files,
      train_steps=train_steps,
      eval_steps=eval_steps,
      schema_path=schema_path,
      transform_graph_path=transform_graph_path,
      data_accessor=data_accessor,
      base_model=base_model,
      custom_config=custom_config,
  )