basic_pitch/inference.py (393 lines of code) (raw):
#!/usr/bin/env python
# encoding: utf-8
#
# Copyright 2022 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import csv
import enum
import json
import logging
import os
import pathlib
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union, cast
from basic_pitch import CT_PRESENT, ICASSP_2022_MODEL_PATH, ONNX_PRESENT, TF_PRESENT, TFLITE_PRESENT
try:
import tensorflow as tf
except ImportError:
pass
try:
import coremltools as ct
except ImportError:
pass
try:
import tflite_runtime.interpreter as tflite
except ImportError:
if TF_PRESENT:
import tensorflow.lite as tflite
try:
import onnxruntime as ort
except ImportError:
pass
import numpy as np
import numpy.typing as npt
import librosa
import pretty_midi
from basic_pitch.constants import (
AUDIO_SAMPLE_RATE,
AUDIO_N_SAMPLES,
ANNOTATIONS_FPS,
FFT_HOP,
)
from basic_pitch.commandline_printing import (
generating_file_message,
no_tf_warnings,
file_saved_confirmation,
failed_to_save,
)
import basic_pitch.note_creation as infer
class Model:
class MODEL_TYPES(enum.Enum):
TENSORFLOW = enum.auto()
COREML = enum.auto()
TFLITE = enum.auto()
ONNX = enum.auto()
def __init__(self, model_path: Union[pathlib.Path, str]):
present = []
if TF_PRESENT:
present.append("TensorFlow")
try:
self.model_type = Model.MODEL_TYPES.TENSORFLOW
self.model = tf.saved_model.load(str(model_path))
return
except Exception as e:
if os.path.isdir(model_path) and {"saved_model.pb", "variables"} & set(os.listdir(model_path)):
logging.warning(
"Could not load TensorFlow saved model %s even "
"though it looks like a saved model file with error %s. "
"Are you sure it's a TensorFlow saved model?",
model_path,
e.__repr__(),
)
if CT_PRESENT:
present.append("CoreML")
try:
self.model_type = Model.MODEL_TYPES.COREML
self.model = ct.models.MLModel(str(model_path), compute_units=ct.ComputeUnit.CPU_ONLY)
return
except Exception as e:
if str(model_path).endswith(".mlpackage"):
logging.warning(
"Could not load CoreML file %s even "
"though it looks like a CoreML file with error %s. "
"Are you sure it's a CoreML file?",
model_path,
e.__repr__(),
)
if TFLITE_PRESENT or TF_PRESENT:
present.append("TensorFlowLite")
try:
self.model_type = Model.MODEL_TYPES.TFLITE
self.interpreter = tflite.Interpreter(str(model_path))
self.model = self.interpreter.get_signature_runner()
return
except Exception as e:
if str(model_path).endswith(".tflite"):
logging.warning(
"Could not load TensorFlowLite file %s even "
"though it looks like a TFLite file with error %s. "
"Are you sure it's a TFLite file?",
model_path,
e.__repr__(),
)
if ONNX_PRESENT:
present.append("ONNX")
try:
self.model_type = Model.MODEL_TYPES.ONNX
providers = ["CPUExecutionProvider"]
if "CUDAExecutionProvider" in ort.get_available_providers():
providers.insert(0, "CUDAExecutionProvider")
self.model = ort.InferenceSession(str(model_path), providers=providers)
return
except Exception as e:
if str(model_path).endswith(".onnx"):
logging.warning(
"Could not load ONNX file %s even "
"though it looks like a ONNX file with error %s. "
"Are you sure it's a ONNX file?",
model_path,
e.__repr__(),
)
raise ValueError(
f"File {model_path} cannot be loaded into either "
"TensorFlow, CoreML, TFLite or ONNX. "
"Please check if it is a supported and valid serialized model "
"and that one of these packages are installed. On this system, "
f"{present} is installed."
)
def predict(self, x: npt.NDArray[np.float32]) -> Dict[str, npt.NDArray[np.float32]]:
if self.model_type == Model.MODEL_TYPES.TENSORFLOW:
return {k: v.numpy() for k, v in cast(tf.keras.Model, self.model(x)).items()}
elif self.model_type == Model.MODEL_TYPES.COREML:
print(f"isfinite: {np.all(np.isfinite(x))}", flush=True)
print(f"shape: {x.shape}", flush=True)
print(f"dtype: {x.dtype}", flush=True)
result = cast(ct.models.MLModel, self.model).predict({"input_2": x})
return {
"note": result["Identity_1"],
"onset": result["Identity_2"],
"contour": result["Identity"],
}
elif self.model_type == Model.MODEL_TYPES.TFLITE:
return self.model(input_2=x) # type: ignore
elif self.model_type == Model.MODEL_TYPES.ONNX:
return {
k: v
for k, v in zip(
["note", "onset", "contour"],
cast(ort.InferenceSession, self.model).run(
[
"StatefulPartitionedCall:1",
"StatefulPartitionedCall:2",
"StatefulPartitionedCall:0",
],
{"serving_default_input_2:0": x},
),
)
}
def window_audio_file(
audio_original: npt.NDArray[np.float32], hop_size: int
) -> Iterable[Tuple[npt.NDArray[np.float32], Dict[str, float]]]:
"""
Pad appropriately an audio file, and return as
windowed signal, with window length = AUDIO_N_SAMPLES
Returns:
audio_windowed: tensor with shape (n_windows, AUDIO_N_SAMPLES, 1)
audio windowed into fixed length chunks
window_times: list of {'start':.., 'end':...} objects (times in seconds)
"""
for i in range(0, audio_original.shape[0], hop_size):
window = audio_original[i : i + AUDIO_N_SAMPLES]
if len(window) < AUDIO_N_SAMPLES:
window = np.pad(
window,
pad_width=[[0, AUDIO_N_SAMPLES - len(window)]],
)
t_start = float(i) / AUDIO_SAMPLE_RATE
window_time = {
"start": t_start,
"end": t_start + (AUDIO_N_SAMPLES / AUDIO_SAMPLE_RATE),
}
yield np.expand_dims(window, axis=-1), window_time
def get_audio_input(
audio_path: Union[pathlib.Path, str], overlap_len: int, hop_size: int
) -> Iterable[Tuple[npt.NDArray[np.float32], Dict[str, float], int]]:
"""
Read wave file (as mono), pad appropriately, and return as
windowed signal, with window length = AUDIO_N_SAMPLES
Returns:
audio_windowed: tensor with shape (n_windows, AUDIO_N_SAMPLES, 1)
audio windowed into fixed length chunks
window_times: list of {'start':.., 'end':...} objects (times in seconds)
audio_original_length: int
length of original audio file, in frames, BEFORE padding.
"""
assert overlap_len % 2 == 0, f"overlap_length must be even, got {overlap_len}"
audio_original, _ = librosa.load(str(audio_path), sr=AUDIO_SAMPLE_RATE, mono=True)
original_length = audio_original.shape[0]
audio_original = np.concatenate([np.zeros((int(overlap_len / 2),), dtype=np.float32), audio_original])
for window, window_time in window_audio_file(audio_original, hop_size):
yield np.expand_dims(window, axis=0), window_time, original_length
def unwrap_output(
output: npt.NDArray[np.float32],
audio_original_length: int,
n_overlapping_frames: int,
) -> np.array:
"""Unwrap batched model predictions to a single matrix.
Args:
output: array (n_batches, n_times_short, n_freqs)
audio_original_length: length of original audio signal (in samples)
n_overlapping_frames: number of overlapping frames in the output
Returns:
array (n_times, n_freqs)
"""
if len(output.shape) != 3:
return None
n_olap = int(0.5 * n_overlapping_frames)
if n_olap > 0:
# remove half of the overlapping frames from beginning and end
output = output[:, n_olap:-n_olap, :]
output_shape = output.shape
n_output_frames_original = int(np.floor(audio_original_length * (ANNOTATIONS_FPS / AUDIO_SAMPLE_RATE)))
unwrapped_output = output.reshape(output_shape[0] * output_shape[1], output_shape[2])
return unwrapped_output[:n_output_frames_original, :] # trim to original audio length
def run_inference(
audio_path: Union[pathlib.Path, str],
model_or_model_path: Union[Model, pathlib.Path, str],
debug_file: Optional[pathlib.Path] = None,
) -> Dict[str, np.array]:
"""Run the model on the input audio path.
Args:
audio_path: The audio to run inference on.
model_or_model_path: A loaded Model or path to a serialized model to load.
debug_file: An optional path to output debug data to. Useful for testing/verification.
Returns:
A dictionary with the notes, onsets and contours from model inference.
"""
if isinstance(model_or_model_path, Model):
model = model_or_model_path
else:
model = Model(model_or_model_path)
# overlap 30 frames
n_overlapping_frames = 30
overlap_len = n_overlapping_frames * FFT_HOP
hop_size = AUDIO_N_SAMPLES - overlap_len
output: Dict[str, Any] = {"note": [], "onset": [], "contour": []}
for audio_windowed, _, audio_original_length in get_audio_input(audio_path, overlap_len, hop_size):
for k, v in model.predict(audio_windowed).items():
output[k].append(v)
unwrapped_output = {
k: unwrap_output(np.concatenate(output[k]), audio_original_length, n_overlapping_frames) for k in output
}
if debug_file:
with open(debug_file, "w") as f:
json.dump(
{
"audio_windowed": audio_windowed.numpy().tolist(),
"audio_original_length": audio_original_length,
"hop_size_samples": hop_size,
"overlap_length_samples": overlap_len,
"unwrapped_output": {k: v.tolist() for k, v in unwrapped_output.items()},
},
f,
)
return unwrapped_output
class OutputExtensions(enum.Enum):
MIDI = "mid"
MODEL_OUTPUT_NPZ = "npz"
MIDI_SONIFICATION = "wav"
NOTE_EVENTS = "csv"
def verify_input_path(audio_path: Union[pathlib.Path, str]) -> None:
"""Verify that an input path is valid and can be processed
Args:
audio_path: Path to an audio file.
Raises:
ValueError: If the audio file is invalid.
"""
if not os.path.isfile(audio_path):
raise ValueError(f"🚨 {audio_path} is not a file path.")
if not os.path.exists(audio_path):
raise ValueError(f"🚨 {audio_path} does not exist.")
def verify_output_dir(output_dir: Union[pathlib.Path, str]) -> None:
"""Verify that an output directory is valid and can be processed
Args:
output_dir: Path to an output directory.
Raises:
ValueError: If the output directory is invalid.
"""
if not os.path.isdir(output_dir):
raise ValueError(f"🚨 {output_dir} is not a directory.")
if not os.path.exists(output_dir):
raise ValueError(f"🚨 {output_dir} does not exist.")
def build_output_path(
audio_path: Union[pathlib.Path, str],
output_directory: Union[pathlib.Path, str],
output_type: OutputExtensions,
) -> pathlib.Path:
"""Create an output path and make sure it doesn't already exist.
Args:
audio_path: The original file path.
output_directory: The directory we will output to.
output_type: The type of output file we are creating.
Raises:
IOError: If the generated path already exists.
Returns:
A new path in the output_directory with the stem audio_path and an extension
based on output_type.
"""
audio_path = str(audio_path)
if not isinstance(output_directory, pathlib.Path):
output_directory = pathlib.Path(output_directory)
basename, _ = os.path.splitext(os.path.basename(audio_path))
output_path = output_directory / f"{basename}_basic_pitch.{output_type.value}"
generating_file_message(output_type.name)
if output_path.exists():
raise IOError(
f" 🚨 {str(output_path)} already exists and would be overwritten. Skipping output files for {audio_path}."
)
return output_path
def save_note_events(
note_events: List[Tuple[float, float, int, float, Optional[List[int]]]],
save_path: Union[pathlib.Path, str],
) -> None:
"""Save note events to file
Args:
note_events: A list of note event tuples to save. Tuples have the format
("start_time_s", "end_time_s", "pitch_midi", "velocity", "list of pitch bend values")
save_path: The location we're saving it
"""
with open(save_path, "w") as fhandle:
writer = csv.writer(fhandle, delimiter=",")
writer.writerow(["start_time_s", "end_time_s", "pitch_midi", "velocity", "pitch_bend"])
for start_time, end_time, note_number, amplitude, pitch_bend in note_events:
row = [start_time, end_time, note_number, int(np.round(127 * amplitude))]
if pitch_bend:
row.extend(pitch_bend)
writer.writerow(row)
def predict(
audio_path: Union[pathlib.Path, str],
model_or_model_path: Union[Model, pathlib.Path, str] = ICASSP_2022_MODEL_PATH,
onset_threshold: float = 0.5,
frame_threshold: float = 0.3,
minimum_note_length: float = 127.70,
minimum_frequency: Optional[float] = None,
maximum_frequency: Optional[float] = None,
multiple_pitch_bends: bool = False,
melodia_trick: bool = True,
debug_file: Optional[pathlib.Path] = None,
midi_tempo: float = 120,
) -> Tuple[
Dict[str, np.array],
pretty_midi.PrettyMIDI,
List[Tuple[float, float, int, float, Optional[List[int]]]],
]:
"""Run a single prediction.
Args:
audio_path: File path for the audio to run inference on.
model_or_model_path: A loaded Model or path to a serialized model to load.
onset_threshold: Minimum energy required for an onset to be considered present.
frame_threshold: Minimum energy requirement for a frame to be considered present.
minimum_note_length: The minimum allowed note length in milliseconds.
minimum_freq: Minimum allowed output frequency, in Hz. If None, all frequencies are used.
maximum_freq: Maximum allowed output frequency, in Hz. If None, all frequencies are used.
multiple_pitch_bends: If True, allow overlapping notes in midi file to have pitch bends.
melodia_trick: Use the melodia post-processing step.
debug_file: An optional path to output debug data to. Useful for testing/verification.
Returns:
The model output, midi data and note events from a single prediction
"""
with no_tf_warnings():
print(f"Predicting MIDI for {audio_path}...")
model_output = run_inference(audio_path, model_or_model_path, debug_file)
min_note_len = int(np.round(minimum_note_length / 1000 * (AUDIO_SAMPLE_RATE / FFT_HOP)))
midi_data, note_events = infer.model_output_to_notes(
model_output,
onset_thresh=onset_threshold,
frame_thresh=frame_threshold,
min_note_len=min_note_len, # convert to frames
min_freq=minimum_frequency,
max_freq=maximum_frequency,
multiple_pitch_bends=multiple_pitch_bends,
melodia_trick=melodia_trick,
midi_tempo=midi_tempo,
)
if debug_file:
with open(debug_file) as f:
debug_data = json.load(f)
with open(debug_file, "w") as f:
json.dump(
{
**debug_data,
"min_note_length": min_note_len,
"onset_thresh": onset_threshold,
"frame_thresh": frame_threshold,
"estimated_notes": [
(
float(start_time),
float(end_time),
int(pitch),
float(amplitude),
[int(b) for b in pitch_bends] if pitch_bends else None,
)
for start_time, end_time, pitch, amplitude, pitch_bends in note_events
],
},
f,
)
return model_output, midi_data, note_events
def predict_and_save(
audio_path_list: Sequence[Union[pathlib.Path, str]],
output_directory: Union[pathlib.Path, str],
save_midi: bool,
sonify_midi: bool,
save_model_outputs: bool,
save_notes: bool,
model_or_model_path: Union[Model, str, pathlib.Path],
onset_threshold: float = 0.5,
frame_threshold: float = 0.3,
minimum_note_length: float = 127.70,
minimum_frequency: Optional[float] = None,
maximum_frequency: Optional[float] = None,
multiple_pitch_bends: bool = False,
melodia_trick: bool = True,
debug_file: Optional[pathlib.Path] = None,
sonification_samplerate: int = 44100,
midi_tempo: float = 120,
) -> None:
"""Make a prediction and save the results to file.
Args:
audio_path_list: List of file paths for the audio to run inference on.
output_directory: Directory to output MIDI and all other outputs derived from the model to.
save_midi: True to save midi.
sonify_midi: Whether or not to render audio from the MIDI and output it to a file.
save_model_outputs: True to save contours, onsets and notes from the model prediction.
save_notes: True to save note events.
model_or_model_path: A loaded Model or path to a serialized model to load.
onset_threshold: Minimum energy required for an onset to be considered present.
frame_threshold: Minimum energy requirement for a frame to be considered present.
minimum_note_length: The minimum allowed note length in milliseconds.
minimum_freq: Minimum allowed output frequency, in Hz. If None, all frequencies are used.
maximum_freq: Maximum allowed output frequency, in Hz. If None, all frequencies are used.
multiple_pitch_bends: If True, allow overlapping notes in midi file to have pitch bends.
melodia_trick: Use the melodia post-processing step.
debug_file: An optional path to output debug data to. Useful for testing/verification.
sonification_samplerate: Sample rate for rendering audio from MIDI.
"""
for audio_path in audio_path_list:
print("")
try:
model_output, midi_data, note_events = predict(
pathlib.Path(audio_path),
model_or_model_path,
onset_threshold,
frame_threshold,
minimum_note_length,
minimum_frequency,
maximum_frequency,
multiple_pitch_bends,
melodia_trick,
debug_file,
midi_tempo,
)
if save_model_outputs:
model_output_path = build_output_path(audio_path, output_directory, OutputExtensions.MODEL_OUTPUT_NPZ)
try:
np.savez(model_output_path, basic_pitch_model_output=model_output)
file_saved_confirmation(OutputExtensions.MODEL_OUTPUT_NPZ.name, model_output_path)
except Exception as e:
failed_to_save(OutputExtensions.MODEL_OUTPUT_NPZ.name, model_output_path)
raise e
if save_midi:
try:
midi_path = build_output_path(audio_path, output_directory, OutputExtensions.MIDI)
except IOError as e:
raise e
try:
midi_data.write(str(midi_path))
file_saved_confirmation(OutputExtensions.MIDI.name, midi_path)
except Exception as e:
failed_to_save(OutputExtensions.MIDI.name, midi_path)
raise e
if sonify_midi:
midi_sonify_path = build_output_path(audio_path, output_directory, OutputExtensions.MIDI_SONIFICATION)
try:
infer.sonify_midi(midi_data, midi_sonify_path, sr=sonification_samplerate)
file_saved_confirmation(OutputExtensions.MIDI_SONIFICATION.name, midi_sonify_path)
except Exception as e:
failed_to_save(OutputExtensions.MIDI_SONIFICATION.name, midi_sonify_path)
raise e
if save_notes:
note_events_path = build_output_path(audio_path, output_directory, OutputExtensions.NOTE_EVENTS)
try:
save_note_events(note_events, note_events_path)
file_saved_confirmation(OutputExtensions.NOTE_EVENTS.name, note_events_path)
except Exception as e:
failed_to_save(OutputExtensions.NOTE_EVENTS.name, note_events_path)
raise e
except Exception as e:
raise e