in tfx/extensions/google_cloud_ai_platform/tuner/executor.py [0:0]
def Do(self, input_dict: Dict[str, List[types.Artifact]],
output_dict: Dict[str, List[types.Artifact]],
exec_properties: Dict[str, Any]) -> None:
"""Starts a Tuner component as a job on Google Cloud AI Platform."""
self._log_startup(input_dict, output_dict, exec_properties)
custom_config = json_utils.loads(
exec_properties.get(standard_component_specs.CUSTOM_CONFIG_KEY, 'null'))
if custom_config is None:
raise ValueError('custom_config is not provided')
if not isinstance(custom_config, Dict):
raise TypeError('custom_config in execution properties must be a dict, '
'but received %s' % type(custom_config))
training_inputs = custom_config.get(TUNING_ARGS_KEY)
if training_inputs is None:
err_msg = ('\'%s\' not found in custom_config.' % TUNING_ARGS_KEY)
logging.error(err_msg)
raise ValueError(err_msg)
training_inputs = training_inputs.copy()
tune_args = tuner_executor.get_tune_args(exec_properties)
enable_vertex = custom_config.get(constants.ENABLE_VERTEX_KEY, False)
vertex_region = custom_config.get(constants.VERTEX_REGION_KEY, None)
num_parallel_trials = (1
if not tune_args else tune_args.num_parallel_trials)
if num_parallel_trials > 1:
# Chief node is also responsible for conducting tuning loop.
desired_worker_count = num_parallel_trials - 1
if enable_vertex:
# worker_pool_specs follows the order detailed below. We make sure the
# number of workers in pool 1 is consistent with num_parallel_trials.
# https://cloud.google.com/vertex-ai/docs/training/distributed-training#configure_a_distributed_training_job
worker_pool_specs = training_inputs['job_spec'].get('worker_pool_specs')
if worker_pool_specs is None or len(worker_pool_specs) < 1:
training_inputs['job_spec']['worker_pool_specs'] = [
# `WorkerPoolSpec` for worker pool 0, primary replica
{
'machine_spec': {
'machine_type': 'n1-standard-8'
},
'replica_count': 1
},
# `WorkerPoolSpec` for worker pool 1
{
'machine_spec': {
'machine_type': 'n1-standard-8'
},
'replica_count': desired_worker_count
}
]
logging.warning('worker_pool_specs are overridden with %s.',
training_inputs['job_spec']['worker_pool_specs'])
elif len(worker_pool_specs) < 2:
# primary replica set but missing workers
worker_specs = {**training_inputs['job_spec']['worker_pool_specs'][0]}
worker_specs['replica_count'] = desired_worker_count
training_inputs['job_spec']['worker_pool_specs'].append(worker_specs)
logging.warning('worker_pool_specs[1] are overridden with %s.',
training_inputs['job_spec']['worker_pool_specs'][1])
elif training_inputs['job_spec']['worker_pool_specs'][1].get(
'replica_count') != desired_worker_count:
training_inputs['job_spec']['worker_pool_specs'][1][
'replica_count'] = desired_worker_count
logging.warning(
'replica_count in worker_pool_specs[1] is overridden with %s.',
desired_worker_count)
else:
if training_inputs.get('workerCount') != desired_worker_count:
logging.warning('workerCount is overridden with %s',
desired_worker_count)
training_inputs['workerCount'] = desired_worker_count
training_inputs['scaleTier'] = 'CUSTOM'
training_inputs['masterType'] = (
training_inputs.get('masterType') or 'standard')
training_inputs['workerType'] = (
training_inputs.get('workerType') or 'standard')
# 'tfx_tuner_YYYYmmddHHMMSS' is the default job ID if not specified.
job_id = (
custom_config.get(ai_platform_trainer_executor.JOB_ID_KEY) or
'tfx_tuner_{}'.format(datetime.datetime.now().strftime('%Y%m%d%H%M%S')))
# TODO(b/160059039): Factor out label creation to a utility function.
executor_class = _WorkerExecutor
executor_class_path = name_utils.get_full_name(executor_class)
# Note: exec_properties['custom_config'] here is a dict.
return runner.start_cloud_training(input_dict, output_dict, exec_properties,
executor_class_path, training_inputs,
job_id, None, enable_vertex,
vertex_region)