in tfx/components/trainer/fn_args_utils.py [0:0]
def get_common_fn_args(input_dict: Dict[str, List[types.Artifact]],
exec_properties: Dict[str, Any],
working_dir: Optional[str] = None) -> FnArgs:
"""Get common args of training and tuning."""
if input_dict.get(standard_component_specs.TRANSFORM_GRAPH_KEY):
transform_graph_path = artifact_utils.get_single_uri(
input_dict[standard_component_specs.TRANSFORM_GRAPH_KEY])
else:
transform_graph_path = None
if input_dict.get(standard_component_specs.SCHEMA_KEY):
schema_path = io_utils.get_only_uri_in_dir(
artifact_utils.get_single_uri(
input_dict[standard_component_specs.SCHEMA_KEY]))
else:
schema_path = None
train_args = trainer_pb2.TrainArgs()
eval_args = trainer_pb2.EvalArgs()
proto_utils.json_to_proto(
exec_properties[standard_component_specs.TRAIN_ARGS_KEY], train_args)
proto_utils.json_to_proto(
exec_properties[standard_component_specs.EVAL_ARGS_KEY], eval_args)
# Default behavior is train on `train` split (when splits is empty in train
# args) and evaluate on `eval` split (when splits is empty in eval args).
if not train_args.splits:
train_args.splits.append('train')
absl.logging.info("Train on the 'train' split when train_args.splits is "
'not set.')
if not eval_args.splits:
eval_args.splits.append('eval')
absl.logging.info("Evaluate on the 'eval' split when eval_args.splits is "
'not set.')
train_files = []
for train_split in train_args.splits:
train_files.extend([
io_utils.all_files_pattern(uri)
for uri in artifact_utils.get_split_uris(
input_dict[standard_component_specs.EXAMPLES_KEY], train_split)
])
eval_files = []
for eval_split in eval_args.splits:
eval_files.extend([
io_utils.all_files_pattern(uri)
for uri in artifact_utils.get_split_uris(
input_dict[standard_component_specs.EXAMPLES_KEY], eval_split)
])
data_accessor = DataAccessor(
tf_dataset_factory=tfxio_utils.get_tf_dataset_factory_from_artifact(
input_dict[standard_component_specs.EXAMPLES_KEY],
_TELEMETRY_DESCRIPTORS),
record_batch_factory=tfxio_utils.get_record_batch_factory_from_artifact(
input_dict[standard_component_specs.EXAMPLES_KEY],
_TELEMETRY_DESCRIPTORS),
data_view_decode_fn=tfxio_utils.get_data_view_decode_fn_from_artifact(
input_dict[standard_component_specs.EXAMPLES_KEY],
_TELEMETRY_DESCRIPTORS)
)
# https://github.com/tensorflow/tfx/issues/45: Replace num_steps=0 with
# num_steps=None. Conversion of the proto to python will set the default
# value of an int as 0 so modify the value here. Tensorflow will raise an
# error if num_steps <= 0.
train_steps = train_args.num_steps or None
eval_steps = eval_args.num_steps or None
# Load and deserialize custom config from execution properties.
# Note that in the component interface the default serialization of custom
# config is 'null' instead of '{}'. Therefore we need to default the
# json_utils.loads to 'null' then populate it with an empty dict when
# needed.
custom_config = json_utils.loads(
exec_properties.get(standard_component_specs.CUSTOM_CONFIG_KEY, 'null'))
# TODO(ruoyu): Make this a dict of tag -> uri instead of list.
if input_dict.get(standard_component_specs.BASE_MODEL_KEY):
base_model_artifact = artifact_utils.get_single_instance(
input_dict[standard_component_specs.BASE_MODEL_KEY])
base_model = path_utils.serving_model_path(
base_model_artifact.uri,
path_utils.is_old_model_artifact(base_model_artifact))
else:
base_model = None
return FnArgs(
working_dir=working_dir,
train_files=train_files,
eval_files=eval_files,
train_steps=train_steps,
eval_steps=eval_steps,
schema_path=schema_path,
transform_graph_path=transform_graph_path,
data_accessor=data_accessor,
base_model=base_model,
custom_config=custom_config,
)