in tfx/extensions/google_cloud_ai_platform/prediction_clients.py [0:0]
def deploy_model(self,
serving_path: str,
model_version_name: str,
ai_platform_serving_args: Dict[str, Any],
labels: Dict[str, str],
skip_model_endpoint_creation: Optional[bool] = False,
set_default: Optional[bool] = True,
**kwargs) -> str:
"""Deploys a model for serving with AI Platform.
Args:
serving_path: The path to the model. Must be a GCS URI.
model_version_name: Version of the model being deployed. Must be different
from what is currently being served.
ai_platform_serving_args: Dictionary containing arguments for pushing to
AI Platform. The full set of parameters supported can be found at
https://cloud.google.com/ml-engine/reference/rest/v1/projects.models.versions#Version.
Most keys are forwarded as-is, but following keys are handled specially:
- name: this must be empty (and will be filled by pusher).
- deployment_uri: this must be empty (and will be filled by pusher).
- python_version: when left empty, this will be filled by python
version of the environment being used.
- runtime_version: when left empty, this will be filled by TensorFlow
version from the environment.
- labels: a list of job labels will be merged with user's input.
labels: The dict of labels that will be attached to this job. They are
merged with optional labels from `ai_platform_serving_args`.
skip_model_endpoint_creation: If true, the method assumes model already
exists in AI platform, therefore skipping model creation.
set_default: Whether set the newly deployed model version as the
default version.
**kwargs: Extra keyword args.
Returns:
The resource name of the deployed model version.
Raises:
RuntimeError: if an error is encountered when trying to push.
"""
logging.info(
'Deploying to model with version %s to AI Platform for serving: %s',
model_version_name, ai_platform_serving_args)
if (sys.version_info.major != 3) and (sys.version_info.minor != 7):
logging.warn('Current python version is not the same as default of 3.7.')
model_name = ai_platform_serving_args['model_name']
project_id = ai_platform_serving_args['project_id']
default_runtime_version = _get_tf_runtime_version(tf.__version__)
runtime_version = ai_platform_serving_args.get('runtime_version',
default_runtime_version)
python_version = '3.7'
if not skip_model_endpoint_creation:
self.create_model_for_aip_prediction_if_not_exist(
labels, ai_platform_serving_args)
version_body = dict(ai_platform_serving_args)
for model_only_key in ['model_name', 'project_id', 'regions']:
version_body.pop(model_only_key, None)
version_body['name'] = model_version_name
version_body['deployment_uri'] = serving_path
version_body['runtime_version'] = version_body.get('runtime_version',
runtime_version)
version_body['python_version'] = version_body.get('python_version',
python_version)
version_body['labels'] = {**version_body.get('labels', {}), **labels}
logging.info(
'Creating new version of model_name %s in project %s, request body: %s',
model_name, project_id, version_body)
# Push to AIP, and record the operation name so we can poll for its state.
model_name = 'projects/{}/models/{}'.format(project_id, model_name)
try:
operation = self._client.projects().models().versions().create(
body=version_body, parent=model_name).execute()
self._wait_for_operation(
operation, 'projects.models.versions.create')
except errors.HttpError as e:
# If the error is to create an already existing model version, it's ok to
# ignore.
if e.resp.status == 409:
logging.warn('Model version %s already exists', model_version_name)
else:
raise RuntimeError('Creating model version to AI Platform failed: {}'
.format(e))
if set_default:
# Set the new version as default.
# By API specification, if Long-Running-Operation is done and there is
# no error, 'response' is guaranteed to exist.
self._client.projects().models().versions().setDefault(
name='{}/versions/{}'.format(model_name,
model_version_name)).execute()
logging.info(
'Successfully deployed model %s with version %s, serving from %s',
model_name, model_version_name, serving_path)
return _CAIP_MODEL_VERSION_PATH_FORMAT.format(
project_id=project_id, model=model_name, version=model_version_name)