def on_train_begin()

in realbook/callbacks/spectrogram_visualization.py [0:0]


    def on_train_begin(self, logs: Any = None) -> None:
        try:
            # Create a tempoary model using only the frontend of the model,
            # as defined by "the largest sequence of non-trainable layers at the start."
            non_trainable_input_layers = []
            for layer in self.model.layers:
                if len(layer.trainable_variables) + len(layer.trainable_weights) > 0:
                    break
                else:
                    non_trainable_input_layers.append(layer)

            if not non_trainable_input_layers:
                raise ValueError("No non-trainable input layers could be inferred for spectrogram visualization.")

            # Don't use tf.keras.models.Sequential here, as the input may not be traditional Layers.
            # (Yes, you'd think that self.model.layers returns all layers - but that doesn't seem to be the case.)
            input_to_image = tf.keras.models.Model(
                inputs=non_trainable_input_layers[0].input, outputs=non_trainable_input_layers[-1].output
            )

            with self.tensorboard_writer.as_default():
                # Pull n random batches from the dataset and send them to TensorBoard.
                for data, _ in self.example_batches:
                    assert tf.rank(data) == 2, "Expected input data to be of rank 2, with shape (batch, audio)."
                    assert tf.shape(data)[0] < tf.shape(data)[1], (
                        "Expected input data to be of rank 2, with shape (batch, audio), but got shape"
                        f" {tf.shape(data)}."
                    )

                    spectrograms = input_to_image(data)
                    assert tf.rank(spectrograms) in (3, 4), (
                        "Expected non-trainable input layers to produce output of shape (batch, x, y) "
                        f"or (batch, x, y, 1), but got {tf.shape(spectrograms)}"
                    )
                    if tf.rank(spectrograms) == 4:
                        assert tf.shape(spectrograms)[-1] == 1, (
                            "Expected non-trainable input layers to produce output with one channel, but shape is"
                            f" {tf.shape(spectrograms)}"
                        )
                        # Ignore the single channel dimension, if it exists.
                        spectrograms = spectrograms[:, :, :, 0]

                    # We can infer the hop length, as we know the input audio length
                    # and sample rate used in the spectrogram
                    length_in_samples = data.shape[-1]
                    length_in_frames = spectrograms.shape[-2]
                    hop_length = int(tf.math.ceil(length_in_samples / length_in_frames))

                    figs = []
                    for spectrogram in spectrograms:
                        plt.clf()
                        fig, ax = plt.subplots()
                        spectrogram = np.abs(spectrogram).T

                        if self.convert_to_dB is True:
                            spectrogram = librosa.amplitude_to_db(spectrogram, ref=np.max)
                        elif callable(self.convert_to_dB):
                            spectrogram = self.convert_to_dB(spectrogram)

                        librosa.display.specshow(
                            spectrogram,
                            sr=self.sample_rate_hz,
                            hop_length=hop_length,
                            ax=ax,
                            **self.specshow_arguments,
                        )

                        figs.append(plot_to_image(fig))
                    tf.summary.image(
                        self.name,
                        np.concatenate(figs),
                        step=0,  # We only output this once, so epoch doesn't matter.
                        max_outputs=1000000,
                    )
                    plt.clf()
            self.tensorboard_writer.flush()
        except Exception as e:
            if self.raise_on_error:
                raise
            logging.error(f"{self.__class__.__name__} failed: ", exc_info=e)