|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
"""Helper functions for manipulating collections of variables during training.
|
|
"""
|
|
import logging
|
|
import re
|
|
|
|
import tensorflow as tf
|
|
|
|
from tensorflow.python.ops import variables as tf_variables
|
|
|
|
slim = tf.contrib.slim
|
|
|
|
|
|
# TODO(derekjchow): Consider replacing with tf.contrib.filter_variables in
|
|
# tensorflow/contrib/framework/python/ops/variables.py
|
|
def filter_variables(variables, filter_regex_list, invert=False):
|
|
"""Filters out the variables matching the filter_regex.
|
|
|
|
Filter out the variables whose name matches the any of the regular
|
|
expressions in filter_regex_list and returns the remaining variables.
|
|
Optionally, if invert=True, the complement set is returned.
|
|
|
|
Args:
|
|
variables: a list of tensorflow variables.
|
|
filter_regex_list: a list of string regular expressions.
|
|
invert: (boolean). If True, returns the complement of the filter set; that
|
|
is, all variables matching filter_regex are kept and all others discarded.
|
|
|
|
Returns:
|
|
a list of filtered variables.
|
|
"""
|
|
kept_vars = []
|
|
variables_to_ignore_patterns = list(filter(None, filter_regex_list))
|
|
for var in variables:
|
|
add = True
|
|
for pattern in variables_to_ignore_patterns:
|
|
if re.match(pattern, var.op.name):
|
|
add = False
|
|
break
|
|
if add != invert:
|
|
kept_vars.append(var)
|
|
return kept_vars
|
|
|
|
|
|
def multiply_gradients_matching_regex(grads_and_vars, regex_list, multiplier):
|
|
"""Multiply gradients whose variable names match a regular expression.
|
|
|
|
Args:
|
|
grads_and_vars: A list of gradient to variable pairs (tuples).
|
|
regex_list: A list of string regular expressions.
|
|
multiplier: A (float) multiplier to apply to each gradient matching the
|
|
regular expression.
|
|
|
|
Returns:
|
|
grads_and_vars: A list of gradient to variable pairs (tuples).
|
|
"""
|
|
variables = [pair[1] for pair in grads_and_vars]
|
|
matching_vars = filter_variables(variables, regex_list, invert=True)
|
|
for var in matching_vars:
|
|
logging.info('Applying multiplier %f to variable [%s]',
|
|
multiplier, var.op.name)
|
|
grad_multipliers = {var: float(multiplier) for var in matching_vars}
|
|
return slim.learning.multiply_gradients(grads_and_vars,
|
|
grad_multipliers)
|
|
|
|
|
|
def freeze_gradients_matching_regex(grads_and_vars, regex_list):
|
|
"""Freeze gradients whose variable names match a regular expression.
|
|
|
|
Args:
|
|
grads_and_vars: A list of gradient to variable pairs (tuples).
|
|
regex_list: A list of string regular expressions.
|
|
|
|
Returns:
|
|
grads_and_vars: A list of gradient to variable pairs (tuples) that do not
|
|
contain the variables and gradients matching the regex.
|
|
"""
|
|
variables = [pair[1] for pair in grads_and_vars]
|
|
matching_vars = filter_variables(variables, regex_list, invert=True)
|
|
kept_grads_and_vars = [pair for pair in grads_and_vars
|
|
if pair[1] not in matching_vars]
|
|
for var in matching_vars:
|
|
logging.info('Freezing variable [%s]', var.op.name)
|
|
return kept_grads_and_vars
|
|
|
|
|
|
def get_variables_available_in_checkpoint(variables,
|
|
checkpoint_path,
|
|
include_global_step=True):
|
|
"""Returns the subset of variables available in the checkpoint.
|
|
|
|
Inspects given checkpoint and returns the subset of variables that are
|
|
available in it.
|
|
|
|
TODO(rathodv): force input and output to be a dictionary.
|
|
|
|
Args:
|
|
variables: a list or dictionary of variables to find in checkpoint.
|
|
checkpoint_path: path to the checkpoint to restore variables from.
|
|
include_global_step: whether to include `global_step` variable, if it
|
|
exists. Default True.
|
|
|
|
Returns:
|
|
A list or dictionary of variables.
|
|
Raises:
|
|
ValueError: if `variables` is not a list or dict.
|
|
"""
|
|
if isinstance(variables, list):
|
|
variable_names_map = {}
|
|
for variable in variables:
|
|
if isinstance(variable, tf_variables.PartitionedVariable):
|
|
name = variable.name
|
|
else:
|
|
name = variable.op.name
|
|
variable_names_map[name] = variable
|
|
elif isinstance(variables, dict):
|
|
variable_names_map = variables
|
|
else:
|
|
raise ValueError('`variables` is expected to be a list or dict.')
|
|
ckpt_reader = tf.train.NewCheckpointReader(checkpoint_path)
|
|
ckpt_vars_to_shape_map = ckpt_reader.get_variable_to_shape_map()
|
|
if not include_global_step:
|
|
ckpt_vars_to_shape_map.pop(tf.GraphKeys.GLOBAL_STEP, None)
|
|
vars_in_ckpt = {}
|
|
for variable_name, variable in sorted(variable_names_map.items()):
|
|
if variable_name in ckpt_vars_to_shape_map:
|
|
if ckpt_vars_to_shape_map[variable_name] == variable.shape.as_list():
|
|
vars_in_ckpt[variable_name] = variable
|
|
else:
|
|
logging.warning('Variable [%s] is available in checkpoint, but has an '
|
|
'incompatible shape with model variable. Checkpoint '
|
|
'shape: [%s], model variable shape: [%s]. This '
|
|
'variable will not be initialized from the checkpoint.',
|
|
variable_name, ckpt_vars_to_shape_map[variable_name],
|
|
variable.shape.as_list())
|
|
else:
|
|
logging.warning('Variable [%s] is not available in checkpoint',
|
|
variable_name)
|
|
if isinstance(variables, list):
|
|
return vars_in_ckpt.values()
|
|
return vars_in_ckpt
|