You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

155 lines
5.9 KiB

  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Helper functions for manipulating collections of variables during training.
  16. """
  17. import logging
  18. import re
  19. import tensorflow as tf
  20. from tensorflow.python.ops import variables as tf_variables
  21. slim = tf.contrib.slim
  22. # TODO(derekjchow): Consider replacing with tf.contrib.filter_variables in
  23. # tensorflow/contrib/framework/python/ops/variables.py
  24. def filter_variables(variables, filter_regex_list, invert=False):
  25. """Filters out the variables matching the filter_regex.
  26. Filter out the variables whose name matches the any of the regular
  27. expressions in filter_regex_list and returns the remaining variables.
  28. Optionally, if invert=True, the complement set is returned.
  29. Args:
  30. variables: a list of tensorflow variables.
  31. filter_regex_list: a list of string regular expressions.
  32. invert: (boolean). If True, returns the complement of the filter set; that
  33. is, all variables matching filter_regex are kept and all others discarded.
  34. Returns:
  35. a list of filtered variables.
  36. """
  37. kept_vars = []
  38. variables_to_ignore_patterns = list(filter(None, filter_regex_list))
  39. for var in variables:
  40. add = True
  41. for pattern in variables_to_ignore_patterns:
  42. if re.match(pattern, var.op.name):
  43. add = False
  44. break
  45. if add != invert:
  46. kept_vars.append(var)
  47. return kept_vars
  48. def multiply_gradients_matching_regex(grads_and_vars, regex_list, multiplier):
  49. """Multiply gradients whose variable names match a regular expression.
  50. Args:
  51. grads_and_vars: A list of gradient to variable pairs (tuples).
  52. regex_list: A list of string regular expressions.
  53. multiplier: A (float) multiplier to apply to each gradient matching the
  54. regular expression.
  55. Returns:
  56. grads_and_vars: A list of gradient to variable pairs (tuples).
  57. """
  58. variables = [pair[1] for pair in grads_and_vars]
  59. matching_vars = filter_variables(variables, regex_list, invert=True)
  60. for var in matching_vars:
  61. logging.info('Applying multiplier %f to variable [%s]',
  62. multiplier, var.op.name)
  63. grad_multipliers = {var: float(multiplier) for var in matching_vars}
  64. return slim.learning.multiply_gradients(grads_and_vars,
  65. grad_multipliers)
  66. def freeze_gradients_matching_regex(grads_and_vars, regex_list):
  67. """Freeze gradients whose variable names match a regular expression.
  68. Args:
  69. grads_and_vars: A list of gradient to variable pairs (tuples).
  70. regex_list: A list of string regular expressions.
  71. Returns:
  72. grads_and_vars: A list of gradient to variable pairs (tuples) that do not
  73. contain the variables and gradients matching the regex.
  74. """
  75. variables = [pair[1] for pair in grads_and_vars]
  76. matching_vars = filter_variables(variables, regex_list, invert=True)
  77. kept_grads_and_vars = [pair for pair in grads_and_vars
  78. if pair[1] not in matching_vars]
  79. for var in matching_vars:
  80. logging.info('Freezing variable [%s]', var.op.name)
  81. return kept_grads_and_vars
  82. def get_variables_available_in_checkpoint(variables,
  83. checkpoint_path,
  84. include_global_step=True):
  85. """Returns the subset of variables available in the checkpoint.
  86. Inspects given checkpoint and returns the subset of variables that are
  87. available in it.
  88. TODO(rathodv): force input and output to be a dictionary.
  89. Args:
  90. variables: a list or dictionary of variables to find in checkpoint.
  91. checkpoint_path: path to the checkpoint to restore variables from.
  92. include_global_step: whether to include `global_step` variable, if it
  93. exists. Default True.
  94. Returns:
  95. A list or dictionary of variables.
  96. Raises:
  97. ValueError: if `variables` is not a list or dict.
  98. """
  99. if isinstance(variables, list):
  100. variable_names_map = {}
  101. for variable in variables:
  102. if isinstance(variable, tf_variables.PartitionedVariable):
  103. name = variable.name
  104. else:
  105. name = variable.op.name
  106. variable_names_map[name] = variable
  107. elif isinstance(variables, dict):
  108. variable_names_map = variables
  109. else:
  110. raise ValueError('`variables` is expected to be a list or dict.')
  111. ckpt_reader = tf.train.NewCheckpointReader(checkpoint_path)
  112. ckpt_vars_to_shape_map = ckpt_reader.get_variable_to_shape_map()
  113. if not include_global_step:
  114. ckpt_vars_to_shape_map.pop(tf.GraphKeys.GLOBAL_STEP, None)
  115. vars_in_ckpt = {}
  116. for variable_name, variable in sorted(variable_names_map.items()):
  117. if variable_name in ckpt_vars_to_shape_map:
  118. if ckpt_vars_to_shape_map[variable_name] == variable.shape.as_list():
  119. vars_in_ckpt[variable_name] = variable
  120. else:
  121. logging.warning('Variable [%s] is available in checkpoint, but has an '
  122. 'incompatible shape with model variable. Checkpoint '
  123. 'shape: [%s], model variable shape: [%s]. This '
  124. 'variable will not be initialized from the checkpoint.',
  125. variable_name, ckpt_vars_to_shape_map[variable_name],
  126. variable.shape.as_list())
  127. else:
  128. logging.warning('Variable [%s] is not available in checkpoint',
  129. variable_name)
  130. if isinstance(variables, list):
  131. return vars_in_ckpt.values()
  132. return vars_in_ckpt