spotify_tensorflow/tfx/tfdv.py (106 lines of code) (raw):

# -*- coding: utf-8 -*- # # Copyright 2017-2019 Spotify AB. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. # import logging import os import time from os.path import join as pjoin from typing import List, Optional # noqa: F401 import tensorflow_data_validation as tfdv from tensorflow_data_validation.statistics.stats_options import StatsOptions from apache_beam.options.pipeline_options import GoogleCloudOptions, PipelineOptions, SetupOptions from tensorflow_metadata.proto.v0.statistics_pb2 import DatasetFeatureStatisticsList # noqa: F401 from spotify_tensorflow.tfx.utils import create_setup_file, assert_not_empty_string, \ clean_up_pipeline_args from spotify_tensorflow.tf_schema_utils import parse_schema_txt_file, parse_schema_file from tensorflow_metadata.proto.v0 import statistics_pb2 # noqa: F401 from tensorflow.python.lib.io import file_io logger = logging.getLogger("spotify-tensorflow") class TfDataValidator(object): """Spotify-specific API for using Tensorflow Data Validation in production. The typical usage is to create an instance of this class from a Luigi task that produces tfrecord files in order to produce statistics, a schema snapshot and any anomalies along with the dataset. """ def __init__(self, schema_path, # type: Optional[str] data_location, # type: str binary_schema=False, # type: bool stats_options=StatsOptions() # type: StatsOptions ): """ :param schema_path: tf.metadata Schema path. Must be in text or binary format. If this is None, a new schema will be inferred automatically from the statistics. :param data_location: input data dir containing tfrecord files :param binary_schema: specifies if the schema is in a binary format :param stats_options: tfdv.StatsOptions for statistics generation settings """ self.data_location = data_location self.schema = None if schema_path: if binary_schema: self.schema = parse_schema_file(schema_path) else: self.schema = parse_schema_txt_file(schema_path) self.schema_snapshot_path = pjoin(self.data_location, "schema_snapshot.pb") self.stats_path = pjoin(self.data_location, "stats.pb") self.anomalies_path = pjoin(self.data_location, "anomalies.pb") self.stats_options = stats_options def write_stats(self, pipeline_args): # type: (List[str]) -> statistics_pb2.DatasetFeatureStatisticsList return generate_statistics_from_tfrecord(pipeline_args=pipeline_args, data_location=self.data_location, output_path=self.stats_path, stats_options=self.stats_options) def write_stats_and_schema(self, pipeline_args, infer_feature_shape=False ): # type: (List[str], bool) -> None stats = self.write_stats(pipeline_args) if not self.schema: logger.warning( "Inferring a new schema for this dataset. If you want to use an existing schema, " "provide a value for schema_path in the constructor." ) new_schema = tfdv.infer_schema(stats, infer_feature_shape=infer_feature_shape) self.schema = new_schema self.upload_schema() def validate_stats_against_schema(self, environment=None, previous_statistics=None, serving_statistics=None, ): # type: (str, DatasetFeatureStatisticsList, DatasetFeatureStatisticsList) -> bool stats = tfdv.load_statistics(self.stats_path) self.anomalies = tfdv.validate_statistics( stats, self.schema, environment=environment, previous_statistics=previous_statistics, serving_statistics=serving_statistics, ) if len(self.anomalies.anomaly_info.items()) > 0: logger.error("Anomalies found in training dataset...") logger.error(str(self.anomalies.anomaly_info.items())) self.upload_anomalies() return False else: logger.info("No anomalies found") return True def upload_schema(self): # type: () -> None if not self.schema: raise ValueError( "Cannot upload a schema since no schema_path was provided. Either provide one, or " "use write_stats_and_schema so that a schema can be inferred first." ) file_io.atomic_write_string_to_file(self.schema_snapshot_path, self.schema.SerializeToString()) def upload_anomalies(self): # type: () -> None if self.anomalies.anomaly_info: file_io.atomic_write_string_to_file(self.anomalies_path, self.anomalies.SerializeToString()) def generate_statistics_from_tfrecord(pipeline_args, # type: List[str] data_location, # type: str output_path, # type: str stats_options # type: StatsOptions ): # type: (...) -> statistics_pb2.DatasetFeatureStatisticsList """ Generate stats file from a tfrecord dataset using TFDV :param pipeline_args: un-parsed Dataflow arguments :param data_location: input data dir containing tfrecord files :param output_path: output path for the stats file :param stats_options: tfdv.StatsOptions for statistics generation settings :return a DatasetFeatureStatisticsList proto. """ assert_not_empty_string(data_location) assert_not_empty_string(output_path) args_in_snake_case = clean_up_pipeline_args(pipeline_args) pipeline_options = PipelineOptions(flags=args_in_snake_case) all_options = pipeline_options.get_all_options() if all_options["job_name"] is None: gcloud_options = pipeline_options.view_as(GoogleCloudOptions) gcloud_options.job_name = "generatestats-%s" % str(int(time.time())) if all_options["setup_file"] is None: setup_file_path = create_setup_file() setup_options = pipeline_options.view_as(SetupOptions) setup_options.setup_file = setup_file_path input_files = os.path.join(data_location, "*.tfrecords*") return tfdv.generate_statistics_from_tfrecord(data_location=input_files, output_path=output_path, stats_options=stats_options, pipeline_options=pipeline_options)