basic_pitch/data/tf_example_serialization.py (78 lines of code) (raw):
#!/usr/bin/env python
# encoding: utf-8
#
# Copyright 2022 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple, Union
import sox
import numpy as np
import tensorflow as tf
from basic_pitch.constants import AUDIO_N_CHANNELS, AUDIO_SAMPLE_RATE
def int64_feature(value: Union[List[int], int]) -> tf.train.Feature:
if not isinstance(value, list):
value = [value]
return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
def float_feature(value: Union[List[float], float]) -> tf.train.Feature:
if not isinstance(value, list):
value = [value]
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def bytes_feature(value: Union[tf.Tensor, List[bytes], bytes]) -> tf.train.Feature:
if isinstance(value, type(tf.constant(0))):
value = value.numpy()
if not isinstance(value, list):
value = [value]
return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))
def _to_transcription_tfex(
file_id: str,
source: str,
encoded_wav: bytes,
notes_indices: List[Tuple[int, int]],
notes_values: List[float],
onsets_indices: List[int],
onsets_values: List[float],
contours_indices: List[int],
contours_values: List[float],
notes_onsets_shape: Tuple[int, int],
contours_shape: Tuple[int, int],
) -> tf.train.Example:
return tf.train.Example(
features=tf.train.Features(
feature={
"file_id": bytes_feature(bytes(file_id, "utf-8")),
"source": bytes_feature(bytes(source, "utf-8")),
"audio_wav": bytes_feature(encoded_wav),
"notes_indices": bytes_feature(tf.io.serialize_tensor(np.array(notes_indices, np.int64))),
"notes_values": bytes_feature(tf.io.serialize_tensor(np.array(notes_values, np.float32))),
"onsets_indices": bytes_feature(tf.io.serialize_tensor(np.array(onsets_indices, np.int64))),
"onsets_values": bytes_feature(tf.io.serialize_tensor(np.array(onsets_values, np.float32))),
"contours_indices": bytes_feature(tf.io.serialize_tensor(np.array(contours_indices, np.int64))),
"contours_values": bytes_feature(tf.io.serialize_tensor(np.array(contours_values, np.float32))),
"notes_onsets_shape": bytes_feature(tf.io.serialize_tensor(np.array(notes_onsets_shape, np.int64))),
"contours_shape": bytes_feature(tf.io.serialize_tensor(np.array(contours_shape, np.int64))),
}
)
)
def to_transcription_tfexample(
file_id: str,
source: str,
audio_wav_file_path: str,
notes_indices: List[Tuple[int, int]],
notes_values: List[float],
onsets_indices: List[int],
onsets_values: List[float],
contours_indices: List[int],
contours_values: List[float],
notes_onsets_shape: Tuple[int, int],
contours_shape: Tuple[int, int],
) -> tf.train.Example:
"""
- `file_id` string
- `source` string (e.g., "maestro")
- `audio_file_path` path to a local WAV file (must be 22kHz stereo, checked)
- `notes_indices` [(time, freq)], integers
- `notes_values` [float]
- `onsets_indices` same as above
- `onsets_values` same as above
- `contours_indices` same as above
- `contours_values` same as above
- `notes_onsets_shape` (time, freq), ints
- `contours_onsets_shape` (time, freq), ints
"""
assert sox.file_info.sample_rate(audio_wav_file_path) == AUDIO_SAMPLE_RATE
assert sox.file_info.channels(audio_wav_file_path) == AUDIO_N_CHANNELS
encoded_wav = open(audio_wav_file_path, "rb").read()
return _to_transcription_tfex(
file_id,
source,
encoded_wav,
notes_indices,
notes_values,
onsets_indices,
onsets_values,
contours_indices,
contours_values,
notes_onsets_shape,
contours_shape,
)