tfx/components/example_gen/driver.py (128 lines of code) (raw):

# Copyright 2019 Google LLC. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Generic TFX ExampleGen custom driver.""" import copy import os from typing import Any, Dict, List, Iterable, Optional from absl import logging from tfx import types from tfx.components.example_gen import input_processor from tfx.components.example_gen import utils from tfx.dsl.components.base import base_driver from tfx.orchestration import data_types from tfx.orchestration import data_types_utils from tfx.orchestration import metadata from tfx.orchestration.portable import base_driver as ir_base_driver from tfx.orchestration.portable import data_types as portable_data_types from tfx.proto import example_gen_pb2 from tfx.proto import range_config_pb2 from tfx.proto.orchestration import driver_output_pb2 from tfx.types import standard_component_specs from tfx.utils import proto_utils from ml_metadata.proto import metadata_store_pb2 def update_output_artifact( exec_properties: Dict[str, Any], output_artifact: metadata_store_pb2.Artifact) -> None: """Updates output_artifact for FileBasedExampleGen. Updates output_artifact properties by updating existing entries or creating new entries if not already exists. Args: exec_properties: execution properties passed to the example gen. output_artifact: the example artifact to be output. """ if exec_properties.get(utils.FINGERPRINT_PROPERTY_NAME): output_artifact.custom_properties[ utils.FINGERPRINT_PROPERTY_NAME].string_value = ( exec_properties[utils.FINGERPRINT_PROPERTY_NAME]) output_artifact.custom_properties[ utils.SPAN_PROPERTY_NAME].int_value = exec_properties[ utils.SPAN_PROPERTY_NAME] # TODO(b/162622803): add default behavior for when version spec not present. if exec_properties[utils.VERSION_PROPERTY_NAME] is not None: output_artifact.custom_properties[ utils.VERSION_PROPERTY_NAME].int_value = exec_properties[ utils.VERSION_PROPERTY_NAME] class Driver(base_driver.BaseDriver, ir_base_driver.BaseDriver): """Custom driver for ExampleGen.""" def __init__(self, metadata_handler: metadata.Metadata): base_driver.BaseDriver.__init__(self, metadata_handler) ir_base_driver.BaseDriver.__init__(self, metadata_handler) def get_input_processor( self, splits: Iterable[example_gen_pb2.Input.Split], range_config: Optional[range_config_pb2.RangeConfig] = None, input_base_uri: Optional[str] = None) -> input_processor.InputProcessor: """Returns the custom InputProcessor for this driver.""" raise NotImplementedError def resolve_exec_properties( self, exec_properties: Dict[str, Any], pipeline_info: data_types.PipelineInfo, component_info: data_types.ComponentInfo, ) -> Dict[str, Any]: """Overrides BaseDriver.resolve_exec_properties().""" del pipeline_info, component_info input_config = example_gen_pb2.Input() proto_utils.json_to_proto( exec_properties[standard_component_specs.INPUT_CONFIG_KEY], input_config) input_base = exec_properties.get(standard_component_specs.INPUT_BASE_KEY) logging.debug('Processing input %s.', input_base) range_config = None range_config_entry = exec_properties.get( standard_component_specs.RANGE_CONFIG_KEY) if range_config_entry: range_config = range_config_pb2.RangeConfig() proto_utils.json_to_proto(range_config_entry, range_config) processor = self.get_input_processor( splits=input_config.splits, range_config=range_config, input_base_uri=input_base) span, version = processor.resolve_span_and_version() fingerprint = processor.get_input_fingerprint(span, version) # Updates the input_config.splits.pattern. for split in input_config.splits: split.pattern = processor.get_pattern_for_span_version( split.pattern, span, version) exec_properties[standard_component_specs .INPUT_CONFIG_KEY] = proto_utils.proto_to_json(input_config) exec_properties[utils.SPAN_PROPERTY_NAME] = span exec_properties[utils.VERSION_PROPERTY_NAME] = version exec_properties[utils.FINGERPRINT_PROPERTY_NAME] = fingerprint return exec_properties def _prepare_output_artifacts( self, input_artifacts: Dict[str, List[types.Artifact]], output_dict: Dict[str, types.Channel], exec_properties: Dict[str, Any], execution_id: int, pipeline_info: data_types.PipelineInfo, component_info: data_types.ComponentInfo, ) -> Dict[str, List[types.Artifact]]: """Overrides BaseDriver._prepare_output_artifacts().""" del input_artifacts example_artifact = output_dict[standard_component_specs.EXAMPLES_KEY].type() base_output_dir = os.path.join(pipeline_info.pipeline_root, component_info.component_id) example_artifact.uri = base_driver._generate_output_uri( # pylint: disable=protected-access base_output_dir, standard_component_specs.EXAMPLES_KEY, execution_id) update_output_artifact(exec_properties, example_artifact.mlmd_artifact) base_driver._prepare_output_paths(example_artifact) # pylint: disable=protected-access return {standard_component_specs.EXAMPLES_KEY: [example_artifact]} def run( self, execution_info: portable_data_types.ExecutionInfo ) -> driver_output_pb2.DriverOutput: # Populate exec_properties result = driver_output_pb2.DriverOutput() # PipelineInfo and ComponentInfo are not actually used, two fake one are # created just to be compatible with the old API. pipeline_info = data_types.PipelineInfo('', '') component_info = data_types.ComponentInfo('', '', pipeline_info) exec_properties = self.resolve_exec_properties( execution_info.exec_properties, pipeline_info, component_info) for k, v in exec_properties.items(): if v is not None: data_types_utils.set_metadata_value(result.exec_properties[k], v) # Populate output_dict output_example = copy.deepcopy(execution_info.output_dict[ standard_component_specs.EXAMPLES_KEY][0].mlmd_artifact) update_output_artifact(exec_properties, output_example) result.output_artifacts[ standard_component_specs.EXAMPLES_KEY].artifacts.append(output_example) return result class FileBasedDriver(Driver): """Custom Driver for file based ExampleGen, e.g., ImportExampleGen.""" def get_input_processor( self, splits: Iterable[example_gen_pb2.Input.Split], range_config: Optional[range_config_pb2.RangeConfig] = None, input_base_uri: Optional[str] = None) -> input_processor.InputProcessor: """Returns FileBasedInputProcessor.""" assert input_base_uri return input_processor.FileBasedInputProcessor(input_base_uri, splits, range_config) class QueryBasedDriver(Driver): """Custom Driver for query based ExampleGen, e.g., BigQueryExampleGen.""" def get_input_processor( self, splits: Iterable[example_gen_pb2.Input.Split], range_config: Optional[range_config_pb2.RangeConfig] = None, input_base_uri: Optional[str] = None) -> input_processor.InputProcessor: """Returns QueryBasedInputProcessor.""" return input_processor.QueryBasedInputProcessor(splits, range_config)