in realbook/layers/compatibility.py [0:0]
def get_all_tensors_from_saved_model(saved_model_or_path: Union[tf.keras.Model, str]) -> List[tf.Tensor]:
"""
Given a path to a SavedModel or an already loaded SavedModel,
return a list of all of its tensors. Useful for figuring out the
names of intermediate tensors for use with SavedModelLayer.
To extract the output of a given Keras layer (which isn't stored in the
SavedModel, as TensorFlow SavedModels don't acually save layer
information), try something like:
```
tensors = get_all_tensors_from_saved_model('./my-saved-model')
layer_name = "some_named_layer"
probable_output_of_layer = [t for t in tensors if layer_name in t.name][-1]
# Should probably check that this is the model's input, as expected:
probable_input_to_model = tensors[0]
# Create a sub-graph of the loaded model from the input and output tensors you want:
sub_graph = create_function_from_tensors(probable_input_to_model, probable_output_of_layer)
# Use that sub-graph wherever you'd like:
my_model = tf.keras.models.Sequential([
tf.keras.layers.Lambda(lambda input_tensor: sub_graph(input_tensor)),
])
# (Note that if you're extracting multiple inputs, your lambda function
# must pass each input to sub_graph as a separate argument,
# i.e.: `lambda input_tensors: sub_graph(*input_tensors)`.)
```
"""
if isinstance(saved_model_or_path, str):
savedmodel = tf.saved_model.load(saved_model_or_path)
model = savedmodel.signatures["serving_default"]
model._backref = savedmodel # Without this, the SavedModel will be GC'd too early
else:
model = saved_model_or_path
if hasattr(model, "signatures"):
model = model.signatures["serving_default"]
frozen_func, _ = convert_variables_to_constants_v2_as_graph(model)
all_inputs_and_outputs: List[tf.Tensor] = sum(
[list(op.inputs) + list(op.outputs) for op in frozen_func.graph.get_operations()],
[],
)
# Using this to find unique tensors instead of set() because tf.Tensor is not hashable.
seen_names = set()
res = []
for obj in all_inputs_and_outputs:
if obj.name not in seen_names:
seen_names.add(obj.name)
res.append(obj)
return res