in realbook/callbacks/spectrogram_visualization.py [0:0]
def on_train_begin(self, logs: Any = None) -> None:
try:
# Create a tempoary model using only the frontend of the model,
# as defined by "the largest sequence of non-trainable layers at the start."
non_trainable_input_layers = []
for layer in self.model.layers:
if len(layer.trainable_variables) + len(layer.trainable_weights) > 0:
break
else:
non_trainable_input_layers.append(layer)
if not non_trainable_input_layers:
raise ValueError("No non-trainable input layers could be inferred for spectrogram visualization.")
# Don't use tf.keras.models.Sequential here, as the input may not be traditional Layers.
# (Yes, you'd think that self.model.layers returns all layers - but that doesn't seem to be the case.)
input_to_image = tf.keras.models.Model(
inputs=non_trainable_input_layers[0].input, outputs=non_trainable_input_layers[-1].output
)
with self.tensorboard_writer.as_default():
# Pull n random batches from the dataset and send them to TensorBoard.
for data, _ in self.example_batches:
assert tf.rank(data) == 2, "Expected input data to be of rank 2, with shape (batch, audio)."
assert tf.shape(data)[0] < tf.shape(data)[1], (
"Expected input data to be of rank 2, with shape (batch, audio), but got shape"
f" {tf.shape(data)}."
)
spectrograms = input_to_image(data)
assert tf.rank(spectrograms) in (3, 4), (
"Expected non-trainable input layers to produce output of shape (batch, x, y) "
f"or (batch, x, y, 1), but got {tf.shape(spectrograms)}"
)
if tf.rank(spectrograms) == 4:
assert tf.shape(spectrograms)[-1] == 1, (
"Expected non-trainable input layers to produce output with one channel, but shape is"
f" {tf.shape(spectrograms)}"
)
# Ignore the single channel dimension, if it exists.
spectrograms = spectrograms[:, :, :, 0]
# We can infer the hop length, as we know the input audio length
# and sample rate used in the spectrogram
length_in_samples = data.shape[-1]
length_in_frames = spectrograms.shape[-2]
hop_length = int(tf.math.ceil(length_in_samples / length_in_frames))
figs = []
for spectrogram in spectrograms:
plt.clf()
fig, ax = plt.subplots()
spectrogram = np.abs(spectrogram).T
if self.convert_to_dB is True:
spectrogram = librosa.amplitude_to_db(spectrogram, ref=np.max)
elif callable(self.convert_to_dB):
spectrogram = self.convert_to_dB(spectrogram)
librosa.display.specshow(
spectrogram,
sr=self.sample_rate_hz,
hop_length=hop_length,
ax=ax,
**self.specshow_arguments,
)
figs.append(plot_to_image(fig))
tf.summary.image(
self.name,
np.concatenate(figs),
step=0, # We only output this once, so epoch doesn't matter.
max_outputs=1000000,
)
plt.clf()
self.tensorboard_writer.flush()
except Exception as e:
if self.raise_on_error:
raise
logging.error(f"{self.__class__.__name__} failed: ", exc_info=e)