spotify_tensorflow/example_decoders.py (28 lines of code) (raw):
# -*- coding: utf-8 -*-
#
# Copyright 2017-2019 Spotify AB.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function
import json
import numpy as np
from google.protobuf.json_format import MessageToJson
from tensorflow.core.example import example_pb2
from tensorflow_transform.coders import example_proto_coder
from tensorflow_transform.tf_metadata import dataset_schema
class ExampleDecoder(object):
"""
Decode a tf.Example payload using the example.proto schema
"""
def to_json(self, example_str): # type: (str) -> str
"""
Converts a single tf.Example to JSon a string
:param example_str: tf.Example payload
"""
ex = example_pb2.Example()
ex.ParseFromString(example_str)
return MessageToJson(ex)
class ExampleWithFeatureSpecDecoder(ExampleDecoder):
"""
Decode a tf.Example payload using a TensorFlow feature_spec
"""
def __init__(self, feature_spec):
super(ExampleWithFeatureSpecDecoder, self).__init__()
schema = dataset_schema.from_feature_spec(feature_spec)
self._coder = example_proto_coder.ExampleProtoCoder(schema)
class _NumpyArrayEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
if isinstance(obj, bytes):
return obj.decode()
return json.JSONEncoder.default(self, obj)
def to_json(self, example_str): # type: (str) -> str
"""
Converts a single tf.Example to Json a string
:param example_str: tf.Example payload
"""
decoded = self._coder.decode(example_str)
decoded_json = json.dumps(decoded, cls=self._NumpyArrayEncoder)
return decoded_json