spotify_tensorflow/luigi/utils.py (63 lines of code) (raw):
# -*- coding: utf-8 -*-
#
# Copyright 2017-2019 Spotify AB.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
from __future__ import absolute_import, division, print_function
import re
import os
import subprocess
import tempfile
import requests
def to_snake_case(s, sep="_"):
# type: (str, str) -> str
p = r"\1" + sep + r"\2"
s1 = re.sub("(.)([A-Z][a-z]+)", p, s)
return re.sub("([a-z0-9])([A-Z])", p, s1).lower()
def is_gcs_path(path):
# type: (str) -> bool
"""Returns True if given path is GCS path, False otherwise."""
return path.strip().lower().startswith("gs://")
def get_uri(target):
if hasattr(target, "uri"):
return target.uri()
elif hasattr(target, "path"):
return target.path
else:
raise ValueError("Unknown input target type: %s" % target.__class__.__name__)
def run_with_logging(cmd, logger):
"""
Run cmd and wait for it to finish. While cmd is running, we read it's
output and print it to a logger.
"""
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
output_lines = []
while True:
line = process.stdout.readline()
if not line:
break
line = line.decode("utf-8")
output_lines += [line]
logger.info(line.rstrip("\n"))
exit_code = process.wait()
if exit_code:
output = "".join(output_lines)
raise subprocess.CalledProcessError(exit_code, cmd, output=output)
return exit_code
def _fetch_file(url, output_path=None):
# type: (str, str) -> str
"""Fetches a file from the url and saves it to a temp file (or at the provided output path)."""
rep = requests.get(url, allow_redirects=True)
if rep.status_code / 100 != 2:
raise Exception("Got [status_code:{}] fetching file at [url:{}]".format(rep.status_code,
url))
if output_path is None:
output_path = tempfile.NamedTemporaryFile(delete=False).name
with open(output_path, "wb") as out:
out.write(rep.content)
return output_path
def fetch_tfdv_whl(version=None, output_path=None, platform="manylinux1"):
# type: (str, str, str) -> str
"""Fetches the TFDV pip package from PyPI and saves it to a temporary file (or the provided
output path). Returns the path to the fetched package."""
package_name = "tensorflow_data_validation"
if version is None:
import tensorflow_data_validation as tfdv
version = tfdv.__version__
pypi_base = "https://pypi.org/simple/{}".format(package_name)
package_url = None
with open(_fetch_file(pypi_base)) as listing_html:
for line in listing_html:
if version in line and platform in line:
package_url = re.findall(".*href=\"([^ ]*)#[^ ]*\".*", line)[0]
break
if package_url is None:
raise Exception("Problem fetching package. Couldn't parse listing at [url:{}]"
.format(pypi_base))
if output_path is None:
temp_dir = tempfile.mkdtemp()
# Note: output_path file name must exactly match the remote wheel name.
output_path = os.path.join(temp_dir, package_url.split("/")[-1])
return _fetch_file(package_url, output_path=output_path)