def _RunBeamImpl()

in tfx/components/transform/executor.py [0:0]


  def _RunBeamImpl(
      self, analyze_data_list: List[_Dataset],
      transform_data_list: List[_Dataset], preprocessing_fn: Any,
      stats_options_updater_fn: Callable[
          [stats_options_util.StatsType, tfdv.StatsOptions],
          tfdv.StatsOptions], force_tf_compat_v1: bool,
      input_dataset_metadata: dataset_metadata.DatasetMetadata,
      transform_output_path: str, raw_examples_data_format: int, temp_path: str,
      input_cache_dir: Optional[str], output_cache_dir: Optional[str],
      disable_statistics: bool, per_set_stats_output_paths: Sequence[str],
      materialization_format: Optional[str], analyze_paths_count: int,
      stats_output_paths: Dict[str, str],
      make_beam_pipeline_fn: Callable[[], beam.Pipeline]) -> _Status:
    """Perform data preprocessing with TFT.

    Args:
      analyze_data_list: List of datasets for analysis.
      transform_data_list: List of datasets for transform.
      preprocessing_fn: The tf.Transform preprocessing_fn.
      stats_options_updater_fn: The user-specified function for updating stats
        options.
      force_tf_compat_v1: If True, call Transform's API to use Tensorflow in
        tf.compat.v1 mode.
      input_dataset_metadata: A DatasetMetadata object for the input data.
      transform_output_path: An absolute path to write the output to.
      raw_examples_data_format: The data format of the raw examples. One of the
        enums from example_gen_pb2.PayloadFormat.
      temp_path: A path to a temporary dir.
      input_cache_dir: A dir containing the input analysis cache. May be None.
      output_cache_dir: A dir to write the analysis cache to. May be None.
      disable_statistics: A bool indicating whether or to disable statistics.
      per_set_stats_output_paths: Paths to per-set statistics output. If empty,
        per-set statistics is not produced.
      materialization_format: A string describing the format of the materialized
        data or None if materialization is not enabled.
      analyze_paths_count: An integer, the number of paths that should be used
        for analysis.
      stats_output_paths: A dictionary specifying the output paths to use when
        computing statistics. If the dictionary is empty, the stats will be
        placed within the transform_output_path to preserve backwards
        compatibility.
      make_beam_pipeline_fn: A callable that can create a beam pipeline.

    Returns:
      Status of the execution.
    """
    self._AssertSameTFXIOSchema(analyze_data_list)
    unprojected_typespecs = (
        analyze_data_list[0].tfxio.TensorAdapter().OriginalTypeSpecs())

    analyze_input_columns = tft.get_analyze_input_columns(
        preprocessing_fn,
        unprojected_typespecs,
        force_tf_compat_v1=force_tf_compat_v1)
    analyze_columns_count = len(analyze_input_columns)

    transform_input_columns = tft.get_transform_input_columns(
        preprocessing_fn,
        unprojected_typespecs,
        force_tf_compat_v1=force_tf_compat_v1)
    # Use the same dataset (same columns) for AnalyzeDataset and computing
    # pre-transform stats so that the data will only be read once for these
    # two operations.
    if not disable_statistics:
      analyze_input_columns = list(
          set(list(analyze_input_columns) + list(transform_input_columns)))

    for d in analyze_data_list:
      d.tfxio = d.tfxio.Project(analyze_input_columns)

    self._AssertSameTFXIOSchema(analyze_data_list)
    analyze_data_tensor_adapter_config = (
        analyze_data_list[0].tfxio.TensorAdapterConfig())

    for d in transform_data_list:
      d.tfxio = d.tfxio.Project(transform_input_columns)

    desired_batch_size = self._GetDesiredBatchSize(
        raw_examples_data_format, input_dataset_metadata.schema)

    with make_beam_pipeline_fn() as pipeline:
      with tft_beam.Context(
          temp_dir=temp_path,
          desired_batch_size=desired_batch_size,
          passthrough_keys=self._GetTFXIOPassthroughKeys(),
          use_deep_copy_optimization=True,
          force_tf_compat_v1=force_tf_compat_v1):
        (new_analyze_data_dict, input_cache,
         estimated_stage_count_with_cache) = (
             pipeline
             | 'OptimizeRun' >> self._OptimizeRun(
                 input_cache_dir, output_cache_dir, analyze_data_list,
                 unprojected_typespecs, preprocessing_fn,
                 self._GetCacheSource(), force_tf_compat_v1))

        _ = (
            pipeline
            | 'IncrementPipelineMetrics' >> self._IncrementPipelineMetrics(
                total_columns_count=len(unprojected_typespecs),
                analyze_columns_count=analyze_columns_count,
                transform_columns_count=len(transform_input_columns),
                analyze_paths_count=analyze_paths_count,
                analyzer_cache_enabled=input_cache is not None,
                disable_statistics=disable_statistics,
                materialize=materialization_format is not None,
                estimated_stage_count_with_cache=(
                    estimated_stage_count_with_cache)))

        if input_cache:
          logging.debug('Analyzing data with cache.')

        full_analyze_dataset_keys_list = [
            dataset.dataset_key for dataset in analyze_data_list
        ]

        # Removing unneeded datasets if they won't be needed for statistics or
        # materialization.
        if materialization_format is None and disable_statistics:
          if None in new_analyze_data_dict.values():
            logging.debug(
                'Not reading the following datasets due to cache: %s', [
                    dataset.file_pattern
                    for dataset in analyze_data_list
                    if new_analyze_data_dict[dataset.dataset_key] is None
                ])
          analyze_data_list = [
              d for d in new_analyze_data_dict.values() if d is not None
          ]

        for dataset in analyze_data_list:
          infix = 'AnalysisIndex{}'.format(dataset.index)
          dataset.standardized = (
              pipeline
              | 'TFXIOReadAndDecode[{}]'.format(infix) >>
              dataset.tfxio.BeamSource(desired_batch_size))

        input_analysis_data = {}
        for key, dataset in new_analyze_data_dict.items():
          input_analysis_data[key] = (None if dataset is None else
                                      dataset.standardized)

        transform_fn, cache_output = (
            (input_analysis_data, input_cache,
             analyze_data_tensor_adapter_config)
            | 'Analyze' >> tft_beam.AnalyzeDatasetWithCache(
                preprocessing_fn, pipeline=pipeline))

        # Write the raw/input metadata.
        (input_dataset_metadata
         | 'WriteMetadata' >> tft_beam.WriteMetadata(
             os.path.join(transform_output_path,
                          tft.TFTransformOutput.RAW_METADATA_DIR), pipeline))

        # WriteTransformFn writes transform_fn and metadata to subdirectories
        # tensorflow_transform.SAVED_MODEL_DIR and
        # tensorflow_transform.TRANSFORMED_METADATA_DIR respectively.
        completed_transform = (
            transform_fn
            | 'WriteTransformFn' >>
            tft_beam.WriteTransformFn(transform_output_path))

        if output_cache_dir is not None and cache_output is not None:
          fileio.makedirs(output_cache_dir)
          logging.debug('Using existing cache in: %s', input_cache_dir)
          if input_cache_dir is not None:
            # Only copy cache that is relevant to this iteration. This is
            # assuming that this pipeline operates on rolling ranges, so those
            # cache entries may also be relevant for future iterations.
            for span_cache_dir in input_analysis_data:
              full_span_cache_dir = os.path.join(input_cache_dir,
                                                 span_cache_dir.key)
              if fileio.isdir(full_span_cache_dir):
                self._CopyCache(
                    full_span_cache_dir,
                    os.path.join(output_cache_dir, span_cache_dir.key))

          # TODO(b/157479287, b/171165988): Remove this condition when beam
          # 2.26 is used.
          if cache_output:
            (cache_output
             | 'WriteCache' >> analyzer_cache.WriteAnalysisCacheToFS(
                 pipeline=pipeline,
                 cache_base_dir=output_cache_dir,
                 sink=self._GetCacheSink(),
                 dataset_keys=full_analyze_dataset_keys_list))

        if not disable_statistics or materialization_format is not None:
          # Do not compute pre-transform stats if the input format is raw
          # proto, as StatsGen would treat any input as tf.Example. Note that
          # tf.SequenceExamples are wire-format compatible with tf.Examples.
          if (not disable_statistics and
              not self._IsDataFormatProto(raw_examples_data_format)):
            # Aggregated feature stats before transformation.
            if (self._DecodesSequenceExamplesAsRawRecords(
                raw_examples_data_format, input_dataset_metadata.schema)):
              schema_proto = None
            else:
              schema_proto = input_dataset_metadata.schema

            if (self._DecodesSequenceExamplesAsRawRecords(
                raw_examples_data_format, input_dataset_metadata.schema)):

              def _ExtractRawExampleBatches(record_batch):
                return record_batch.column(
                    record_batch.schema.get_field_index(
                        RAW_EXAMPLE_KEY)).flatten().to_pylist()

              # Make use of the fact that tf.SequenceExample is wire-format
              # compatible with tf.Example
              stats_input = []
              for dataset in analyze_data_list:
                infix = 'AnalysisIndex{}'.format(dataset.index)
                stats_input.append(
                    dataset.standardized
                    | 'ExtractRawExampleBatches[{}]'.format(
                        infix) >> beam.Map(_ExtractRawExampleBatches)
                    | 'DecodeSequenceExamplesAsExamplesIntoRecordBatches[{}]'
                    .format(infix) >> beam.ParDo(
                        self._ToArrowRecordBatchesFn(schema_proto)))
            else:
              stats_input = [
                  dataset.standardized for dataset in analyze_data_list
              ]

            pre_transform_stats_options = _InvokeStatsOptionsUpdaterFn(
                stats_options_updater_fn,
                stats_options_util.StatsType.PRE_TRANSFORM, schema_proto)
            if self._TFDVWriteShardedOutput():
              pre_transform_stats_options.experimental_result_partitions = (
                  _SHARDED_OUTPUT_PARTITIONS)
            else:
              if (pre_transform_stats_options.experimental_result_partitions !=
                  1):
                raise ValueError('Sharded output disabled requires '
                                 'experimental_result_partitions=1.')

            if stats_output_paths:
              pre_transform_feature_stats_loc = {
                  _STATS_KEY:
                      os.path.join(
                          stats_output_paths[
                              labels.PRE_TRANSFORM_OUTPUT_STATS_PATH_LABEL],
                          STATS_FILE),
                  _SHARDED_STATS_KEY:
                      os.path.join(
                          stats_output_paths[
                              labels.PRE_TRANSFORM_OUTPUT_STATS_PATH_LABEL],
                          SHARDED_STATS_PREFIX),
                  _SCHEMA_KEY:
                      os.path.join(
                          stats_output_paths[
                              labels.PRE_TRANSFORM_OUTPUT_SCHEMA_PATH_LABEL],
                          _SCHEMA_FILE)
              }
            else:
              pre_transform_feature_stats_loc = os.path.join(
                  transform_output_path,
                  tft.TFTransformOutput.PRE_TRANSFORM_FEATURE_STATS_PATH)

            (stats_input
             | 'FlattenAnalysisDatasets' >> beam.Flatten(pipeline=pipeline)
             | 'GenerateStats[FlattenedAnalysisDataset]' >>
             self._GenerateAndMaybeValidateStats(
                 pre_transform_feature_stats_loc,
                 stats_options=pre_transform_stats_options,
                 enable_validation=False))

          # transform_data_list is a superset of analyze_data_list, we pay the
          # cost to read the same dataset (analyze_data_list) again here to
          # prevent certain beam runner from doing large temp materialization.
          for dataset in transform_data_list:
            infix = 'TransformIndex{}'.format(dataset.index)
            dataset.standardized = (
                pipeline | 'TFXIOReadAndDecode[{}]'.format(infix) >>
                dataset.tfxio.BeamSource(desired_batch_size))
            (dataset.transformed, metadata) = (
                ((dataset.standardized, dataset.tfxio.TensorAdapterConfig()),
                 transform_fn)
                | 'Transform[{}]'.format(infix) >>
                tft_beam.TransformDataset(output_record_batches=True))

          _, metadata = transform_fn

          # TODO(b/70392441): Retain tf.Metadata (e.g., IntDomain) in
          # schema. Currently input dataset schema only contains dtypes,
          # and other metadata is dropped due to roundtrip to tensors.
          transformed_schema_proto = metadata.schema

          if not disable_statistics:
            # Aggregated feature stats after transformation.
            for dataset in transform_data_list:
              infix = 'TransformIndex{}'.format(dataset.index)
              dataset.transformed_and_standardized = (
                  dataset.transformed
                  | 'ExtractRecordBatches[{}]'.format(infix) >> beam.Keys())

            post_transform_stats_options = _InvokeStatsOptionsUpdaterFn(
                stats_options_updater_fn,
                stats_options_util.StatsType.POST_TRANSFORM,
                transformed_schema_proto, metadata.asset_map,
                transform_output_path)

            if self._TFDVWriteShardedOutput():
              post_transform_stats_options.experimental_result_partitions = (
                  _SHARDED_OUTPUT_PARTITIONS)
            else:
              if (post_transform_stats_options.experimental_result_partitions !=
                  1):
                raise ValueError('Sharded output disabled requires '
                                 'experimental_result_partitions=1.')

            if stats_output_paths:
              post_transform_feature_stats_loc = {
                  _STATS_KEY:
                      os.path.join(
                          stats_output_paths[
                              labels.POST_TRANSFORM_OUTPUT_STATS_PATH_LABEL],
                          STATS_FILE),
                  _SHARDED_STATS_KEY:
                      os.path.join(
                          stats_output_paths[
                              labels.POST_TRANSFORM_OUTPUT_STATS_PATH_LABEL],
                          SHARDED_STATS_PREFIX),
                  _SCHEMA_KEY:
                      os.path.join(
                          stats_output_paths[
                              labels.POST_TRANSFORM_OUTPUT_SCHEMA_PATH_LABEL],
                          _SCHEMA_FILE),
                  _ANOMALIES_KEY:
                      os.path.join(
                          stats_output_paths[
                              labels
                              .POST_TRANSFORM_OUTPUT_ANOMALIES_PATH_LABEL],
                          _ANOMALIES_FILE)
              }
            else:
              post_transform_feature_stats_loc = os.path.join(
                  transform_output_path,
                  tft.TFTransformOutput.POST_TRANSFORM_FEATURE_STATS_PATH)

            ([
                dataset.transformed_and_standardized
                for dataset in transform_data_list
            ]
             | 'FlattenTransformedDatasets' >> beam.Flatten(pipeline=pipeline)
             | 'WaitForTransformWrite' >> beam.Map(
                 lambda x, completion: x,
                 completion=beam.pvalue.AsSingleton(completed_transform))
             | 'GenerateAndValidateStats[FlattenedTransformedDatasets]' >>
             self._GenerateAndMaybeValidateStats(
                 post_transform_feature_stats_loc,
                 stats_options=post_transform_stats_options,
                 enable_validation=True))

            if per_set_stats_output_paths:
              # TODO(b/130885503): Remove duplicate stats gen compute that is
              # done both on a flattened view of the data, and on each span
              # below.
              for dataset in transform_data_list:
                infix = 'TransformIndex{}'.format(dataset.index)
                (dataset.transformed_and_standardized
                 | 'WaitForTransformWrite[{}]'.format(infix) >> beam.Map(
                     lambda x, completion: x,
                     completion=beam.pvalue.AsSingleton(completed_transform))
                 | 'GenerateAndValidateStats[{}]'.format(infix) >>
                 self._GenerateAndMaybeValidateStats(
                     dataset.stats_output_path,
                     stats_options=post_transform_stats_options,
                     enable_validation=True))

          if materialization_format is not None:
            for dataset in transform_data_list:
              infix = 'TransformIndex{}'.format(dataset.index)
              (dataset.transformed
               | 'EncodeAndWrite[{}]'.format(infix) >> self._EncodeAndWrite(
                   schema=transformed_schema_proto,
                   file_format=materialization_format,
                   output_path=dataset.materialize_output_path))

    return _Status.OK()