in basic_pitch/nn.py [0:0]
def call(self, x: tf.Tensor) -> tf.Tensor:
# (n_batch, n_times, n_freqs, 1)
tf.debugging.assert_equal(tf.shape(x).shape, 4)
channels = []
for shift in self.shifts:
if shift == 0:
padded = x
elif shift > 0:
paddings = tf.constant([[0, 0], [0, 0], [0, shift], [0, 0]])
padded = tf.pad(x[:, :, shift:, :], paddings)
elif shift < 0:
paddings = tf.constant([[0, 0], [0, 0], [-shift, 0], [0, 0]])
padded = tf.pad(x[:, :, :shift, :], paddings)
else:
raise ValueError
channels.append(padded)
x = tf.concat(channels, axis=-1)
x = x[:, :, : self.n_output_freqs, :] # return only the first n_output_freqs frequency channels
return x