def Do()

in tfx/extensions/google_cloud_ai_platform/bulk_inferrer/executor.py [0:0]


  def Do(self, input_dict: Dict[str, List[types.Artifact]],
         output_dict: Dict[str, List[types.Artifact]],
         exec_properties: Dict[str, Any]) -> None:
    """Runs batch inference on a given model with given input examples.

    This function creates a new model (if necessary) and a new model version
    before inference, and cleans up resources after inference. It provides
    re-executability as it cleans up (only) the model resources that are created
    during the process even inference job failed.

    Args:
      input_dict: Input dict from input key to a list of Artifacts.
        - examples: examples for inference.
        - model: exported model.
        - model_blessing: model blessing result
      output_dict: Output dict from output key to a list of Artifacts.
        - output: bulk inference results.
      exec_properties: A dict of execution properties.
        - data_spec: JSON string of bulk_inferrer_pb2.DataSpec instance.
        - custom_config: custom_config.ai_platform_serving_args need to contain
          the serving job parameters sent to Google Cloud AI Platform. For the
          full set of parameters, refer to
          https://cloud.google.com/ml-engine/reference/rest/v1/projects.models

    Returns:
      None
    """
    self._log_startup(input_dict, output_dict, exec_properties)

    if output_dict.get('inference_result'):
      inference_result = artifact_utils.get_single_instance(
          output_dict['inference_result'])
    else:
      inference_result = None
    if output_dict.get('output_examples'):
      output_examples = artifact_utils.get_single_instance(
          output_dict['output_examples'])
    else:
      output_examples = None

    if 'examples' not in input_dict:
      raise ValueError('`examples` is missing in input dict.')
    if 'model' not in input_dict:
      raise ValueError('Input models are not valid, model '
                       'need to be specified.')
    if 'model_blessing' in input_dict:
      model_blessing = artifact_utils.get_single_instance(
          input_dict['model_blessing'])
      if not model_utils.is_model_blessed(model_blessing):
        logging.info('Model on %s was not blessed', model_blessing.uri)
        return
    else:
      logging.info('Model blessing is not provided, exported model will be '
                   'used.')
    if _CUSTOM_CONFIG_KEY not in exec_properties:
      raise ValueError('Input exec properties are not valid, {} '
                       'need to be specified.'.format(_CUSTOM_CONFIG_KEY))

    custom_config = json_utils.loads(
        exec_properties.get(_CUSTOM_CONFIG_KEY, 'null'))
    if custom_config is not None and not isinstance(custom_config, Dict):
      raise ValueError('custom_config in execution properties needs to be a '
                       'dict.')
    ai_platform_serving_args = custom_config.get(SERVING_ARGS_KEY)
    if not ai_platform_serving_args:
      raise ValueError(
          '`ai_platform_serving_args` is missing in `custom_config`')
    service_name, api_version = runner.get_service_name_and_api_version(
        ai_platform_serving_args)
    executor_class_path = name_utils.get_full_name(self.__class__)
    with telemetry_utils.scoped_labels(
        {telemetry_utils.LABEL_TFX_EXECUTOR: executor_class_path}):
      job_labels = telemetry_utils.make_labels_dict()
    model = artifact_utils.get_single_instance(input_dict['model'])
    model_path = path_utils.serving_model_path(
        model.uri, path_utils.is_old_model_artifact(model))
    logging.info('Use exported model from %s.', model_path)
    # Use model artifact uri to generate model version to guarantee the
    # 1:1 mapping from model version to model.
    model_version = 'version_' + hashlib.sha256(model.uri.encode()).hexdigest()
    inference_spec = self._get_inference_spec(model_path, model_version,
                                              ai_platform_serving_args)
    data_spec = bulk_inferrer_pb2.DataSpec()
    proto_utils.json_to_proto(exec_properties['data_spec'], data_spec)
    output_example_spec = bulk_inferrer_pb2.OutputExampleSpec()
    if exec_properties.get('output_example_spec'):
      proto_utils.json_to_proto(exec_properties['output_example_spec'],
                                output_example_spec)
    endpoint = custom_config.get(constants.ENDPOINT_ARGS_KEY)
    if endpoint and 'regions' in ai_platform_serving_args:
      raise ValueError(
          '`endpoint` and `ai_platform_serving_args.regions` cannot be set simultaneously'
      )
    api = discovery.build(
        service_name,
        api_version,
        requestBuilder=telemetry_utils.TFXHttpRequest,
        client_options=client_options.ClientOptions(api_endpoint=endpoint),
    )
    new_model_endpoint_created = False
    try:
      new_model_endpoint_created = runner.create_model_for_aip_prediction_if_not_exist(
          job_labels, ai_platform_serving_args, api)
      runner.deploy_model_for_aip_prediction(
          serving_path=model_path,
          model_version_name=model_version,
          ai_platform_serving_args=ai_platform_serving_args,
          api=api,
          labels=job_labels,
          skip_model_endpoint_creation=True,
          set_default=False,
      )
      self._run_model_inference(data_spec, output_example_spec,
                                input_dict['examples'], output_examples,
                                inference_result, inference_spec)
    except Exception as e:
      logging.error('Error in executing CloudAIBulkInferrerComponent: %s',
                    str(e))
      raise
    finally:
      # Guarantee newly created resources are cleaned up even if the inference
      # job failed.

      # Clean up the newly deployed model.
      runner.delete_model_from_aip_if_exists(
          model_version_name=model_version,
          ai_platform_serving_args=ai_platform_serving_args,
          api=api,
          delete_model_endpoint=new_model_endpoint_created)