in tfx/components/transform/component.py [0:0]
def __init__(
self,
examples: types.BaseChannel,
schema: types.BaseChannel,
module_file: Optional[Union[str, data_types.RuntimeParameter]] = None,
preprocessing_fn: Optional[Union[str,
data_types.RuntimeParameter]] = None,
splits_config: Optional[transform_pb2.SplitsConfig] = None,
analyzer_cache: Optional[types.BaseChannel] = None,
materialize: bool = True,
disable_analyzer_cache: bool = False,
force_tf_compat_v1: bool = False,
custom_config: Optional[Dict[str, Any]] = None,
disable_statistics: bool = False,
stats_options_updater_fn: Optional[str] = None):
"""Construct a Transform component.
Args:
examples: A BaseChannel of type `standard_artifacts.Examples` (required).
This should contain custom splits specified in splits_config. If custom
split is not provided, this should contain two splits 'train' and
'eval'.
schema: A BaseChannel of type `standard_artifacts.Schema`. This should
contain a single schema artifact.
module_file: The file path to a python module file, from which the
'preprocessing_fn' function will be loaded.
Exactly one of 'module_file' or 'preprocessing_fn' must be supplied.
The function needs to have the following signature:
```
def preprocessing_fn(inputs: Dict[Text, Any]) -> Dict[Text, Any]:
...
```
where the values of input and returned Dict are either tf.Tensor or
tf.SparseTensor.
If additional inputs are needed for preprocessing_fn, they can be passed
in custom_config:
```
def preprocessing_fn(inputs: Dict[Text, Any], custom_config:
Dict[Text, Any]) -> Dict[Text, Any]:
...
```
To update the stats options used to compute the pre-transform or
post-transform statistics, optionally define the
'stats-options_updater_fn' within the same module. If implemented,
this function needs to have the following signature:
```
def stats_options_updater_fn(stats_type: tfx.components.transform
.stats_options_util.StatsType, stats_options: tfdv.StatsOptions)
-> tfdv.StatsOptions:
...
```
Use of a RuntimeParameter for this argument is experimental.
preprocessing_fn: The path to python function that implements a
'preprocessing_fn'. See 'module_file' for expected signature of the
function. Exactly one of 'module_file' or 'preprocessing_fn' must be
supplied. Use of a RuntimeParameter for this argument is experimental.
splits_config: A transform_pb2.SplitsConfig instance, providing splits
that should be analyzed and splits that should be transformed. Note
analyze and transform splits can have overlap. Default behavior (when
splits_config is not set) is analyze the 'train' split and transform all
splits. If splits_config is set, analyze cannot be empty.
analyzer_cache: Optional input 'TransformCache' channel containing cached
information from previous Transform runs. When provided, Transform will
try use the cached calculation if possible.
materialize: If True, write transformed examples as an output.
disable_analyzer_cache: If False, Transform will use input cache if
provided and write cache output. If True, `analyzer_cache` must not be
provided.
force_tf_compat_v1: (Optional) If True and/or TF2 behaviors are disabled
Transform will use Tensorflow in compat.v1 mode irrespective of
installed version of Tensorflow. Defaults to `False`.
custom_config: A dict which contains additional parameters that will be
passed to preprocessing_fn.
disable_statistics: If True, do not invoke TFDV to compute pre-transform
and post-transform statistics. When statistics are computed, they will
will be stored in the `pre_transform_feature_stats/` and
`post_transform_feature_stats/` subfolders of the `transform_graph`
export.
stats_options_updater_fn: The path to a python function that implements a
'stats_options_updater_fn'. See 'module_file' for expected signature of
the function. 'stats_options_updater_fn' cannot be defined if
'module_file' is specified.
Raises:
ValueError: When both or neither of 'module_file' and 'preprocessing_fn'
is supplied.
"""
if bool(module_file) == bool(preprocessing_fn):
raise ValueError(
"Exactly one of 'module_file' or 'preprocessing_fn' must be supplied."
)
if bool(module_file) and bool(stats_options_updater_fn):
raise ValueError(
"'stats_options_updater_fn' cannot be specified together with "
"'module_file'")
transform_graph = types.Channel(type=standard_artifacts.TransformGraph)
transformed_examples = None
if materialize:
transformed_examples = types.Channel(type=standard_artifacts.Examples)
transformed_examples.matching_channel_name = "examples"
(pre_transform_schema, pre_transform_stats, post_transform_schema,
post_transform_stats, post_transform_anomalies) = (None,) * 5
if not disable_statistics:
pre_transform_schema = types.Channel(type=standard_artifacts.Schema)
post_transform_schema = types.Channel(type=standard_artifacts.Schema)
pre_transform_stats = types.Channel(
type=standard_artifacts.ExampleStatistics)
post_transform_stats = types.Channel(
type=standard_artifacts.ExampleStatistics)
post_transform_anomalies = types.Channel(
type=standard_artifacts.ExampleAnomalies)
if disable_analyzer_cache:
updated_analyzer_cache = None
if analyzer_cache:
raise ValueError(
"`analyzer_cache` is set when disable_analyzer_cache is True.")
else:
updated_analyzer_cache = types.Channel(
type=standard_artifacts.TransformCache)
spec = standard_component_specs.TransformSpec(
examples=examples,
schema=schema,
module_file=module_file,
preprocessing_fn=preprocessing_fn,
stats_options_updater_fn=stats_options_updater_fn,
force_tf_compat_v1=int(force_tf_compat_v1),
splits_config=splits_config,
transform_graph=transform_graph,
transformed_examples=transformed_examples,
analyzer_cache=analyzer_cache,
updated_analyzer_cache=updated_analyzer_cache,
custom_config=json_utils.dumps(custom_config),
disable_statistics=int(disable_statistics),
pre_transform_schema=pre_transform_schema,
pre_transform_stats=pre_transform_stats,
post_transform_schema=post_transform_schema,
post_transform_stats=post_transform_stats,
post_transform_anomalies=post_transform_anomalies)
super().__init__(spec=spec)
if udf_utils.should_package_user_modules():
# In this case, the `MODULE_PATH_KEY` execution property will be injected
# as a reference to the given user module file after packaging, at which
# point the `MODULE_FILE_KEY` execution property will be removed.
udf_utils.add_user_module_dependency(
self, standard_component_specs.MODULE_FILE_KEY,
standard_component_specs.MODULE_PATH_KEY)