tfx/benchmarks/tft_benchmark_base.py (284 lines of code) (raw):

# Copyright 2019 Google LLC. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """TFT benchmark base.""" import collections import shutil import tempfile import time from absl import logging import apache_beam as beam from apache_beam.utils import shared import tensorflow as tf import tensorflow_transform as tft from tensorflow_transform import graph_tools from tensorflow_transform import impl_helper import tensorflow_transform.beam as tft_beam from tensorflow_transform.beam import impl as tft_beam_impl from tensorflow_transform.saved import saved_transform_io from tensorflow_transform.saved import saved_transform_io_v2 from tensorflow_transform.tf_metadata import dataset_metadata from tensorflow_transform.tf_metadata import schema_utils import tfx from tfx.benchmarks import benchmark_utils from tfx.benchmarks import benchmark_base from tfx_bsl.coders import example_coder from tfx_bsl.tfxio import tensor_adapter from tfx_bsl.tfxio import tf_example_record class _CopySavedModel(beam.PTransform): """Copies the TFT SavedModel to another directory.""" def __init__(self, dest_path): self._dest_path = dest_path def expand(self, transform_fn): def copy_saved_model(unused_element, source_path, dest_path): shutil.rmtree(dest_path, ignore_errors=True) shutil.copytree(source_path, dest_path) logging.info("Copied SavedModel from %s to %s", source_path, dest_path) return (transform_fn.pipeline | "CreateSole" >> beam.Create([None]) | "CopySavedModel" >> beam.Map( copy_saved_model, source_path=beam.pvalue.AsSingleton(transform_fn), dest_path=self._dest_path)) class _AnalyzeAndTransformDataset(beam.PTransform): """PTransform to run AnalyzeAndTransformDataset.""" def __init__(self, dataset, tfxio, preprocessing_fn, transform_input_dataset_metadata, force_tf_compat_v1=True, max_num_examples=None, generate_dataset=False): """Constructor. Args: dataset: BenchmarkDataset object. tfxio: A `tfx_bsl.TFXIO` instance. preprocessing_fn: preprocessing_fn. transform_input_dataset_metadata: dataset_metadata.DatasetMetadata. force_tf_compat_v1: If False then Transform will use its native TF2 version, if True then Transform will use its TF1 version. max_num_examples: Max number of examples to read from the dataset. generate_dataset: If True, generates the raw dataset and appropriate intermediate outputs (just the TFT SavedModel for now) necessary for other benchmarks. """ self._dataset = dataset self._tfxio = tfxio self._preprocessing_fn = preprocessing_fn self._transform_input_dataset_metadata = transform_input_dataset_metadata self._force_tf_compat_v1 = force_tf_compat_v1 self._max_num_examples = max_num_examples self._generate_dataset = generate_dataset def expand(self, pipeline): # TODO(b/147620802): Consider making this (and other parameters) # configurable to test more variants (e.g. with and without deep-copy # optimisation, with and without cache, etc). with tft_beam.Context( temp_dir=tempfile.mkdtemp(), force_tf_compat_v1=self._force_tf_compat_v1): raw_data = ( pipeline | "ReadDataset" >> beam.Create( self._dataset.read_raw_dataset( deserialize=False, limit=self._max_num_examples)) | "Decode" >> self._tfxio.BeamSource()) transform_fn, output_metadata = ( (raw_data, self._tfxio.TensorAdapterConfig()) | "AnalyzeDataset" >> tft_beam.AnalyzeDataset(self._preprocessing_fn)) if self._generate_dataset: _ = transform_fn | "CopySavedModel" >> _CopySavedModel( dest_path=self._dataset.tft_saved_model_path( self._force_tf_compat_v1)) (transformed_dataset, transformed_metadata) = ( ((raw_data, self._tfxio.TensorAdapterConfig()), (transform_fn, output_metadata)) | "TransformDataset" >> tft_beam.TransformDataset(output_record_batches=True)) return transformed_dataset, transformed_metadata # Tuple for variables common to all benchmarks. CommonVariablesTuple = collections.namedtuple("CommonVariablesTuple", [ "tf_metadata_schema", "preprocessing_fn", "transform_input_dataset_metadata", "tfxio", ]) def _get_common_variables(dataset, force_tf_compat_v1): """Returns metadata schema, preprocessing fn, input dataset metadata.""" tf_metadata_schema = benchmark_utils.read_schema( dataset.tf_metadata_schema_path()) preprocessing_fn = dataset.tft_preprocessing_fn() feature_spec = schema_utils.schema_as_feature_spec( tf_metadata_schema).feature_spec type_spec = impl_helper.get_type_specs_from_feature_specs(feature_spec) transform_input_columns = ( tft.get_transform_input_columns( preprocessing_fn, type_spec, force_tf_compat_v1=force_tf_compat_v1)) transform_input_dataset_metadata = dataset_metadata.DatasetMetadata( schema_utils.schema_from_feature_spec({ feature: feature_spec[feature] for feature in transform_input_columns })) tfxio = tf_example_record.TFExampleBeamRecord( physical_format="tfexamples", schema=transform_input_dataset_metadata.schema, telemetry_descriptors=["TFTransformBenchmark"]) return CommonVariablesTuple( tf_metadata_schema=tf_metadata_schema, preprocessing_fn=preprocessing_fn, transform_input_dataset_metadata=transform_input_dataset_metadata, tfxio=tfxio) def regenerate_intermediates_for_dataset(dataset, force_tf_compat_v1=True, max_num_examples=None): """Regenerate intermediate outputs required for the benchmark.""" common_variables = _get_common_variables(dataset, force_tf_compat_v1) logging.info("Regenerating intermediate outputs required for benchmark.") with beam.Pipeline() as p: _ = p | _AnalyzeAndTransformDataset( dataset, common_variables.tfxio, common_variables.preprocessing_fn, common_variables.transform_input_dataset_metadata, force_tf_compat_v1=force_tf_compat_v1, max_num_examples=max_num_examples, generate_dataset=True) logging.info("Intermediate outputs regenerated.") def _get_batched_records(dataset, force_tf_compat_v1, max_num_examples=None): """Returns a (batch_size, iterator for batched records) tuple for the dataset. Args: dataset: BenchmarkDataset object. force_tf_compat_v1: If False then Transform will use its native TF2 version, if True then Transform will use its TF1 version. max_num_examples: Maximum number of examples to read from the dataset. Returns: Tuple of (batch_size, iterator for batched records), where records are decoded tf.train.Examples. """ batch_size = 1000 common_variables = _get_common_variables(dataset, force_tf_compat_v1) converter = example_coder.ExamplesToRecordBatchDecoder( common_variables.transform_input_dataset_metadata.schema .SerializeToString()) serialized_records = benchmark_utils.batched_iterator( dataset.read_raw_dataset(deserialize=False, limit=max_num_examples), batch_size) records = [converter.DecodeBatch(x) for x in serialized_records] return batch_size, records class TFTBenchmarkBase(benchmark_base.BenchmarkBase): """TFT benchmark base class.""" def __init__(self, dataset, **kwargs): # Benchmark runners may pass extraneous arguments we don't care about. del kwargs super().__init__() self._dataset = dataset def report_benchmark(self, **kwargs): if "extras" not in kwargs: kwargs["extras"] = {} # Note that the GIT_COMMIT_ID is not included in the packages themselves: # it must be injected by an external script. kwargs["extras"]["commit_tfx"] = (getattr(tfx, "GIT_COMMIT_ID", None) or getattr(tfx, "__version__", None)) kwargs["extras"]["commit_tft"] = (getattr(tft, "GIT_COMMIT_ID", None) or getattr(tft, "__version__", None)) super().report_benchmark(**kwargs) def _benchmarkAnalyzeAndTransformDatasetCommon(self, force_tf_compat_v1): """Common implementation to benchmark AnalyzeAndTransformDataset.""" common_variables = _get_common_variables(self._dataset, force_tf_compat_v1) pipeline = self._create_beam_pipeline() _ = pipeline | _AnalyzeAndTransformDataset( self._dataset, common_variables.tfxio, common_variables.preprocessing_fn, common_variables.transform_input_dataset_metadata, force_tf_compat_v1=force_tf_compat_v1, max_num_examples=self._max_num_examples()) start = time.time() result = pipeline.run() result.wait_until_finish() end = time.time() delta = end - start self.report_benchmark( iters=1, wall_time=delta, extras={ "num_examples": self._dataset.num_examples(limit=self._max_num_examples()) }) def benchmarkAnalyzeAndTransformDataset(self): """Benchmark AnalyzeAndTransformDataset for TFT's TF1 implementation. Runs AnalyzeAndTransformDataset in a Beam pipeline. Records the wall time taken for the whole pipeline. """ self._benchmarkAnalyzeAndTransformDatasetCommon(force_tf_compat_v1=True) def benchmarkTF2AnalyzeAndTransformDataset(self): """Benchmark AnalyzeAndTransformDataset for TFT's TF2 implementation. Runs AnalyzeAndTransformDataset in a Beam pipeline. Records the wall time taken for the whole pipeline. """ self._benchmarkAnalyzeAndTransformDatasetCommon(force_tf_compat_v1=False) def _benchmarkRunMetaGraphDoFnManualActuationCommon(self, force_tf_compat_v1): """Common implementation to benchmark RunMetaGraphDoFn "manually".""" common_variables = _get_common_variables(self._dataset, force_tf_compat_v1) batch_size, batched_records = _get_batched_records(self._dataset, force_tf_compat_v1, self._max_num_examples()) fn = tft_beam_impl._RunMetaGraphDoFn( # pylint: disable=protected-access tf_config=None, shared_graph_state_handle=shared.Shared(), passthrough_keys=set(), exclude_outputs=None, use_tf_compat_v1=force_tf_compat_v1, input_tensor_adapter_config=common_variables.tfxio.TensorAdapterConfig( )) fn.setup() start = time.time() for batch in batched_records: _ = list( fn.process( batch, saved_model_dir=self._dataset.tft_saved_model_path( force_tf_compat_v1))) end = time.time() delta = end - start self.report_benchmark( iters=1, wall_time=delta, extras={ "batch_size": batch_size, "num_examples": self._dataset.num_examples(limit=self._max_num_examples()) }) def benchmarkRunMetaGraphDoFnManualActuation(self): """Benchmark RunMetaGraphDoFn "manually" for TFT's TF1 implementation. Runs RunMetaGraphDoFn "manually" outside of a Beam pipeline. Records the wall time taken. """ self._benchmarkRunMetaGraphDoFnManualActuationCommon( force_tf_compat_v1=True) def benchmarkTF2RunMetaGraphDoFnManualActuation(self): """Benchmark RunMetaGraphDoFn "manually" for TFT's TF2 implementation. Runs RunMetaGraphDoFn "manually" outside of a Beam pipeline. Records the wall time taken. """ self._benchmarkRunMetaGraphDoFnManualActuationCommon( force_tf_compat_v1=False) def benchmarkRunMetagraphDoFnAtTFLevel(self): """Benchmark RunMetaGraphDoFn at the TF level for TFT's TF1 implementation. Benchmarks the parts of RunMetaGraphDoFn that involve feeding and fetching from the TFT SavedModel. Records the wall time taken. Note that this benchmark necessarily duplicates code directly from TFT since it's benchmarking the low-level internals of TFT, which are not exposed for use in this way. """ common_variables = _get_common_variables( self._dataset, force_tf_compat_v1=True) tf_config = tft_beam_impl._FIXED_PARALLELISM_TF_CONFIG # pylint: disable=protected-access # This block copied from _GraphStateCompatV1.__init__ with tf.compat.v1.Graph().as_default() as graph: session = tf.compat.v1.Session(graph=graph, config=tf_config) with session.as_default(): inputs, outputs = ( saved_transform_io.partially_apply_saved_transform_internal( self._dataset.tft_saved_model_path(force_tf_compat_v1=True), {})) session.run(tf.compat.v1.global_variables_initializer()) session.run(tf.compat.v1.tables_initializer()) graph.finalize() # We ignore the schema, and assume there are no excluded outputs. outputs_tensor_keys = sorted(set(outputs.keys())) fetches = [outputs[key] for key in outputs_tensor_keys] tensor_inputs = graph_tools.get_dependent_inputs(graph, inputs, fetches) input_tensor_keys = sorted(tensor_inputs.keys()) feed_list = [inputs[key] for key in input_tensor_keys] callable_get_outputs = session.make_callable(fetches, feed_list=feed_list) batch_size, batched_records = _get_batched_records( self._dataset, force_tf_compat_v1=True, max_num_examples=self._max_num_examples()) input_tensor_adapter = tensor_adapter.TensorAdapter( common_variables.tfxio.TensorAdapterConfig()) # This block copied from _RunMetaGraphDoFn._handle_batch start = time.time() for batch in batched_records: feed_by_name = input_tensor_adapter.ToBatchTensors( batch, produce_eager_tensors=False) feed_list = [feed_by_name[name] for name in input_tensor_keys] outputs_list = callable_get_outputs(*feed_list) _ = {key: value for key, value in zip(outputs_tensor_keys, outputs_list)} end = time.time() delta = end - start self.report_benchmark( iters=1, wall_time=delta, extras={ "batch_size": batch_size, "num_examples": self._dataset.num_examples(limit=self._max_num_examples()) }) def benchmarkTF2RunMetagraphDoFnAtTFLevel(self): """Benchmark RunMetaGraphDoFn at the TF level for TFT's TF2 implementation. Benchmarks the parts of RunMetaGraphDoFn that involve feeding and fetching from the TFT SavedModel. Records the wall time taken. Note that this benchmark necessarily duplicates code directly from TFT since it's benchmarking the low-level internals of TFT, which are not exposed for use in this way. """ common_variables = _get_common_variables( self._dataset, force_tf_compat_v1=False) tensor_adapter_config = common_variables.tfxio.TensorAdapterConfig() # This block copied from _GraphStateV2.__init__ saved_model_loader = saved_transform_io_v2.SavedModelLoader( self._dataset.tft_saved_model_path(force_tf_compat_v1=False)) callable_get_outputs = saved_model_loader.apply_transform_model # We ignore the schema, and assume there are no excluded outputs. outputs_tensor_keys = set(saved_model_loader.structured_outputs.keys()) saved_model_loader.finalize( tensor_adapter_config.tensor_representations.keys(), outputs_tensor_keys) batch_size, batched_records = _get_batched_records( self._dataset, force_tf_compat_v1=False, max_num_examples=self._max_num_examples()) input_tensor_adapter = tensor_adapter.TensorAdapter(tensor_adapter_config) # This block copied from _RunMetaGraphDoFn._handle_batch start = time.time() for batch in batched_records: feed_dict = input_tensor_adapter.ToBatchTensors( batch, produce_eager_tensors=True) _ = callable_get_outputs(feed_dict) end = time.time() delta = end - start self.report_benchmark( iters=1, wall_time=delta, extras={ "batch_size": batch_size, "num_examples": self._dataset.num_examples(limit=self._max_num_examples()) })