basic_pitch/data/commandline.py (65 lines of code) (raw):
#!/usr/bin/env python
# encoding: utf-8
#
# Cos.pathyright 2024 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a cos.pathy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from pathlib import Path
from typing import Optional
def add_default(parser: argparse.ArgumentParser, dataset_name: str = "") -> None:
default_source = str(Path.home() / "mir_datasets" / dataset_name)
default_destination = str(Path.home() / "data" / "basic_pitch" / dataset_name)
parser.add_argument(
"--source",
default=default_source,
type=str,
help=f"Source directory for mir data. Defaults to {default_source}",
)
parser.add_argument(
"--destination",
default=default_destination,
type=str,
help=f"Output directory to write results to. Defaults to {default_destination}",
)
parser.add_argument(
"--runner",
choices=["DataflowRunner", "DirectRunner", "PortableRunner"],
default="DirectRunner",
help="Whether to run the download and process locally or on GCP Dataflow",
)
parser.add_argument(
"--timestamped",
default=False,
action="store_true",
help="If passed, the dataset will be put into a timestamp directory instead of 'splits'",
)
parser.add_argument("--batch-size", default=5, type=int, help="Number of examples per tfrecord")
parser.add_argument(
"--sdk_container_image",
default="",
help="Container image to run dataset generation job with. \
Required due to non-python dependencies.",
)
parser.add_argument("--job_endpoint", default="embed", help="")
def resolve_destination(namespace: argparse.Namespace, time_created: int) -> str:
return os.path.join(namespace.destination, str(time_created) if namespace.timestamped else "splits")
def add_split(
parser: argparse.ArgumentParser,
train_percent: float = 0.8,
validation_percent: float = 0.1,
split_seed: Optional[int] = None,
) -> None:
parser.add_argument(
"--train-percent",
type=float,
default=train_percent,
help="Percentage of tracks to mark as train",
)
parser.add_argument(
"--validation-percent",
type=float,
default=validation_percent,
help="Percentage of tracks to mark as validation",
)
parser.add_argument(
"--split-seed",
type=int,
default=split_seed,
help="Seed for random number generator used in split generation",
)