in tfx/extensions/google_cloud_ai_platform/prediction_clients.py [0:0]
def deploy_model(self,
serving_path: str,
model_version_name: str,
ai_platform_serving_args: Dict[str, Any],
labels: Dict[str, str],
serving_container_image_uri: str,
endpoint_region: str,
skip_model_endpoint_creation: Optional[bool] = False,
set_default: Optional[bool] = True,
**kwargs) -> str:
"""Deploys a model for serving with AI Platform.
Args:
serving_path: The path to the model. Must be a GCS URI. Required for model
creation. If not specified, it is assumed that model with model_name
exists in AIP.
model_version_name: Name of the Vertex model being deployed. Must be
different from what is currently being served, if there is an existing
model at the specified endpoint with endpoint_name.
ai_platform_serving_args: Dictionary containing arguments for pushing to
AI Platform. The full set of parameters supported can be found at
https://googleapis.dev/python/aiplatform/latest/aiplatform.html?highlight=deploy#google.cloud.aiplatform.Model.deploy.
Most keys are forwarded as-is, but following keys are handled specially:
- endpoint_name: Name of the endpoint.
- traffic_percentage: Desired traffic to newly deployed model.
Forwarded as-is if specified. If not specified, it is set to 100 if
set_default_version is True, or set to 0 otherwise.
- labels: a list of job labels will be merged with user's input.
labels: The dict of labels that will be attached to this endpoint. They
are merged with optional labels from `ai_platform_serving_args`.
serving_container_image_uri: The path to the serving container image URI.
Container registry for prediction is available at:
https://gcr.io/cloud-aiplatform/prediction.
endpoint_region: Region for Vertex Endpoint. For available regions, please
see https://cloud.google.com/vertex-ai/docs/general/locations
skip_model_endpoint_creation: If true, the method assumes endpoint already
exists in AI platform, therefore skipping endpoint creation.
set_default: Whether set the newly deployed model as the default (i.e.
100% traffic).
**kwargs: Extra keyword args.
Returns:
The resource name of the deployed model.
"""
logging.info(
'Deploying to model to AI Platform for serving: %s',
ai_platform_serving_args)
if sys.version_info[:2] != (3, 7):
logging.warn('Current python version is not the same as default of 3.7.')
if ai_platform_serving_args.get('project_id'):
assert 'project' not in ai_platform_serving_args, ('`project` and '
'`project_id` should '
'not be set at the '
'same time in serving '
'args')
logging.warn('Replacing `project_id` with `project` in serving args.')
ai_platform_serving_args['project'] = ai_platform_serving_args[
'project_id']
ai_platform_serving_args.pop('project_id')
project = ai_platform_serving_args['project']
# Initialize the AI Platform client
# location defaults to 'us-central-1' if not specified
aiplatform.init(project=project, location=endpoint_region)
endpoint_name = ai_platform_serving_args['endpoint_name']
if not skip_model_endpoint_creation:
self.create_model_for_aip_prediction_if_not_exist(
labels, ai_platform_serving_args)
endpoint = self._get_endpoint(ai_platform_serving_args)
deploy_body = dict(ai_platform_serving_args)
for unneeded_key in ['endpoint_name', 'project', 'regions', 'labels']:
deploy_body.pop(unneeded_key, None)
deploy_body['traffic_percentage'] = deploy_body.get(
'traffic_percentage', 100 if set_default else 0)
logging.info(
'Creating model_name %s in project %s at endpoint %s, request body: %s',
model_version_name, project, endpoint_name, deploy_body)
model = aiplatform.Model.upload(
display_name=model_version_name,
artifact_uri=serving_path,
serving_container_image_uri=serving_container_image_uri)
model.wait()
try:
# Push to AI Platform and wait for deployment to be complete.
model.deploy(endpoint=endpoint, **deploy_body)
model.wait()
except errors.HttpError as e:
# If the error is to create an already existing model, it's ok to
# ignore.
if e.resp.status == 409:
logging.warn('Model %s already exists at endpoint %s',
model_version_name, endpoint_name)
else:
raise RuntimeError(
'Creating model version to AI Platform failed.') from e
logging.info(
'Successfully deployed model %s to endpoint %s, serving from %s',
model_version_name, endpoint_name, endpoint.resource_name)
return model.resource_name