basic_pitch/models.py (173 lines of code) (raw):
#!/usr/bin/env python
# encoding: utf-8
#
# Copyright 2022 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict
import numpy as np
import tensorflow as tf
from basic_pitch import nn
from basic_pitch.constants import (
ANNOTATIONS_BASE_FREQUENCY,
ANNOTATIONS_N_SEMITONES,
AUDIO_N_SAMPLES,
AUDIO_SAMPLE_RATE,
CONTOURS_BINS_PER_SEMITONE,
FFT_HOP,
N_FREQ_BINS_CONTOURS,
)
from basic_pitch.layers import signal, nnaudio
tfkl = tf.keras.layers
MAX_N_SEMITONES = int(np.floor(12.0 * np.log2(0.5 * AUDIO_SAMPLE_RATE / ANNOTATIONS_BASE_FREQUENCY)))
def transcription_loss(y_true: tf.Tensor, y_pred: tf.Tensor, label_smoothing: float) -> tf.Tensor:
"""Really a binary cross entropy loss. Used to calculate the loss between the predicted
posteriorgrams and the ground truth matrices.
Args:
y_true: The true labels.
y_pred: The predicted labels.
label_smoothing: Squeeze labels towards 0.5.
Returns:
The transcription loss.
"""
bce = tf.keras.losses.binary_crossentropy(y_true, y_pred, label_smoothing=label_smoothing)
return bce
def weighted_transcription_loss(
y_true: tf.Tensor, y_pred: tf.Tensor, label_smoothing: float, positive_weight: float = 0.5
) -> tf.Tensor:
"""The transcription loss where the positive and negative true labels are balanced by a weighting factor.
Args:
y_true: The true labels.
y_pred: The predicted labels.
label_smoothing: Smoothing factor. Squeezes labels towards 0.5.
positive_weight: Weighting factor for the positive labels.
Returns:
The weighted transcription loss.
"""
negative_mask = tf.equal(y_true, 0)
nonnegative_mask = tf.logical_not(negative_mask)
bce_negative = tf.keras.losses.binary_crossentropy(
tf.boolean_mask(y_true, negative_mask),
tf.boolean_mask(y_pred, negative_mask),
label_smoothing=label_smoothing,
)
bce_nonnegative = tf.keras.losses.binary_crossentropy(
tf.boolean_mask(y_true, nonnegative_mask),
tf.boolean_mask(y_pred, nonnegative_mask),
label_smoothing=label_smoothing,
)
return ((1 - positive_weight) * bce_negative) + (positive_weight * bce_nonnegative)
def onset_loss(
weighted: bool, label_smoothing: float, positive_weight: float
) -> Callable[[tf.Tensor, tf.Tensor], tf.Tensor]:
"""
Args:
weighted: Whether or not to use a weighted cross entropy loss.
label_smoothing: Smoothing factor. Squeezes labels towards 0.5.
positive_weight: Weighting factor for the positive labels.
Returns:
A function that calculates the transcription loss. The function will
return weighted_transcription_loss if weighted is true else it will return
transcription_loss.
"""
if weighted:
return lambda x, y: weighted_transcription_loss(
x, y, label_smoothing=label_smoothing, positive_weight=positive_weight
)
return lambda x, y: transcription_loss(x, y, label_smoothing=label_smoothing)
def loss(label_smoothing: float = 0.2, weighted: bool = False, positive_weight: float = 0.5) -> Dict[str, Any]:
"""Creates a keras-compatible dictionary of loss functions to calculate
the loss for the contour, note and onset posteriorgrams.
Args:
label_smoothing: Smoothing factor. Squeezes labels towards 0.5.
weighted: Whether or not to use a weighted cross entropy loss.
positive_weight: Weighting factor for the positive labels.
Returns:
A dictionary with keys "contour," "note," and "onset" with functions as values to be used to calculate
transcription losses.
"""
loss_fn = lambda x, y: transcription_loss(x, y, label_smoothing=label_smoothing)
loss_onset = onset_loss(weighted, label_smoothing, positive_weight)
return {
"contour": loss_fn,
"note": loss_fn,
"onset": loss_onset,
}
def _initializer() -> tf.keras.initializers.VarianceScaling:
return tf.keras.initializers.VarianceScaling(scale=2.0, mode="fan_avg", distribution="uniform", seed=None)
def _kernel_constraint() -> tf.keras.constraints.UnitNorm:
return tf.keras.constraints.UnitNorm(axis=[0, 1, 2])
def get_cqt(inputs: tf.Tensor, n_harmonics: int, use_batchnorm: bool) -> tf.Tensor:
"""Calculate the CQT of the input audio.
Input shape: (batch, number of audio samples, 1)
Output shape: (batch, number of frequency bins, number of time frames)
Args:
inputs: The audio input.
n_harmonics: The number of harmonics to capture above the maximum output frequency.
Used to calculate the number of semitones for the CQT.
use_batchnorm: If True, applies batch normalization after computing the CQT
Returns:
The log-normalized CQT of the input audio.
"""
n_semitones = np.min(
[
int(np.ceil(12.0 * np.log2(n_harmonics)) + ANNOTATIONS_N_SEMITONES),
MAX_N_SEMITONES,
]
)
x = nn.FlattenAudioCh()(inputs)
x = nnaudio.CQT(
sr=AUDIO_SAMPLE_RATE,
hop_length=FFT_HOP,
fmin=ANNOTATIONS_BASE_FREQUENCY,
n_bins=n_semitones * CONTOURS_BINS_PER_SEMITONE,
bins_per_octave=12 * CONTOURS_BINS_PER_SEMITONE,
)(x)
x = signal.NormalizedLog()(x)
x = tf.expand_dims(x, -1)
if use_batchnorm:
x = tfkl.BatchNormalization()(x)
return x
def model(
n_harmonics: int = 8,
n_filters_contour: int = 32,
n_filters_onsets: int = 32,
n_filters_notes: int = 32,
no_contours: bool = False,
) -> tf.keras.Model:
"""Basic Pitch's model implementation.
Args:
n_harmonics: The number of harmonics to use in the harmonic stacking layer.
n_filters_contour: Number of filters for the contour convolutional layer.
n_filters_onsets: Number of filters for the onsets convolutional layer.
n_filters_notes: Number of filters for the notes convolutional layer.
no_contours: Whether or not to include contours in the output.
"""
# input representation
inputs = tf.keras.Input(shape=(AUDIO_N_SAMPLES, 1)) # (batch, time, ch)
x = get_cqt(inputs, n_harmonics, True)
if n_harmonics > 1:
x = nn.HarmonicStacking(
CONTOURS_BINS_PER_SEMITONE,
[0.5] + list(range(1, n_harmonics)),
N_FREQ_BINS_CONTOURS,
)(x)
else:
x = nn.HarmonicStacking(
CONTOURS_BINS_PER_SEMITONE,
[1],
N_FREQ_BINS_CONTOURS,
)(x)
# contour layers - fully convolutional
x_contours = tfkl.Conv2D(
n_filters_contour,
(5, 5),
padding="same",
kernel_initializer=_initializer(),
kernel_constraint=_kernel_constraint(),
)(x)
x_contours = tfkl.BatchNormalization()(x_contours)
x_contours = tfkl.ReLU()(x_contours)
x_contours = tfkl.Conv2D(
8,
(3, 3 * 13),
padding="same",
kernel_initializer=_initializer(),
kernel_constraint=_kernel_constraint(),
)(x)
x_contours = tfkl.BatchNormalization()(x_contours)
x_contours = tfkl.ReLU()(x_contours)
if not no_contours:
contour_name = "contour"
x_contours = tfkl.Conv2D(
1,
(5, 5),
padding="same",
activation="sigmoid",
kernel_initializer=_initializer(),
kernel_constraint=_kernel_constraint(),
name="contours-reduced",
)(x_contours)
x_contours = nn.FlattenFreqCh(name=contour_name)(x_contours) # contour output
# reduced contour output as input to notes
x_contours_reduced = tf.expand_dims(x_contours, -1)
else:
x_contours_reduced = x_contours
x_contours_reduced = tfkl.Conv2D(
n_filters_notes,
(7, 7),
padding="same",
strides=(1, 3),
kernel_initializer=_initializer(),
kernel_constraint=_kernel_constraint(),
)(x_contours_reduced)
x_contours_reduced = tfkl.ReLU()(x_contours_reduced)
# note output layer
note_name = "note"
x_notes_pre = tfkl.Conv2D(
1,
(7, 3),
padding="same",
kernel_initializer=_initializer(),
kernel_constraint=_kernel_constraint(),
activation="sigmoid",
)(x_contours_reduced)
x_notes = nn.FlattenFreqCh(name=note_name)(x_notes_pre)
# onset output layer
# onsets - fully convolutional
x_onset = tfkl.Conv2D(
n_filters_onsets,
(5, 5),
padding="same",
strides=(1, 3),
kernel_initializer=_initializer(),
kernel_constraint=_kernel_constraint(),
)(x)
x_onset = tfkl.BatchNormalization()(x_onset)
x_onset = tfkl.ReLU()(x_onset)
x_onset = tfkl.Concatenate(axis=3, name="concat")([x_notes_pre, x_onset])
x_onset = tfkl.Conv2D(
1,
(3, 3),
padding="same",
activation="sigmoid",
kernel_initializer=_initializer(),
kernel_constraint=_kernel_constraint(),
)(x_onset)
onset_name = "onset"
x_onset = nn.FlattenFreqCh(
name=onset_name,
)(x_onset)
outputs = {"onset": x_onset, "contour": x_contours, "note": x_notes}
return tf.keras.Model(inputs=inputs, outputs=outputs)