basic_pitch/data/pipeline.py (68 lines of code) (raw):
#!/usr/bin/env python
# encoding: utf-8
#
# Copyright 2024 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import uuid
from typing import Any, Dict, List, Tuple, Callable, Union
import apache_beam as beam
import tensorflow as tf
from apache_beam.options.pipeline_options import PipelineOptions
# Beacase beam.GroupIntoBatches isn't supported as of 2.29
class Batch(beam.DoFn):
def __init__(self, batch_size: int) -> None:
self.batch_size = batch_size
def process(self, element: List[Any], *args: Tuple[Any, Any], **kwargs: Dict[str, Any]) -> Any:
for i in range(0, len(element), self.batch_size):
yield element[i : i + self.batch_size]
class WriteBatchToTfRecord(beam.DoFn):
def __init__(self, destination: str) -> None:
self.destination = destination
def process(self, element: Any, *args: Tuple[Any, Any], **kwargs: Dict[str, Any]) -> None:
if not isinstance(element, list):
element = [element]
logging.info(f"Writing to file batch of length {len(element)}")
# hopefully uuids are unique enough
with tf.io.TFRecordWriter(os.path.join(self.destination, f"{uuid.uuid4()}.tfrecord")) as writer:
for example in element:
writer.write(example.SerializeToString())
def transcription_dataset_writer(
p: beam.Pipeline,
input_data: List[Tuple[str, str]],
to_tf_example: Union[beam.DoFn, Callable[[List[Any]], Any]],
filter_invalid_tracks: beam.PTransform,
destination: str,
batch_size: int,
) -> None:
valid_track_ids = (
p
| "Create PCollection of track IDS" >> beam.Create(input_data)
| "Remove invalid track IDs"
>> beam.ParDo(filter_invalid_tracks).with_outputs(
"train",
"test",
"validation",
)
)
for split in ["train", "test", "validation"]:
(
getattr(valid_track_ids, split)
| f"Batch {split}" >> beam.BatchElements(min_batch_size=batch_size, max_batch_size=batch_size)
| f"Reshuffle {split}" >> beam.Reshuffle() # To prevent fuses
| f"Create tf.Example {split} batch" >> beam.ParDo(to_tf_example)
| f"Write {split} batch to tfrecord" >> beam.ParDo(WriteBatchToTfRecord(os.path.join(destination, split)))
)
getattr(valid_track_ids, split) | f"Write {split} index file" >> beam.io.textio.WriteToText(
os.path.join(destination, split, "index.csv"),
num_shards=1,
header="track_id",
shard_name_template="",
)
def run(
pipeline_options: Dict[str, str],
pipeline_args: List[str],
input_data: List[Tuple[str, str]],
to_tf_example: beam.DoFn,
filter_invalid_tracks: beam.DoFn,
destination: str,
batch_size: int,
) -> None:
logging.info(f"pipeline_options = {pipeline_options}")
logging.info(f"pipeline_args = {pipeline_args}")
with beam.Pipeline(options=PipelineOptions(flags=pipeline_args, **pipeline_options)) as p:
transcription_dataset_writer(p, input_data, to_tf_example, filter_invalid_tracks, destination, batch_size)