spotify_tensorflow/luigi/tensorflow_task.py (120 lines of code) (raw):
# -*- coding: utf-8 -*-
#
# Copyright 2017-2019 Spotify AB.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from __future__ import absolute_import, division, print_function
import getpass
import logging
import uuid
import luigi
from spotify_tensorflow.luigi.utils import get_uri, run_with_logging
logger = logging.getLogger("luigi-interface")
class TensorFlowTask(luigi.Task):
"""Luigi wrapper for a TensorFlow task. To use, extend this class and provide values for the
following properties:
model_package = None The name of the python package containing your model.
model_name = None The name of the python module containing your model.
Ex: if the model is in /foo/models/main.py, you would set
model_package = "models" and model_name = "main"
gcp_project = None The Google Cloud project id to run with ai-platform
region = None The GCP region if running with ai-platform, e.g. europe-west1
model_name_suffix = None A string suffix representing the model name, which will be appended
to the job name.
runtime_version = None The Google Cloud AI Platform runtime version for this job. Defaults
to the latest stable version. See
https://cloud.google.com/ml/docs/concepts/runtime-version-list for a
list of accepted versions.
python_version = None The Google Cloud AI Platform python version for this job. See
https://cloud.google.com/ml-engine/docs/versioning#set-python-version-training
for more information.
scale_tier = None Specifies the machine types, the number of replicas for workers and
parameter servers. SCALE_TIER must be one of:
basic, basic-gpu, basic-tpu, custom, premium-1, standard-1.
Also, you can specify command line arguments for your trainer by overriding the
`def tf_task_args(self)` method.
"""
# Task properties
model_name = luigi.Parameter(description="Name of the python model file")
model_package = luigi.Parameter(description="Python package containing your model")
model_package_path = luigi.Parameter(description="Absolute path to the model package")
gcp_project = luigi.Parameter(description="GCP project", default=None)
region = luigi.Parameter(description="GCP region", default=None)
model_name_suffix = luigi.Parameter(description="String which will be appended to the job"
" name. Useful for finding jobs in the"
" ai-platform UI.", default=None)
# Task parameters
cloud = luigi.BoolParameter(description="Run on ai-platform")
blocking = luigi.BoolParameter(default=True, description="Run in stream-logs/blocking mode")
job_dir = luigi.Parameter(description="A job directory, used to store snapshots, logs and any "
"other artifacts. A trailing '/' is required for "
"'gs://' paths.")
ai_platform_conf = luigi.Parameter(default=None,
description="An ai-platform YAML configuration file.")
tf_debug = luigi.BoolParameter(default=False, description="Run tf on debug mode")
runtime_version = luigi.Parameter(default=None,
description="The Google Cloud AI Platform runtime version "
"for this job.")
python_version = luigi.Parameter(
default=None,
description="The Google Cloud AI Platform python version for this job. See "
"https://cloud.google.com/ml-engine/docs/versioning#set-python-version-training"
"for more information."
)
scale_tier = luigi.Parameter(default=None,
description="Specifies the machine types, the number of replicas "
"for workers and parameter servers.")
def __init__(self, *args, **kwargs):
super(TensorFlowTask, self).__init__(*args, **kwargs)
def tf_task_args(self):
"""A list of args to pass to the tf main module."""
return []
def run(self):
cmd = self._mk_cmd()
logger.info("Running:\n```\n%s\n```", cmd)
run_with_logging(cmd, logger)
logger.info("Training finished.")
def get_job_dir(self):
"""Get job directory used to store snapshots, logs, final output and any other artifacts."""
return self.job_dir
def _mk_cmd(self):
cmd = ["gcloud", "ai-platform"]
if self.cloud:
cmd.extend(self._mk_cloud_params())
else:
cmd.extend(["local", "train"])
cmd.extend(self._get_model_args())
if self.tf_debug:
cmd.append("--verbosity=debug")
cmd.extend(self._get_job_args())
return cmd
def _mk_cloud_params(self):
params = []
if self.gcp_project:
params.append("--project=%s" % self.gcp_project)
params.extend(["jobs", "submit", "training", self._get_job_name()])
if self.region:
params.append("--region=%s" % self.region)
if self.ai_platform_conf:
params.append("--config=%s" % self.ai_platform_conf)
params.append("--job-dir=%s" % self.get_job_dir())
if self.blocking:
params.append("--stream-logs") # makes the execution "blocking"
if self.runtime_version:
params.append("--runtime-version=%s" % self.runtime_version)
if self.python_version:
params.append("--python-version=%s" % self.python_version)
if self.scale_tier:
params.append("--scale-tier=%s" % self.scale_tier)
return params
def _get_model_args(self):
args = []
if self.model_package_path:
args.append("--package-path=%s" % self.model_package_path)
if self.model_name:
module_name = self.model_name
if self.model_package:
module_name = "{package}.{module}".format(package=self.model_package,
module=module_name)
args.append("--module-name=" + module_name)
return args
def _get_job_args(self):
args = ["--"]
args.extend(self._get_input_args())
if not self.cloud:
args.append("--job-dir=%s" % self.get_job_dir())
args.extend(self.tf_task_args())
return args
def _get_job_name(self):
job_name = "%s_%s_%s_%s" % (
getpass.getuser(),
self.__class__.__name__,
self.model_name_suffix,
str(uuid.uuid4()).replace("-", "_"))
return job_name
def _get_input_args(self):
# TODO(brianm): this doesn't work when subclass yields from `requires`
job_input = self.input()
if isinstance(job_input, luigi.Target):
job_input = {"input": job_input}
if len(job_input) == 0: # default requires()
return []
if not isinstance(job_input, dict):
raise ValueError("Input (requires()) must be dict type")
input_args = []
for (name, targets) in job_input.items():
uris = [get_uri(target) for target in luigi.task.flatten(targets)]
if isinstance(targets, dict):
# If targets is a dict that means it had multiple outputs. In this case make the
# input args "<input key>-<task output key>"
names = ["%s-%s" % (name, key) for key in targets.keys()]
else:
names = [name] * len(uris)
for (arg_name, uri) in zip(names, uris):
input_args.append("--%s=%s" % (arg_name, uri))
return input_args