def _infer_time_size()

in basic_pitch/data/tf_example_deserialization.py [0:0]


def _infer_time_size(onsets: tf.Tensor, contour: tf.Tensor, notes: tf.Tensor) -> tf.Tensor:
    """Some of the targets might be empty, but we need to find out the
    number of time frames of one of the non-empty ones.
    Returns
        number of time frames in the targets
    """
    onset_shape = tf.shape(onsets)[0]
    contour_shape = tf.shape(contour)[0]
    note_shape = tf.shape(notes)[0]
    time_size = tf.cast(
        tf.math.maximum(
            tf.cast(tf.math.maximum(onset_shape, contour_shape), dtype=tf.int32),
            note_shape,
        ),
        dtype=tf.int32,
    )

    return time_size