def _train_file_generator()

in basic_pitch/data/tf_example_deserialization.py [0:0]


def _train_file_generator(x: Dict[str, tf.data.Dataset], weights: np.ndarray) -> Iterator[tf.Tensor]:
    """file generator for training sets"""
    x = {k: list(v) for (k, v) in x.items()}
    keys = list(x.keys())
    # shuffle each list
    for k in keys:
        np.random.shuffle(x[k])

    while all(x.values()):
        # choose a random dataset and yield the last file
        fpath = x[np.random.choice(keys, p=weights)].pop()
        yield fpath