def get_all_tensors_from_saved_model()

in realbook/layers/compatibility.py [0:0]


def get_all_tensors_from_saved_model(saved_model_or_path: Union[tf.keras.Model, str]) -> List[tf.Tensor]:
    """
    Given a path to a SavedModel or an already loaded SavedModel,
    return a list of all of its tensors. Useful for figuring out the
    names of intermediate tensors for use with SavedModelLayer.

    To extract the output of a given Keras layer (which isn't stored in the
    SavedModel, as TensorFlow SavedModels don't acually save layer
    information), try something like:

    ```
        tensors = get_all_tensors_from_saved_model('./my-saved-model')

        layer_name = "some_named_layer"
        probable_output_of_layer = [t for t in tensors if layer_name in t.name][-1]

        # Should probably check that this is the model's input, as expected:
        probable_input_to_model = tensors[0]

        # Create a sub-graph of the loaded model from the input and output tensors you want:
        sub_graph = create_function_from_tensors(probable_input_to_model, probable_output_of_layer)

        # Use that sub-graph wherever you'd like:
        my_model = tf.keras.models.Sequential([
            tf.keras.layers.Lambda(lambda input_tensor: sub_graph(input_tensor)),
        ])

        # (Note that if you're extracting multiple inputs, your lambda function
        # must pass each input to sub_graph as a separate argument,
        # i.e.: `lambda input_tensors: sub_graph(*input_tensors)`.)
    ```

    """
    if isinstance(saved_model_or_path, str):
        savedmodel = tf.saved_model.load(saved_model_or_path)
        model = savedmodel.signatures["serving_default"]
        model._backref = savedmodel  # Without this, the SavedModel will be GC'd too early
    else:
        model = saved_model_or_path
    if hasattr(model, "signatures"):
        model = model.signatures["serving_default"]
    frozen_func, _ = convert_variables_to_constants_v2_as_graph(model)

    all_inputs_and_outputs: List[tf.Tensor] = sum(
        [list(op.inputs) + list(op.outputs) for op in frozen_func.graph.get_operations()],
        [],
    )

    # Using this to find unique tensors instead of set() because tf.Tensor is not hashable.
    seen_names = set()
    res = []
    for obj in all_inputs_and_outputs:
        if obj.name not in seen_names:
            seen_names.add(obj.name)
            res.append(obj)
    return res