|
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
"""Utility functions for manipulating Keras models."""
|
|
|
|
import tensorflow as tf
|
|
|
|
|
|
def extract_submodel(model, inputs, outputs, name=None):
|
|
"""Extracts a section of a Keras model into a new model.
|
|
|
|
This method walks an existing model from the specified outputs back to the
|
|
specified inputs in order to construct a new model containing only a portion
|
|
of the old model, while sharing the layers and weights with the original
|
|
model.
|
|
|
|
WARNING: This method does not work for submodels containing layers that have
|
|
been used multiple times in the original model, or in other models beyond
|
|
the original model. (E.g. does not work for submodels that contain layers that
|
|
use shared weights). This also means that multiple overlapping submodels
|
|
cannot be extracted from the same model.
|
|
|
|
It also relies on recursion and will hit python's recursion limit for large
|
|
submodels.
|
|
|
|
Args:
|
|
model: The existing Keras model this method extracts a submodel from.
|
|
inputs: The layer inputs in the existing model that start the submodel
|
|
outputs: The layer outputs in the existing model that should be output by
|
|
the submodel
|
|
name: The name for the extracted model
|
|
|
|
Returns:
|
|
The extracted submodel specified by the given inputs and outputs
|
|
"""
|
|
output_to_layer = {}
|
|
output_to_layer_input = {}
|
|
for layer in model.layers:
|
|
layer_output = layer.output
|
|
layer_inputs = layer.input
|
|
output_to_layer[layer_output] = layer
|
|
output_to_layer_input[layer_output] = layer_inputs
|
|
|
|
model_inputs_dict = {}
|
|
memoized_results = {}
|
|
|
|
# Relies on recursion, very low limit in python
|
|
def _recurse_in_model(tensor):
|
|
"""Walk the existing model recursively to copy a submodel."""
|
|
if tensor in memoized_results:
|
|
return memoized_results[tensor]
|
|
if (tensor == inputs) or (isinstance(inputs, list) and tensor in inputs):
|
|
if tensor not in model_inputs_dict:
|
|
model_inputs_dict[tensor] = tf.keras.layers.Input(tensor=tensor)
|
|
out = model_inputs_dict[tensor]
|
|
else:
|
|
cur_inputs = output_to_layer_input[tensor]
|
|
cur_layer = output_to_layer[tensor]
|
|
if isinstance(cur_inputs, list):
|
|
out = cur_layer([_recurse_in_model(inp) for inp in cur_inputs])
|
|
else:
|
|
out = cur_layer(_recurse_in_model(cur_inputs))
|
|
memoized_results[tensor] = out
|
|
return out
|
|
|
|
if isinstance(outputs, list):
|
|
model_outputs = [_recurse_in_model(tensor) for tensor in outputs]
|
|
else:
|
|
model_outputs = _recurse_in_model(outputs)
|
|
|
|
if isinstance(inputs, list):
|
|
model_inputs = [model_inputs_dict[tensor] for tensor in inputs]
|
|
else:
|
|
model_inputs = model_inputs_dict[inputs]
|
|
|
|
return tf.keras.Model(inputs=model_inputs, outputs=model_outputs, name=name)
|