spotify_tensorflow/luigi/python_dataflow_task.py (46 lines of code) (raw):

# -*- coding: utf-8 -*- # # Copyright 2017-2019 Spotify AB. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import logging import os import subprocess import time import luigi from luigi.task import MixinNaiveBulkComplete from spotify_tensorflow.luigi.utils import get_uri, run_with_logging logger = logging.getLogger("luigi-interface") class PythonDataflowTask(MixinNaiveBulkComplete, luigi.Task): """"Luigi wrapper for a dataflow job The following properties can be set: python_script = None # Python script for the dataflow task. project = None # Name of the project owning the dataflow job. staging_location = None # GCS path for staging code packages needed by workers. zone = None # GCE availability zone for launching workers. region = None # GCE region for creating the dataflow job. temp_location = None # GCS path for saving temporary workflow jobs. num_workers = None # The number of workers to start the task with. autoscaling_algorithm = None # Set to "NONE" to disable autoscaling. `num_workers` # will then be used for the job. max_num_workers = None # Used if the autoscaling is enabled. network = None # Network in GCE to be used for launching workers. subnetwork = None # Subnetwork in GCE to be used for launching workers. disk_size_gb = None # Remote worker disk size, if not defined uses default size. worker_machine_type = None # Machine type to create Dataflow worker VMs. If unset, # the Dataflow service will choose a reasonable default. worker_disk_type = None # Specify SSD for local disk or defaults to hard disk. service_account = None # Service account of Dataflow VMs/workers. Default is a default GCE service account. job_name = None # Name of the dataflow job requirements_file = None # Path to a requirements file containing package dependencies. local_runner = False # If local_runner = True, the job uses DirectRunner, otherwise it uses DataflowRunner setup_file = None # Path to a setup Python file containing package dependencies. :Example: class AwesomeJob(PythonDataflowJobTask): python_script = "/path/to/python_script" project = "gcp-project" staging_location = "gs://gcp-project-playground/user/staging" temp_location = "gs://gcp-project-playground/user/tmp" max_num_workers = 20 region = "europe-west1" service_account_email = "service_account@gcp-project.iam.gserviceaccount.com" def output(self): ... """ # Required dataflow args python_script = None # type: str project = None # type: str staging_location = None # type: str # Dataflow requires one and only one of: zone = None # type: str region = None # type: str # Optional dataflow args temp_location = None # type: str num_workers = None # type: int autoscaling_algorithm = None # type: str max_num_workers = None # type: int network = None # type: str subnetwork = None # type: str disk_size_gb = None # type: int worker_machine_type = None # type: str worker_disk_type = None # type: str service_account = None # type: str job_name = None # type: str requirements_file = None # type: str local_runner = False # type: bool setup_file = None # type: str def __init__(self, *args, **kwargs): super(PythonDataflowTask, self).__init__(*args, **kwargs) self._output = self.output() if isinstance(self._output, luigi.Target): self._output = {"output": self._output} if self.job_name is None: # job_name must consist of only the characters [-a-z0-9] cls_name = self.__class__.__name__.replace("_", "-").lower() self.job_name = "{cls_name}-{timestamp}".format(cls_name=cls_name, timestamp=str(int(time.time()))) def on_successful_run(self): """ Callback that gets called right after the dataflow job has finished successfully but before validate_output is run. """ pass def validate_output(self): """ Callback that can be used to validate your output before it is moved to it's final location. Returning false here will cause the job to fail, and output to be removed instead of published. :return: Whether the output is valid or not :rtype: Boolean """ return True def file_pattern(self): """ If one/some of the input target files are not in the pattern of part-*, we can add the key of the required target and the correct file pattern that should be appended in the command line here. If the input target key is not found in this dict, the file pattern will be assumed to be part-* for that target. :return A dictionary of overrided file pattern that is not part-* for the inputs :rtype: Dict of String to String """ return {} def run(self): cmd_line = self._mk_cmd_line() logger.info(" ".join(cmd_line)) try: run_with_logging(cmd_line, logger) except subprocess.CalledProcessError as e: logging.error(e, exc_info=True) # exit luigi with the same exit code as the python dataflow job proccess # In this way users can easily exit the job with code 50 to avoid Styx retries # https://github.com/spotify/styx/blob/master/doc/design-overview.md#workflow-state-graph os._exit(e.returncode) self.on_successful_run() if self.validate_output(): self._publish_outputs() else: raise ValueError("Output is not valid") def _publish_outputs(self): for (name, target) in self._output.items(): if hasattr(target, "publish"): target.publish(self._output_uris[name]) def _mk_cmd_line(self): cmd_line = self._dataflow_executable() cmd_line.extend(self._get_dataflow_args()) cmd_line.extend(self._get_input_args()) cmd_line.extend(self._get_output_args()) cmd_line.extend(self.args()) return cmd_line def _dataflow_executable(self): """ Defines the executable used to run the python dataflow job. """ return ["python", self.python_script] def _get_input_uri(self, file_pattern, target): uri = get_uri(target) uri = uri.rstrip("/") + "/" + file_pattern return uri def _get_file_pattern(self): file_pattern = self.file_pattern() if not isinstance(file_pattern, dict): raise ValueError("file_pattern() must return a dict type") return file_pattern def _get_input_args(self): """ Collects outputs from requires() and converts them to input arguments. file_pattern() is called to construct input file path glob with default value "part-*" """ job_input = self.input() if isinstance(job_input, luigi.Target): job_input = {"input": job_input} if not isinstance(job_input, dict): raise ValueError("Input (requires()) must be dict type") input_args = [] file_pattern_dict = self._get_file_pattern() for (name, targets) in job_input.items(): uri_targets = luigi.task.flatten(targets) pattern = file_pattern_dict.get(name, "part-*") uris = [self._get_input_uri(pattern, uri_target) for uri_target in uri_targets] if isinstance(targets, dict): # If targets is a dict that means it had multiple outputs. # Make the input args in that case "<input key>-<task output key>" names = ["%s-%s" % (name, key) for key in targets.keys()] else: names = [name] * len(uris) for (arg_name, uri) in zip(names, uris): input_args.append("--%s=%s" % (arg_name, uri)) return input_args def _get_output_args(self): if not isinstance(self._output, dict): raise ValueError("Output must be dict type") output_args = [] self._output_uris = {} for (name, target) in self._output.items(): uri = target.generate_uri() if hasattr(target, "generate_uri") else get_uri(target) uri = uri.rstrip("/") output_args.append("--%s=%s" % (name, uri)) self._output_uris[name] = uri return output_args def _get_runner(self): return "DirectRunner" if self.local_runner else "DataflowRunner" def _get_dataflow_args(self): dataflow_args = [] _runner = self._get_runner() if _runner: dataflow_args += ["--runner={}".format(_runner)] if self.project: dataflow_args += ["--project={}".format(self.project)] if self.staging_location: dataflow_args += ["--staging_location={}".format(self.staging_location)] if self.zone: dataflow_args += ["--zone={}".format(self.zone)] if self.region: dataflow_args += ["--region={}".format(self.region)] if self.temp_location: dataflow_args += ["--temp_location={}".format(self.temp_location)] if self.num_workers: dataflow_args += ["--num_workers={}".format(self.num_workers)] if self.autoscaling_algorithm: dataflow_args += ["--autoscaling_algorithm={}".format(self.autoscaling_algorithm)] if self.max_num_workers: dataflow_args += ["--max_num_workers={}".format(self.max_num_workers)] if self.network: dataflow_args += ["--network={}".format(self.network)] if self.subnetwork: dataflow_args += ["--subnetwork={}".format(self.subnetwork)] if self.disk_size_gb: dataflow_args += ["--disk_size_gb={}".format(self.disk_size_gb)] if self.worker_machine_type: dataflow_args += ["--worker_machine_type={}".format(self.worker_machine_type)] if self.job_name: dataflow_args += ["--job_name={}".format(self.job_name)] if self.worker_disk_type: dataflow_args += ["--worker_disk_type={}".format(self.worker_disk_type)] if self.service_account: dataflow_args += ["--service_account_email={}".format(self.service_account)] if self.requirements_file: dataflow_args += ["--requirements_file={}".format(self.requirements_file)] if self.setup_file: dataflow_args += ["--setup_file={}".format(self.setup_file)] return dataflow_args def args(self): """ Extra arguments that will be passed to your dataflow job. Example: return ["--project=my-gcp-project", "--zone=a-zone", "--staging_location=gs://my-gcp-project/dataflow"] Note that: * You "set" args by overriding this method in your subclass. * This function should return an iterable of strings. """ return [] def get_output_uris(self): """ Returns a dictionary that contains output uris. The key is the name of the output target defined in output(), and the value is the path/uri of the output target. It can be used to write data to different sub directories under one output target. :return A dictionary of output uris :rtype: Dict of String to String """ return self._output_uris