in basic_pitch/train.py [0:0]
def console_entry_point() -> None:
"""From pip installed script."""
parser = argparse.ArgumentParser(description="")
parser.add_argument("--source", required=True, help="Path to directory containing train/validation splits.")
parser.add_argument("--output", required=True, help="Directory to save the model in.")
parser.add_argument("-e", "--epochs", type=int, default=500, help="Number of training epochs.")
parser.add_argument(
"-b",
"--batch-size",
type=int,
default=16,
help="batch size of training. Unlike Estimator API, this specifies the batch size per-GPU.",
)
parser.add_argument(
"-l",
"--learning-rate",
type=float,
default=0.001,
help="ADAM optimizer learning rate",
)
parser.add_argument(
"-s",
"--steps-per-epoch",
type=int,
default=100,
help="steps_per_epoch (batch) of each training loop",
)
parser.add_argument(
"-v",
"--validation-steps",
type=int,
default=10,
help="validation steps (number of BATCHES) for each validation run. MUST be a positive integer",
)
parser.add_argument(
"-z",
"--training-shuffle-size",
type=int,
default=100,
help="training dataset shuffle size",
)
parser.add_argument(
"--size-evaluation-callback-datasets",
type=int,
default=4,
help="number of elements in the dataset used by the evaluation callback function",
)
for dataset in DATASET_SAMPLING_FREQUENCY.keys():
parser.add_argument(
f"--{dataset.lower()}",
action="store_true",
default=False,
help=f"Use {dataset} dataset in training",
)
parser.add_argument(
"--no-sonify",
action="store_true",
default=False,
help="if given, exclude sonifications from the tensorboard / data visualization",
)
parser.add_argument(
"--no-contours",
action="store_true",
default=False,
help="if given, trains without supervising the contour layer",
)
parser.add_argument(
"--weighted-onset-loss",
action="store_true",
default=False,
help="if given, trains onsets with a class-weighted loss",
)
parser.add_argument(
"--positive-onset-weight",
type=float,
default=0.5,
help="Positive class onset weight. Only applies when weignted onset loss is true.",
)
args = parser.parse_args()
datasets_to_use = [
dataset.lower()
for dataset in DATASET_SAMPLING_FREQUENCY.keys()
if getattr(args, dataset.lower().replace("-", "_"))
]
dataset_sampling_frequency = np.array(
[
frequency
for dataset, frequency in DATASET_SAMPLING_FREQUENCY.items()
if getattr(args, dataset.lower().replace("-", "_"))
]
)
dataset_sampling_frequency = dataset_sampling_frequency / np.sum(dataset_sampling_frequency)
assert args.steps_per_epoch is not None
assert args.validation_steps > 0
main(
args.source,
args.output,
args.training_shuffle_size,
args.batch_size,
args.learning_rate,
args.epochs,
args.steps_per_epoch,
args.validation_steps,
args.size_evaluation_callback_datasets,
datasets_to_use,
dataset_sampling_frequency,
args.dont_sonify,
args.no_contours,
args.weighted_onset_loss,
args.positive_onset_weight,
)