in tfx/extensions/google_cloud_ai_platform/bulk_inferrer/executor.py [0:0]
def Do(self, input_dict: Dict[str, List[types.Artifact]],
output_dict: Dict[str, List[types.Artifact]],
exec_properties: Dict[str, Any]) -> None:
"""Runs batch inference on a given model with given input examples.
This function creates a new model (if necessary) and a new model version
before inference, and cleans up resources after inference. It provides
re-executability as it cleans up (only) the model resources that are created
during the process even inference job failed.
Args:
input_dict: Input dict from input key to a list of Artifacts.
- examples: examples for inference.
- model: exported model.
- model_blessing: model blessing result
output_dict: Output dict from output key to a list of Artifacts.
- output: bulk inference results.
exec_properties: A dict of execution properties.
- data_spec: JSON string of bulk_inferrer_pb2.DataSpec instance.
- custom_config: custom_config.ai_platform_serving_args need to contain
the serving job parameters sent to Google Cloud AI Platform. For the
full set of parameters, refer to
https://cloud.google.com/ml-engine/reference/rest/v1/projects.models
Returns:
None
"""
self._log_startup(input_dict, output_dict, exec_properties)
if output_dict.get('inference_result'):
inference_result = artifact_utils.get_single_instance(
output_dict['inference_result'])
else:
inference_result = None
if output_dict.get('output_examples'):
output_examples = artifact_utils.get_single_instance(
output_dict['output_examples'])
else:
output_examples = None
if 'examples' not in input_dict:
raise ValueError('`examples` is missing in input dict.')
if 'model' not in input_dict:
raise ValueError('Input models are not valid, model '
'need to be specified.')
if 'model_blessing' in input_dict:
model_blessing = artifact_utils.get_single_instance(
input_dict['model_blessing'])
if not model_utils.is_model_blessed(model_blessing):
logging.info('Model on %s was not blessed', model_blessing.uri)
return
else:
logging.info('Model blessing is not provided, exported model will be '
'used.')
if _CUSTOM_CONFIG_KEY not in exec_properties:
raise ValueError('Input exec properties are not valid, {} '
'need to be specified.'.format(_CUSTOM_CONFIG_KEY))
custom_config = json_utils.loads(
exec_properties.get(_CUSTOM_CONFIG_KEY, 'null'))
if custom_config is not None and not isinstance(custom_config, Dict):
raise ValueError('custom_config in execution properties needs to be a '
'dict.')
ai_platform_serving_args = custom_config.get(SERVING_ARGS_KEY)
if not ai_platform_serving_args:
raise ValueError(
'`ai_platform_serving_args` is missing in `custom_config`')
service_name, api_version = runner.get_service_name_and_api_version(
ai_platform_serving_args)
executor_class_path = name_utils.get_full_name(self.__class__)
with telemetry_utils.scoped_labels(
{telemetry_utils.LABEL_TFX_EXECUTOR: executor_class_path}):
job_labels = telemetry_utils.make_labels_dict()
model = artifact_utils.get_single_instance(input_dict['model'])
model_path = path_utils.serving_model_path(
model.uri, path_utils.is_old_model_artifact(model))
logging.info('Use exported model from %s.', model_path)
# Use model artifact uri to generate model version to guarantee the
# 1:1 mapping from model version to model.
model_version = 'version_' + hashlib.sha256(model.uri.encode()).hexdigest()
inference_spec = self._get_inference_spec(model_path, model_version,
ai_platform_serving_args)
data_spec = bulk_inferrer_pb2.DataSpec()
proto_utils.json_to_proto(exec_properties['data_spec'], data_spec)
output_example_spec = bulk_inferrer_pb2.OutputExampleSpec()
if exec_properties.get('output_example_spec'):
proto_utils.json_to_proto(exec_properties['output_example_spec'],
output_example_spec)
endpoint = custom_config.get(constants.ENDPOINT_ARGS_KEY)
if endpoint and 'regions' in ai_platform_serving_args:
raise ValueError(
'`endpoint` and `ai_platform_serving_args.regions` cannot be set simultaneously'
)
api = discovery.build(
service_name,
api_version,
requestBuilder=telemetry_utils.TFXHttpRequest,
client_options=client_options.ClientOptions(api_endpoint=endpoint),
)
new_model_endpoint_created = False
try:
new_model_endpoint_created = runner.create_model_for_aip_prediction_if_not_exist(
job_labels, ai_platform_serving_args, api)
runner.deploy_model_for_aip_prediction(
serving_path=model_path,
model_version_name=model_version,
ai_platform_serving_args=ai_platform_serving_args,
api=api,
labels=job_labels,
skip_model_endpoint_creation=True,
set_default=False,
)
self._run_model_inference(data_spec, output_example_spec,
input_dict['examples'], output_examples,
inference_result, inference_spec)
except Exception as e:
logging.error('Error in executing CloudAIBulkInferrerComponent: %s',
str(e))
raise
finally:
# Guarantee newly created resources are cleaned up even if the inference
# job failed.
# Clean up the newly deployed model.
runner.delete_model_from_aip_if_exists(
model_version_name=model_version,
ai_platform_serving_args=ai_platform_serving_args,
api=api,
delete_model_endpoint=new_model_endpoint_created)