in tfx/components/transform/executor.py [0:0]
def _RunBeamImpl(
self, analyze_data_list: List[_Dataset],
transform_data_list: List[_Dataset], preprocessing_fn: Any,
stats_options_updater_fn: Callable[
[stats_options_util.StatsType, tfdv.StatsOptions],
tfdv.StatsOptions], force_tf_compat_v1: bool,
input_dataset_metadata: dataset_metadata.DatasetMetadata,
transform_output_path: str, raw_examples_data_format: int, temp_path: str,
input_cache_dir: Optional[str], output_cache_dir: Optional[str],
disable_statistics: bool, per_set_stats_output_paths: Sequence[str],
materialization_format: Optional[str], analyze_paths_count: int,
stats_output_paths: Dict[str, str],
make_beam_pipeline_fn: Callable[[], beam.Pipeline]) -> _Status:
"""Perform data preprocessing with TFT.
Args:
analyze_data_list: List of datasets for analysis.
transform_data_list: List of datasets for transform.
preprocessing_fn: The tf.Transform preprocessing_fn.
stats_options_updater_fn: The user-specified function for updating stats
options.
force_tf_compat_v1: If True, call Transform's API to use Tensorflow in
tf.compat.v1 mode.
input_dataset_metadata: A DatasetMetadata object for the input data.
transform_output_path: An absolute path to write the output to.
raw_examples_data_format: The data format of the raw examples. One of the
enums from example_gen_pb2.PayloadFormat.
temp_path: A path to a temporary dir.
input_cache_dir: A dir containing the input analysis cache. May be None.
output_cache_dir: A dir to write the analysis cache to. May be None.
disable_statistics: A bool indicating whether or to disable statistics.
per_set_stats_output_paths: Paths to per-set statistics output. If empty,
per-set statistics is not produced.
materialization_format: A string describing the format of the materialized
data or None if materialization is not enabled.
analyze_paths_count: An integer, the number of paths that should be used
for analysis.
stats_output_paths: A dictionary specifying the output paths to use when
computing statistics. If the dictionary is empty, the stats will be
placed within the transform_output_path to preserve backwards
compatibility.
make_beam_pipeline_fn: A callable that can create a beam pipeline.
Returns:
Status of the execution.
"""
self._AssertSameTFXIOSchema(analyze_data_list)
unprojected_typespecs = (
analyze_data_list[0].tfxio.TensorAdapter().OriginalTypeSpecs())
analyze_input_columns = tft.get_analyze_input_columns(
preprocessing_fn,
unprojected_typespecs,
force_tf_compat_v1=force_tf_compat_v1)
analyze_columns_count = len(analyze_input_columns)
transform_input_columns = tft.get_transform_input_columns(
preprocessing_fn,
unprojected_typespecs,
force_tf_compat_v1=force_tf_compat_v1)
# Use the same dataset (same columns) for AnalyzeDataset and computing
# pre-transform stats so that the data will only be read once for these
# two operations.
if not disable_statistics:
analyze_input_columns = list(
set(list(analyze_input_columns) + list(transform_input_columns)))
for d in analyze_data_list:
d.tfxio = d.tfxio.Project(analyze_input_columns)
self._AssertSameTFXIOSchema(analyze_data_list)
analyze_data_tensor_adapter_config = (
analyze_data_list[0].tfxio.TensorAdapterConfig())
for d in transform_data_list:
d.tfxio = d.tfxio.Project(transform_input_columns)
desired_batch_size = self._GetDesiredBatchSize(
raw_examples_data_format, input_dataset_metadata.schema)
with make_beam_pipeline_fn() as pipeline:
with tft_beam.Context(
temp_dir=temp_path,
desired_batch_size=desired_batch_size,
passthrough_keys=self._GetTFXIOPassthroughKeys(),
use_deep_copy_optimization=True,
force_tf_compat_v1=force_tf_compat_v1):
(new_analyze_data_dict, input_cache,
estimated_stage_count_with_cache) = (
pipeline
| 'OptimizeRun' >> self._OptimizeRun(
input_cache_dir, output_cache_dir, analyze_data_list,
unprojected_typespecs, preprocessing_fn,
self._GetCacheSource(), force_tf_compat_v1))
_ = (
pipeline
| 'IncrementPipelineMetrics' >> self._IncrementPipelineMetrics(
total_columns_count=len(unprojected_typespecs),
analyze_columns_count=analyze_columns_count,
transform_columns_count=len(transform_input_columns),
analyze_paths_count=analyze_paths_count,
analyzer_cache_enabled=input_cache is not None,
disable_statistics=disable_statistics,
materialize=materialization_format is not None,
estimated_stage_count_with_cache=(
estimated_stage_count_with_cache)))
if input_cache:
logging.debug('Analyzing data with cache.')
full_analyze_dataset_keys_list = [
dataset.dataset_key for dataset in analyze_data_list
]
# Removing unneeded datasets if they won't be needed for statistics or
# materialization.
if materialization_format is None and disable_statistics:
if None in new_analyze_data_dict.values():
logging.debug(
'Not reading the following datasets due to cache: %s', [
dataset.file_pattern
for dataset in analyze_data_list
if new_analyze_data_dict[dataset.dataset_key] is None
])
analyze_data_list = [
d for d in new_analyze_data_dict.values() if d is not None
]
for dataset in analyze_data_list:
infix = 'AnalysisIndex{}'.format(dataset.index)
dataset.standardized = (
pipeline
| 'TFXIOReadAndDecode[{}]'.format(infix) >>
dataset.tfxio.BeamSource(desired_batch_size))
input_analysis_data = {}
for key, dataset in new_analyze_data_dict.items():
input_analysis_data[key] = (None if dataset is None else
dataset.standardized)
transform_fn, cache_output = (
(input_analysis_data, input_cache,
analyze_data_tensor_adapter_config)
| 'Analyze' >> tft_beam.AnalyzeDatasetWithCache(
preprocessing_fn, pipeline=pipeline))
# Write the raw/input metadata.
(input_dataset_metadata
| 'WriteMetadata' >> tft_beam.WriteMetadata(
os.path.join(transform_output_path,
tft.TFTransformOutput.RAW_METADATA_DIR), pipeline))
# WriteTransformFn writes transform_fn and metadata to subdirectories
# tensorflow_transform.SAVED_MODEL_DIR and
# tensorflow_transform.TRANSFORMED_METADATA_DIR respectively.
completed_transform = (
transform_fn
| 'WriteTransformFn' >>
tft_beam.WriteTransformFn(transform_output_path))
if output_cache_dir is not None and cache_output is not None:
fileio.makedirs(output_cache_dir)
logging.debug('Using existing cache in: %s', input_cache_dir)
if input_cache_dir is not None:
# Only copy cache that is relevant to this iteration. This is
# assuming that this pipeline operates on rolling ranges, so those
# cache entries may also be relevant for future iterations.
for span_cache_dir in input_analysis_data:
full_span_cache_dir = os.path.join(input_cache_dir,
span_cache_dir.key)
if fileio.isdir(full_span_cache_dir):
self._CopyCache(
full_span_cache_dir,
os.path.join(output_cache_dir, span_cache_dir.key))
# TODO(b/157479287, b/171165988): Remove this condition when beam
# 2.26 is used.
if cache_output:
(cache_output
| 'WriteCache' >> analyzer_cache.WriteAnalysisCacheToFS(
pipeline=pipeline,
cache_base_dir=output_cache_dir,
sink=self._GetCacheSink(),
dataset_keys=full_analyze_dataset_keys_list))
if not disable_statistics or materialization_format is not None:
# Do not compute pre-transform stats if the input format is raw
# proto, as StatsGen would treat any input as tf.Example. Note that
# tf.SequenceExamples are wire-format compatible with tf.Examples.
if (not disable_statistics and
not self._IsDataFormatProto(raw_examples_data_format)):
# Aggregated feature stats before transformation.
if (self._DecodesSequenceExamplesAsRawRecords(
raw_examples_data_format, input_dataset_metadata.schema)):
schema_proto = None
else:
schema_proto = input_dataset_metadata.schema
if (self._DecodesSequenceExamplesAsRawRecords(
raw_examples_data_format, input_dataset_metadata.schema)):
def _ExtractRawExampleBatches(record_batch):
return record_batch.column(
record_batch.schema.get_field_index(
RAW_EXAMPLE_KEY)).flatten().to_pylist()
# Make use of the fact that tf.SequenceExample is wire-format
# compatible with tf.Example
stats_input = []
for dataset in analyze_data_list:
infix = 'AnalysisIndex{}'.format(dataset.index)
stats_input.append(
dataset.standardized
| 'ExtractRawExampleBatches[{}]'.format(
infix) >> beam.Map(_ExtractRawExampleBatches)
| 'DecodeSequenceExamplesAsExamplesIntoRecordBatches[{}]'
.format(infix) >> beam.ParDo(
self._ToArrowRecordBatchesFn(schema_proto)))
else:
stats_input = [
dataset.standardized for dataset in analyze_data_list
]
pre_transform_stats_options = _InvokeStatsOptionsUpdaterFn(
stats_options_updater_fn,
stats_options_util.StatsType.PRE_TRANSFORM, schema_proto)
if self._TFDVWriteShardedOutput():
pre_transform_stats_options.experimental_result_partitions = (
_SHARDED_OUTPUT_PARTITIONS)
else:
if (pre_transform_stats_options.experimental_result_partitions !=
1):
raise ValueError('Sharded output disabled requires '
'experimental_result_partitions=1.')
if stats_output_paths:
pre_transform_feature_stats_loc = {
_STATS_KEY:
os.path.join(
stats_output_paths[
labels.PRE_TRANSFORM_OUTPUT_STATS_PATH_LABEL],
STATS_FILE),
_SHARDED_STATS_KEY:
os.path.join(
stats_output_paths[
labels.PRE_TRANSFORM_OUTPUT_STATS_PATH_LABEL],
SHARDED_STATS_PREFIX),
_SCHEMA_KEY:
os.path.join(
stats_output_paths[
labels.PRE_TRANSFORM_OUTPUT_SCHEMA_PATH_LABEL],
_SCHEMA_FILE)
}
else:
pre_transform_feature_stats_loc = os.path.join(
transform_output_path,
tft.TFTransformOutput.PRE_TRANSFORM_FEATURE_STATS_PATH)
(stats_input
| 'FlattenAnalysisDatasets' >> beam.Flatten(pipeline=pipeline)
| 'GenerateStats[FlattenedAnalysisDataset]' >>
self._GenerateAndMaybeValidateStats(
pre_transform_feature_stats_loc,
stats_options=pre_transform_stats_options,
enable_validation=False))
# transform_data_list is a superset of analyze_data_list, we pay the
# cost to read the same dataset (analyze_data_list) again here to
# prevent certain beam runner from doing large temp materialization.
for dataset in transform_data_list:
infix = 'TransformIndex{}'.format(dataset.index)
dataset.standardized = (
pipeline | 'TFXIOReadAndDecode[{}]'.format(infix) >>
dataset.tfxio.BeamSource(desired_batch_size))
(dataset.transformed, metadata) = (
((dataset.standardized, dataset.tfxio.TensorAdapterConfig()),
transform_fn)
| 'Transform[{}]'.format(infix) >>
tft_beam.TransformDataset(output_record_batches=True))
_, metadata = transform_fn
# TODO(b/70392441): Retain tf.Metadata (e.g., IntDomain) in
# schema. Currently input dataset schema only contains dtypes,
# and other metadata is dropped due to roundtrip to tensors.
transformed_schema_proto = metadata.schema
if not disable_statistics:
# Aggregated feature stats after transformation.
for dataset in transform_data_list:
infix = 'TransformIndex{}'.format(dataset.index)
dataset.transformed_and_standardized = (
dataset.transformed
| 'ExtractRecordBatches[{}]'.format(infix) >> beam.Keys())
post_transform_stats_options = _InvokeStatsOptionsUpdaterFn(
stats_options_updater_fn,
stats_options_util.StatsType.POST_TRANSFORM,
transformed_schema_proto, metadata.asset_map,
transform_output_path)
if self._TFDVWriteShardedOutput():
post_transform_stats_options.experimental_result_partitions = (
_SHARDED_OUTPUT_PARTITIONS)
else:
if (post_transform_stats_options.experimental_result_partitions !=
1):
raise ValueError('Sharded output disabled requires '
'experimental_result_partitions=1.')
if stats_output_paths:
post_transform_feature_stats_loc = {
_STATS_KEY:
os.path.join(
stats_output_paths[
labels.POST_TRANSFORM_OUTPUT_STATS_PATH_LABEL],
STATS_FILE),
_SHARDED_STATS_KEY:
os.path.join(
stats_output_paths[
labels.POST_TRANSFORM_OUTPUT_STATS_PATH_LABEL],
SHARDED_STATS_PREFIX),
_SCHEMA_KEY:
os.path.join(
stats_output_paths[
labels.POST_TRANSFORM_OUTPUT_SCHEMA_PATH_LABEL],
_SCHEMA_FILE),
_ANOMALIES_KEY:
os.path.join(
stats_output_paths[
labels
.POST_TRANSFORM_OUTPUT_ANOMALIES_PATH_LABEL],
_ANOMALIES_FILE)
}
else:
post_transform_feature_stats_loc = os.path.join(
transform_output_path,
tft.TFTransformOutput.POST_TRANSFORM_FEATURE_STATS_PATH)
([
dataset.transformed_and_standardized
for dataset in transform_data_list
]
| 'FlattenTransformedDatasets' >> beam.Flatten(pipeline=pipeline)
| 'WaitForTransformWrite' >> beam.Map(
lambda x, completion: x,
completion=beam.pvalue.AsSingleton(completed_transform))
| 'GenerateAndValidateStats[FlattenedTransformedDatasets]' >>
self._GenerateAndMaybeValidateStats(
post_transform_feature_stats_loc,
stats_options=post_transform_stats_options,
enable_validation=True))
if per_set_stats_output_paths:
# TODO(b/130885503): Remove duplicate stats gen compute that is
# done both on a flattened view of the data, and on each span
# below.
for dataset in transform_data_list:
infix = 'TransformIndex{}'.format(dataset.index)
(dataset.transformed_and_standardized
| 'WaitForTransformWrite[{}]'.format(infix) >> beam.Map(
lambda x, completion: x,
completion=beam.pvalue.AsSingleton(completed_transform))
| 'GenerateAndValidateStats[{}]'.format(infix) >>
self._GenerateAndMaybeValidateStats(
dataset.stats_output_path,
stats_options=post_transform_stats_options,
enable_validation=True))
if materialization_format is not None:
for dataset in transform_data_list:
infix = 'TransformIndex{}'.format(dataset.index)
(dataset.transformed
| 'EncodeAndWrite[{}]'.format(infix) >> self._EncodeAndWrite(
schema=transformed_schema_proto,
file_format=materialization_format,
output_path=dataset.materialize_output_path))
return _Status.OK()