spotify_tensorflow/tf_schema_utils.py (29 lines of code) (raw):
# -*- coding: utf-8 -*-
#
# Copyright 2017-2019 Spotify AB.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function
from typing import Union, Dict # noqa: F401
import google.protobuf.text_format
import tensorflow as tf # noqa: F401
from tensorflow.python.lib.io import file_io
from tensorflow_metadata.proto.v0.schema_pb2 import Schema
from tensorflow_transform.tf_metadata import schema_utils
def feature_spec_to_schema(feature_spec):
# type: (Dict[str, Union[tf.FixedLenFeature, tf.VarLenFeature, tf.SparseFeature]]) -> Schema
"""
Convert a Tensorflow feature_spec object to a tf.metadata Schema.
"""
return schema_utils.schema_from_feature_spec(feature_spec)
def parse_schema_file(schema_path): # type: (str) -> Schema
"""
Read a schema file and return the proto object.
"""
assert file_io.file_exists(schema_path), "File not found: {}".format(schema_path)
schema = Schema()
with file_io.FileIO(schema_path, "rb") as f:
schema.ParseFromString(f.read())
return schema
def parse_schema_txt_file(schema_path): # type: (str) -> Schema
"""
Parse a tf.metadata Schema txt file into its in-memory representation.
"""
assert file_io.file_exists(schema_path), "File not found: {}".format(schema_path)
schema = Schema()
schema_text = file_io.read_file_to_string(schema_path)
google.protobuf.text_format.Parse(schema_text, schema)
return schema
def schema_to_feature_spec(schema):
# type: (Schema) -> Dict[str, Union[tf.FixedLenFeature, tf.VarLenFeature, tf.SparseFeature]]
"""
Convert a tf.metadata Schema to a Tensorflow feature_spec object.
"""
return schema_utils.schema_as_feature_spec(schema).feature_spec
def schema_file_to_feature_spec(schema_path):
# type: (str) -> Dict[str, Union[tf.FixedLenFeature, tf.VarLenFeature, tf.SparseFeature]]
"""
Convert a serialized tf.metadata Schema file to a Tensorflow feature_spec object
"""
schema = parse_schema_file(schema_path)
return schema_to_feature_spec(schema)
def schema_txt_file_to_feature_spec(schema_path):
# type: (str) -> Dict[str, Union[tf.FixedLenFeature, tf.VarLenFeature, tf.SparseFeature]]
"""
Convert a tf.metadata Schema text file to a TensorFlow feature_spec object.
"""
schema = parse_schema_txt_file(schema_path)
return schema_to_feature_spec(schema)