def _launch_cloud_training()

in tfx/extensions/google_cloud_ai_platform/runner.py [0:0]


def _launch_cloud_training(project: str,
                           training_job: Dict[str, Any],
                           enable_vertex: Optional[bool] = False,
                           vertex_region: Optional[str] = None) -> None:
  """Launches and monitors a Cloud custom training job.

  Args:
    project: The GCP project under which the training job will be executed.
    training_job: Training job argument for AI Platform training job. See
      https://cloud.google.com/vertex-ai/docs/reference/rest/v1/projects.locations.customJobs#CustomJob
        for detailed schema for the Vertex CustomJob. See
      https://cloud.google.com/ml-engine/reference/rest/v1/projects.jobs for the
        detailed schema for CAIP Job.
    enable_vertex: Whether to enable Vertex or not.
    vertex_region: Region for endpoint in Vertex training.

  Raises:
    RuntimeError: if the Google Cloud AI Platform training job failed/cancelled.
    ConnectionError: if the status polling of the training job failed due to
      connection issue.
  """
  # TODO(b/185159702): Migrate all training jobs to Vertex and remove the
  #                    enable_vertex switch.
  client = training_clients.get_job_client(enable_vertex, vertex_region)
  # Configure and launch AI Platform training job.
  client.launch_job(project, training_job)

  # Wait for Cloud Training job to finish
  response = client.get_job()
  retry_count = 0

  job_id = client.get_job_name()
  # Monitors the long-running operation by polling the job state periodically,
  # and retries the polling when a transient connectivity issue is encountered.
  #
  # Long-running operation monitoring:
  #   The possible states of "get job" response can be found at
  #   https://cloud.google.com/ai-platform/training/docs/reference/rest/v1/projects.jobs#State
  #   where SUCCEEDED/FAILED/CANCELLED are considered to be final states.
  #   The following logic will keep polling the state of the job until the job
  #   enters a final state.
  #
  # During the polling, if a connection error was encountered, the GET request
  # will be retried by recreating the Python API client to refresh the lifecycle
  # of the connection being used. See
  # https://github.com/googleapis/google-api-python-client/issues/218
  # for a detailed description of the problem. If the error persists for
  # _CONNECTION_ERROR_RETRY_LIMIT consecutive attempts, the function will raise
  # ConnectionError.
  while client.get_job_state(response) not in client.JOB_STATES_COMPLETED:
    time.sleep(_POLLING_INTERVAL_IN_SECONDS)
    try:
      response = client.get_job()
      retry_count = 0
    # Handle transient connection error.
    except ConnectionError as err:
      if retry_count < _CONNECTION_ERROR_RETRY_LIMIT:
        retry_count += 1
        logging.warning(
            'ConnectionError (%s) encountered when polling job: %s. Trying to '
            'recreate the API client.', err, job_id)
        # Recreate the Python API client.
        client.create_client()
      else:
        logging.error('Request failed after %s retries.',
                      _CONNECTION_ERROR_RETRY_LIMIT)
        raise

  if client.get_job_state(response) in client.JOB_STATES_FAILED:
    err_msg = 'Job \'{}\' did not succeed.  Detailed response {}.'.format(
        client.get_job_name(), response)
    logging.error(err_msg)
    raise RuntimeError(err_msg)

  # Cloud training complete
  logging.info('Job \'%s\' successful.', client.get_job_name())