in tfx/extensions/google_cloud_ai_platform/runner.py [0:0]
def _launch_cloud_training(project: str,
training_job: Dict[str, Any],
enable_vertex: Optional[bool] = False,
vertex_region: Optional[str] = None) -> None:
"""Launches and monitors a Cloud custom training job.
Args:
project: The GCP project under which the training job will be executed.
training_job: Training job argument for AI Platform training job. See
https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.customJobs#CustomJob
for detailed schema for the Vertex CustomJob. See
https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs for the
detailed schema for CAIP Job.
enable_vertex: Whether to enable Vertex or not.
vertex_region: Region for endpoint in Vertex training.
Raises:
RuntimeError: if the Google Cloud AI Platform training job failed/cancelled.
ConnectionError: if the status polling of the training job failed due to
connection issue.
"""
# TODO(b/185159702): Migrate all training jobs to Vertex and remove the
# enable_vertex switch.
client = training_clients.get_job_client(enable_vertex, vertex_region)
# Configure and launch AI Platform training job.
client.launch_job(project, training_job)
# Wait for Cloud Training job to finish
response = client.get_job()
retry_count = 0
job_id = client.get_job_name()
# Monitors the long-running operation by polling the job state periodically,
# and retries the polling when a transient connectivity issue is encountered.
#
# Long-running operation monitoring:
# The possible states of "get job" response can be found at
# https://cloud.google.com/ai-platform/training/docs/reference/rest/v1/projects.jobs#State
# where SUCCEEDED/FAILED/CANCELLED are considered to be final states.
# The following logic will keep polling the state of the job until the job
# enters a final state.
#
# During the polling, if a connection error was encountered, the GET request
# will be retried by recreating the Python API client to refresh the lifecycle
# of the connection being used. See
# https://github.com/googleapis/google-api-python-client/issues/218
# for a detailed description of the problem. If the error persists for
# _CONNECTION_ERROR_RETRY_LIMIT consecutive attempts, the function will raise
# ConnectionError.
while client.get_job_state(response) not in client.JOB_STATES_COMPLETED:
time.sleep(_POLLING_INTERVAL_IN_SECONDS)
try:
response = client.get_job()
retry_count = 0
# Handle transient connection error.
except ConnectionError as err:
if retry_count < _CONNECTION_ERROR_RETRY_LIMIT:
retry_count += 1
logging.warning(
'ConnectionError (%s) encountered when polling job: %s. Trying to '
'recreate the API client.', err, job_id)
# Recreate the Python API client.
client.create_client()
else:
logging.error('Request failed after %s retries.',
_CONNECTION_ERROR_RETRY_LIMIT)
raise
if client.get_job_state(response) in client.JOB_STATES_FAILED:
err_msg = 'Job \'{}\' did not succeed. Detailed response {}.'.format(
client.get_job_name(), response)
logging.error(err_msg)
raise RuntimeError(err_msg)
# Cloud training complete
logging.info('Job \'%s\' successful.', client.get_job_name())