spotify_tensorflow/scripts/tfr_read.py (55 lines of code) (raw):

# -*- coding: utf-8 -*- # # Copyright 2017-2019 Spotify AB. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from __future__ import absolute_import, division, print_function import argparse import errno import os import sys import tensorflow as tf from tensorflow.python.lib.io import file_io from spotify_tensorflow.example_decoders import ExampleWithFeatureSpecDecoder, ExampleDecoder from spotify_tensorflow.tf_schema_utils import schema_file_to_feature_spec def resolve_schema(dir, default_schema=None): if default_schema is not None: return default_schema for schema_file_name in ["_schema.pb", "_inferred_schema.pb"]: s = os.path.join(dir, schema_file_name) if file_io.file_exists(s): return s def list_tf_records(paths, default_schema): for p in paths: files = [f for f in file_io.get_matching_files(p) if f.endswith(".tfrecords")] if len(files) == 0: raise Exception("Couldn't find any .tfrecords file in path or glob [{}]".format(p)) for f in files: yield f, resolve_schema(os.path.dirname(f), default_schema) def get_decoder_from_schema(schema): if schema is None: return ExampleDecoder() else: feature_spec = schema_file_to_feature_spec(schema) return ExampleWithFeatureSpecDecoder(feature_spec) def tfr_read_to_json(tf_records_paths, schema_path=None): if schema_path is not None: assert file_io.file_exists(schema_path), "File not found: {}".format(schema_path) for tf_record_file, schema in list_tf_records(tf_records_paths, schema_path): assert file_io.file_exists(tf_record_file), "File not found: {}".format(tf_record_file) decoder = get_decoder_from_schema(schema) for record in tf.python_io.tf_record_iterator(tf_record_file): yield decoder.to_json(record) def main(): parser = argparse.ArgumentParser(description="Output TFRecords as JSON") parser.add_argument("-s", "--schema", help="Path to Schema protobuf file. Uses Example if not " "supplied.") parser.add_argument("tf_records_paths", metavar="TF_RECORDS_PATH", nargs="+", help="TFRecords file (or directory containing .tfrecords files)") args = parser.parse_args() try: for json_str in tfr_read_to_json(args.tf_records_paths, args.schema): print(json_str) except IOError as e: if e.errno == errno.EPIPE: sys.exit(0) raise e if __name__ == "__main__": main()