basic_pitch/data/datasets/slakh.py (160 lines of code) (raw):
#!/usr/bin/env python
# encoding: utf-8
#
# Copyright 2024 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os
import time
from typing import List, Tuple, Any
import apache_beam as beam
import mirdata
from basic_pitch.data import commandline, pipeline
class SlakhFilterInvalidTracks(beam.DoFn):
DOWNLOAD_ATTRIBUTES = ["audio_path", "metadata_path", "midi_path"]
def __init__(self, source: str):
self.source = source
def setup(self) -> None:
import mirdata
self.slakh_remote = mirdata.initialize("slakh", data_home=self.source)
self.filesystem = beam.io.filesystems.FileSystems()
def process(self, element: Tuple[str, str]) -> Any:
import tempfile
import apache_beam as beam
import ffmpeg
from basic_pitch.constants import (
AUDIO_N_CHANNELS,
AUDIO_SAMPLE_RATE,
)
track_id, split = element
if split == "omitted":
return None
logging.info(f"Processing (track_id, split): ({track_id}, {split})")
track_remote = self.slakh_remote.track(track_id)
with tempfile.TemporaryDirectory() as local_tmp_dir:
slakh_local = mirdata.initialize("slakh", local_tmp_dir)
track_local = slakh_local.track(track_id)
for attr in self.DOWNLOAD_ATTRIBUTES:
source = getattr(track_remote, attr)
dest = getattr(track_local, attr)
if not dest:
return None
logging.info(f"Downloading {attr} from {source} to {dest}")
os.makedirs(os.path.dirname(dest), exist_ok=True)
with self.filesystem.open(source) as s, open(dest, "wb") as d:
d.write(s.read())
if track_local.is_drum:
return None
local_wav_path = "{}_tmp.wav".format(track_local.audio_path)
try:
ffmpeg.input(track_local.audio_path).output(
local_wav_path, ar=AUDIO_SAMPLE_RATE, ac=AUDIO_N_CHANNELS
).run()
except Exception as e:
logging.info(f"Could not process {local_wav_path}. Exception: {e}")
return None
# if there are no notes, skip this track
if track_local.notes is None or len(track_local.notes.intervals) == 0:
return None
yield beam.pvalue.TaggedOutput(split, track_id)
class SlakhToTfExample(beam.DoFn):
DOWNLOAD_ATTRIBUTES = ["audio_path", "metadata_path", "midi_path"]
def __init__(self, source: str, download: bool) -> None:
self.source = source
self.download = download
def setup(self) -> None:
import apache_beam as beam
import mirdata
self.slakh_remote = mirdata.initialize("slakh", data_home=self.source)
self.filesystem = beam.io.filesystems.FileSystems() # TODO: replace with fsspec
if self.download:
self.slakh_remote.download()
def process(self, element: List[str]) -> List[Any]:
import tempfile
import numpy as np
import ffmpeg
from basic_pitch.constants import (
AUDIO_N_CHANNELS,
AUDIO_SAMPLE_RATE,
FREQ_BINS_CONTOURS,
FREQ_BINS_NOTES,
ANNOTATION_HOP,
N_FREQ_BINS_NOTES,
N_FREQ_BINS_CONTOURS,
)
from basic_pitch.data import tf_example_serialization
logging.info(f"Processing {element}")
batch = []
for track_id in element:
track_remote = self.slakh_remote.track(track_id)
with tempfile.TemporaryDirectory() as local_tmp_dir:
slakh_local = mirdata.initialize("slakh", local_tmp_dir)
track_local = slakh_local.track(track_id)
for attr in self.DOWNLOAD_ATTRIBUTES:
source = getattr(track_remote, attr)
dest = getattr(track_local, attr)
logging.info(f"Downloading {attr} from {source} to {dest}")
os.makedirs(os.path.dirname(dest), exist_ok=True)
with self.filesystem.open(source) as s, open(dest, "wb") as d:
d.write(s.read())
local_wav_path = "{}_tmp.wav".format(track_local.audio_path)
ffmpeg.input(track_local.audio_path).output(
local_wav_path, ar=AUDIO_SAMPLE_RATE, ac=AUDIO_N_CHANNELS
).run()
duration = float(ffmpeg.probe(local_wav_path)["format"]["duration"])
time_scale = np.arange(0, duration + ANNOTATION_HOP, ANNOTATION_HOP)
n_time_frames = len(time_scale)
note_indices, note_values = track_local.notes.to_sparse_index(time_scale, "s", FREQ_BINS_NOTES, "hz")
onset_indices, onset_values = track_local.notes.to_sparse_index(
time_scale, "s", FREQ_BINS_NOTES, "hz", onsets_only=True
)
contour_indices, contour_values = track_local.multif0.to_sparse_index(
time_scale, "s", FREQ_BINS_CONTOURS, "hz"
)
batch.append(
tf_example_serialization.to_transcription_tfexample(
track_id,
"slakh",
local_wav_path,
note_indices,
note_values,
onset_indices,
onset_values,
contour_indices,
contour_values,
(n_time_frames, N_FREQ_BINS_NOTES),
(n_time_frames, N_FREQ_BINS_CONTOURS),
)
)
logging.info(f"Finished processing batch of length {len(batch)}")
return [batch]
def create_input_data() -> List[Tuple[str, str]]:
slakh = mirdata.initialize("slakh")
return [(track_id, track.data_split) for track_id, track in slakh.load_tracks().items()]
def main(known_args: argparse.Namespace, pipeline_args: List[str]) -> None:
time_created = int(time.time())
destination = commandline.resolve_destination(known_args, time_created)
input_data = create_input_data()
pipeline_options = {
"runner": known_args.runner,
"job_name": f"slakh-tfrecords-{time_created}",
"machine_type": "e2-standard-4",
"num_workers": 25,
"disk_size_gb": 128,
"experiments": ["use_runner_v2"],
"save_main_session": True,
"sdk_container_image": known_args.sdk_container_image,
"job_endpoint": known_args.job_endpoint,
"environment_type": "DOCKER",
"environment_config": known_args.sdk_container_image,
}
pipeline.run(
pipeline_options,
pipeline_args,
input_data,
SlakhToTfExample(known_args.source, download=True),
SlakhFilterInvalidTracks(known_args.source),
destination,
known_args.batch_size,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
commandline.add_default(parser, os.path.basename(os.path.splitext(__file__)[0]))
commandline.add_split(parser)
known_args, pipeline_args = parser.parse_known_args() # sys.argv)
main(known_args, pipeline_args)