spotify_tensorflow/tfx/utils.py (58 lines of code) (raw):

# -*- coding: utf-8 -*- # # Copyright 2017-2019 Spotify AB. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # import os import tempfile import textwrap from typing import Any # noqa: F401 from pbr.version import VersionInfo from spotify_tensorflow.luigi.utils import to_snake_case def assert_not_none(arg): # type: (Any) -> None if arg is None: raise TypeError("Argument can't be a None") def assert_not_empty_string(arg): # type: (Any) -> None if not isinstance(arg, str): raise TypeError("Argument should be a string") if arg == "": raise ValueError("Argument can't be an empty string") def create_setup_file(): lib_version = VersionInfo("spotify_tensorflow").version_string() contents_for_setup_file = """ import setuptools if __name__ == "__main__": setuptools.setup( name="spotify_tensorflow_dataflow", packages=setuptools.find_packages(), install_requires=[ "spotify-tensorflow=={version}" ]) """.format(version=lib_version) # noqa: W293 setup_file_path = os.path.join(tempfile.mkdtemp(), "setup.py") with open(setup_file_path, "w") as f: f.writelines(textwrap.dedent(contents_for_setup_file)) return setup_file_path def clean_up_pipeline_args(pipeline_args): output_args = list() for arg in pipeline_args: if arg.startswith("--"): if "=" in arg: k, v = arg.split("=") output_args.extend([to_snake_case(k), v]) else: output_args.append(to_snake_case(arg)) else: output_args.append(arg) keys = output_args[0::2] vals = output_args[1::2] return ["%s=%s" % (key, val) for (key, val) in zip(keys, vals) if key in SUPPORTED_DATAFLOW_PIPELINE_ARGS] SUPPORTED_DATAFLOW_PIPELINE_ARGS = { "--runner", "--project", "--staging_location", "--zone", "--region", "--temp_location", "--num_workers", "--autoscaling_algorithm", "--max_num_workers", "--network", "--subnetwork", "--disk_size_gb", "--worker_machine_type", "--job_name", "--worker_disk_type", "--service_account_email", "--requirements_file", "--setup_file", "--experiments", }