in spotify_tensorflow/dataset.py [0:0]
def _examples(cls,
file_pattern, # type: str
schema_path=None, # type: str
feature_spec=None, # type: Dict[str, Union[tf.FixedLenFeature, tf.VarLenFeature, tf.SparseFeature]] # noqa: E501
default_value=0, # type: float
compression_type=None, # type: str
batch_size=128, # type: int
shuffle=True, # type: bool
num_epochs=1, # type: int
shuffle_buffer_size=10000, # type: int
shuffle_seed=None, # type: int
prefetch_buffer_size=1, # type: int
reader_num_threads=1, # type: int
parser_num_threads=2, # type: int
sloppy_ordering=False, # type: bool
drop_final_batch=False # type: bool
):
# type: (...) -> Iterator[Dict[str, np.ndarray]]
Datasets._assert_eager("Dictionary")
def get_numpy(tensor):
if isinstance(tensor, tf.Tensor):
return tensor.numpy()
elif isinstance(tensor, tf.SparseTensor):
# If it's a SparseTensor, which is the representation of VarLenFeature and
# SparseFeature, we convert it to dense representation, and further is it's
# a scalar, we reshape to to a vector
shape = tensor.dense_shape.numpy()
# first element is batch size
if shape[1] == 0:
# this feature is not defined for any of the examples in the batch
return np.repeat(default_value, shape[0])
numpy_dense = tf.sparse_tensor_to_dense(tensor,
default_value=default_value).numpy()
if shape[1] == 1:
# this is scalar feature, reshape to a vector
return numpy_dense.reshape(shape[0])
else:
return numpy_dense
else:
raise ValueError("This type %s is not supported!", type(tensor).__name__)
dataset = Datasets._examples(file_pattern=file_pattern,
schema_path=schema_path,
feature_spec=feature_spec,
compression_type=compression_type,
batch_size=batch_size,
shuffle=shuffle,
num_epochs=num_epochs,
shuffle_buffer_size=shuffle_buffer_size,
shuffle_seed=shuffle_seed,
prefetch_buffer_size=prefetch_buffer_size,
reader_num_threads=reader_num_threads,
parser_num_threads=parser_num_threads,
sloppy_ordering=sloppy_ordering,
drop_final_batch=drop_final_batch)
for batch in dataset:
yield {name: get_numpy(eager_tensor) for name, eager_tensor in six.iteritems(batch)}