basic_pitch/train.py (222 lines of code) (raw):
#!/usr/bin/env python
# encoding: utf-8
#
# Copyright 2024 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import logging
from datetime import datetime, timezone
from typing import List
import numpy as np
import tensorflow as tf
from basic_pitch import models
from basic_pitch.callbacks import VisualizeCallback
from basic_pitch.constants import DATASET_SAMPLING_FREQUENCY
from basic_pitch.data import tf_example_deserialization
logging.basicConfig(level=logging.INFO)
def main(
source: str,
output: str,
batch_size: int,
shuffle_size: int,
learning_rate: float,
epochs: int,
steps_per_epoch: int,
validation_steps: int,
size_evaluation_callback_datasets: int,
datasets_to_use: List[str],
dataset_sampling_frequency: np.ndarray,
no_sonify: bool,
no_contours: bool,
weighted_onset_loss: bool,
positive_onset_weight: float,
) -> None:
"""Parse config and run training or evaluation.
Args:
source: source directory for data
output: output directory for trained model / checkpoints / tensorboard
batch_size: batch size for data.
shuffle_size: size of shuffle buffer (only for training set) for the data shuffling mechanism
learning_rate: learning_rate for training
epochs: number of epochs to train for
steps_per_epoch: the number of batches to process per epoch during training
validation_steps: the number of validation batches to evaluate on per epoch
size_evaluation_callback_datasets: the batch size to use for visualization / logging
datasets_to_use: which datasets to train / evaluate on e.g. guitarset, medleydb_pitch, slakh
dataset_sampling_frequency: distribution weighting vector corresponding to datasets determining how they
are sampled from during training / validation dataset creation.
no_sonify: Whether or not to include sonifications in tensorboard.
no_contours: Whether or not to include contours in the output.
weighted_onset_loss: whether or not to use a weighted cross entropy loss.
positive_onset_weight: weighting factor for the positive labels.
"""
# configuration.add_externals()
logging.info(f"source directory: {source}")
logging.info(f"output directory: {output}")
logging.info(f"tensorflow version: {tf.__version__}")
logging.info("parameters to train.main() function:")
logging.info(f"batch_size: {batch_size}")
logging.info(f"shuffle_size: {shuffle_size}")
logging.info(f"learning_rate: {learning_rate}")
logging.info(f"epochs: {epochs}")
logging.info(f"steps_per_epoch: {steps_per_epoch}")
logging.info(f"validation_steps: {validation_steps}")
logging.info(f"size_evaluation_callback_datasets: {size_evaluation_callback_datasets}")
logging.info(f"using datasets: {datasets_to_use} with frequencies {dataset_sampling_frequency}")
logging.info(f"no_contours: {no_contours}")
logging.info(f"weighted_onset_loss: {weighted_onset_loss}")
logging.info(f"positive_onset_weight: {positive_onset_weight}")
# model
model = models.model(no_contours=no_contours)
input_shape = list(model.input_shape)
if input_shape[0] is None:
input_shape[0] = batch_size
logging.info("input_shape" + str(input_shape))
output_shape = model.output_shape
for k, v in output_shape.items():
output_shape[k] = list(v)
if v[0] is None:
output_shape[k][0] = batch_size
logging.info("output_shape" + str(output_shape))
# data loaders
train_ds, validation_ds = tf_example_deserialization.prepare_datasets(
source,
shuffle_size,
batch_size,
validation_steps,
datasets_to_use,
dataset_sampling_frequency,
)
MAX_EVAL_CBF_BATCH_SIZE = 4
(
train_visualization_ds,
validation_visualization_ds,
) = tf_example_deserialization.prepare_visualization_datasets(
source,
batch_size=min(size_evaluation_callback_datasets, MAX_EVAL_CBF_BATCH_SIZE),
validation_steps=max(1, size_evaluation_callback_datasets // MAX_EVAL_CBF_BATCH_SIZE),
datasets_to_use=datasets_to_use,
dataset_sampling_frequency=dataset_sampling_frequency,
)
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M")
tensorboard_log_dir = os.path.join(output, timestamp, "tensorboard")
callbacks = [
tf.keras.callbacks.TensorBoard(log_dir=tensorboard_log_dir, histogram_freq=1),
tf.keras.callbacks.EarlyStopping(patience=25, verbose=2),
tf.keras.callbacks.ReduceLROnPlateau(verbose=1, patience=10, factor=0.5),
tf.keras.callbacks.ModelCheckpoint(filepath=os.path.join(output, timestamp, "model.best"), save_best_only=True),
tf.keras.callbacks.ModelCheckpoint(
filepath=os.path.join(output, timestamp, "checkpoints", "model.{epoch:02d}")
),
VisualizeCallback(
train_visualization_ds,
validation_visualization_ds,
tensorboard_log_dir,
not no_sonify,
not no_contours,
),
]
# if no_contours:
# loss = models.loss_no_contour(weighted=weighted_onset_loss, positive_weight=positive_onset_weight)
# else:
# loss = models.loss(weighted=weighted_onset_loss, positive_weight=positive_onset_weight)
loss = models.loss(weighted=weighted_onset_loss, positive_weight=positive_onset_weight)
# train
model.compile(
loss=loss,
optimizer=tf.keras.optimizers.Adam(learning_rate),
sample_weight_mode={"contour": None, "note": None, "onset": None},
)
logging.info("--- Model Training specs ---")
logging.info(f" train_ds: {train_ds}")
logging.info(f" validation_ds: {validation_ds}")
model.summary()
model.fit(
train_ds,
epochs=epochs,
callbacks=callbacks,
steps_per_epoch=steps_per_epoch,
validation_data=validation_ds,
validation_steps=validation_steps,
)
def console_entry_point() -> None:
"""From pip installed script."""
parser = argparse.ArgumentParser(description="")
parser.add_argument("--source", required=True, help="Path to directory containing train/validation splits.")
parser.add_argument("--output", required=True, help="Directory to save the model in.")
parser.add_argument("-e", "--epochs", type=int, default=500, help="Number of training epochs.")
parser.add_argument(
"-b",
"--batch-size",
type=int,
default=16,
help="batch size of training. Unlike Estimator API, this specifies the batch size per-GPU.",
)
parser.add_argument(
"-l",
"--learning-rate",
type=float,
default=0.001,
help="ADAM optimizer learning rate",
)
parser.add_argument(
"-s",
"--steps-per-epoch",
type=int,
default=100,
help="steps_per_epoch (batch) of each training loop",
)
parser.add_argument(
"-v",
"--validation-steps",
type=int,
default=10,
help="validation steps (number of BATCHES) for each validation run. MUST be a positive integer",
)
parser.add_argument(
"-z",
"--training-shuffle-size",
type=int,
default=100,
help="training dataset shuffle size",
)
parser.add_argument(
"--size-evaluation-callback-datasets",
type=int,
default=4,
help="number of elements in the dataset used by the evaluation callback function",
)
for dataset in DATASET_SAMPLING_FREQUENCY.keys():
parser.add_argument(
f"--{dataset.lower()}",
action="store_true",
default=False,
help=f"Use {dataset} dataset in training",
)
parser.add_argument(
"--no-sonify",
action="store_true",
default=False,
help="if given, exclude sonifications from the tensorboard / data visualization",
)
parser.add_argument(
"--no-contours",
action="store_true",
default=False,
help="if given, trains without supervising the contour layer",
)
parser.add_argument(
"--weighted-onset-loss",
action="store_true",
default=False,
help="if given, trains onsets with a class-weighted loss",
)
parser.add_argument(
"--positive-onset-weight",
type=float,
default=0.5,
help="Positive class onset weight. Only applies when weignted onset loss is true.",
)
args = parser.parse_args()
datasets_to_use = [
dataset.lower()
for dataset in DATASET_SAMPLING_FREQUENCY.keys()
if getattr(args, dataset.lower().replace("-", "_"))
]
dataset_sampling_frequency = np.array(
[
frequency
for dataset, frequency in DATASET_SAMPLING_FREQUENCY.items()
if getattr(args, dataset.lower().replace("-", "_"))
]
)
dataset_sampling_frequency = dataset_sampling_frequency / np.sum(dataset_sampling_frequency)
assert args.steps_per_epoch is not None
assert args.validation_steps > 0
main(
args.source,
args.output,
args.training_shuffle_size,
args.batch_size,
args.learning_rate,
args.epochs,
args.steps_per_epoch,
args.validation_steps,
args.size_evaluation_callback_datasets,
datasets_to_use,
dataset_sampling_frequency,
args.dont_sonify,
args.no_contours,
args.weighted_onset_loss,
args.positive_onset_weight,
)
if __name__ == "__main__":
console_entry_point()