def Do()

in tfx/extensions/google_cloud_ai_platform/tuner/executor.py [0:0]


  def Do(self, input_dict: Dict[str, List[types.Artifact]],
         output_dict: Dict[str, List[types.Artifact]],
         exec_properties: Dict[str, Any]) -> None:
    """Starts a Tuner component as a job on Google Cloud AI Platform."""
    self._log_startup(input_dict, output_dict, exec_properties)

    custom_config = json_utils.loads(
        exec_properties.get(standard_component_specs.CUSTOM_CONFIG_KEY, 'null'))
    if custom_config is None:
      raise ValueError('custom_config is not provided')

    if not isinstance(custom_config, Dict):
      raise TypeError('custom_config in execution properties must be a dict, '
                      'but received %s' % type(custom_config))

    training_inputs = custom_config.get(TUNING_ARGS_KEY)
    if training_inputs is None:
      err_msg = ('\'%s\' not found in custom_config.' % TUNING_ARGS_KEY)
      logging.error(err_msg)
      raise ValueError(err_msg)
    training_inputs = training_inputs.copy()

    tune_args = tuner_executor.get_tune_args(exec_properties)

    enable_vertex = custom_config.get(constants.ENABLE_VERTEX_KEY, False)
    vertex_region = custom_config.get(constants.VERTEX_REGION_KEY, None)
    num_parallel_trials = (1
                           if not tune_args else tune_args.num_parallel_trials)
    if num_parallel_trials > 1:
      # Chief node is also responsible for conducting tuning loop.
      desired_worker_count = num_parallel_trials - 1

      if enable_vertex:
        # worker_pool_specs follows the order detailed below. We make sure the
        # number of workers in pool 1 is consistent with num_parallel_trials.
        # https://cloud.google.com/vertex-ai/docs/training/distributed-training#configure_a_distributed_training_job
        worker_pool_specs = training_inputs['job_spec'].get('worker_pool_specs')
        if worker_pool_specs is None or len(worker_pool_specs) < 1:
          training_inputs['job_spec']['worker_pool_specs'] = [
              # `WorkerPoolSpec` for worker pool 0, primary replica
              {
                  'machine_spec': {
                      'machine_type': 'n1-standard-8'
                  },
                  'replica_count': 1
              },
              # `WorkerPoolSpec` for worker pool 1
              {
                  'machine_spec': {
                      'machine_type': 'n1-standard-8'
                  },
                  'replica_count': desired_worker_count
              }
          ]
          logging.warning('worker_pool_specs are overridden with %s.',
                          training_inputs['job_spec']['worker_pool_specs'])
        elif len(worker_pool_specs) < 2:
          # primary replica set but missing workers
          worker_specs = {**training_inputs['job_spec']['worker_pool_specs'][0]}
          worker_specs['replica_count'] = desired_worker_count
          training_inputs['job_spec']['worker_pool_specs'].append(worker_specs)
          logging.warning('worker_pool_specs[1] are overridden with %s.',
                          training_inputs['job_spec']['worker_pool_specs'][1])
        elif training_inputs['job_spec']['worker_pool_specs'][1].get(
            'replica_count') != desired_worker_count:
          training_inputs['job_spec']['worker_pool_specs'][1][
              'replica_count'] = desired_worker_count
          logging.warning(
              'replica_count in worker_pool_specs[1] is overridden with %s.',
              desired_worker_count)
      else:
        if training_inputs.get('workerCount') != desired_worker_count:
          logging.warning('workerCount is overridden with %s',
                          desired_worker_count)
          training_inputs['workerCount'] = desired_worker_count

        training_inputs['scaleTier'] = 'CUSTOM'
        training_inputs['masterType'] = (
            training_inputs.get('masterType') or 'standard')
        training_inputs['workerType'] = (
            training_inputs.get('workerType') or 'standard')

    # 'tfx_tuner_YYYYmmddHHMMSS' is the default job ID if not specified.
    job_id = (
        custom_config.get(ai_platform_trainer_executor.JOB_ID_KEY) or
        'tfx_tuner_{}'.format(datetime.datetime.now().strftime('%Y%m%d%H%M%S')))

    # TODO(b/160059039): Factor out label creation to a utility function.
    executor_class = _WorkerExecutor
    executor_class_path = name_utils.get_full_name(executor_class)

    # Note: exec_properties['custom_config'] here is a dict.
    return runner.start_cloud_training(input_dict, output_dict, exec_properties,
                                       executor_class_path, training_inputs,
                                       job_id, None, enable_vertex,
                                       vertex_region)