You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

239 lines
9.5 KiB

  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for object_detection.utils.variables_helper."""
  16. import os
  17. import tensorflow as tf
  18. from object_detection.utils import variables_helper
  19. class FilterVariablesTest(tf.test.TestCase):
  20. def _create_variables(self):
  21. return [tf.Variable(1.0, name='FeatureExtractor/InceptionV3/weights'),
  22. tf.Variable(1.0, name='FeatureExtractor/InceptionV3/biases'),
  23. tf.Variable(1.0, name='StackProposalGenerator/weights'),
  24. tf.Variable(1.0, name='StackProposalGenerator/biases')]
  25. def test_return_all_variables_when_empty_regex(self):
  26. variables = self._create_variables()
  27. out_variables = variables_helper.filter_variables(variables, [''])
  28. self.assertItemsEqual(out_variables, variables)
  29. def test_return_variables_which_do_not_match_single_regex(self):
  30. variables = self._create_variables()
  31. out_variables = variables_helper.filter_variables(variables,
  32. ['FeatureExtractor/.*'])
  33. self.assertItemsEqual(out_variables, variables[2:])
  34. def test_return_variables_which_do_not_match_any_regex_in_list(self):
  35. variables = self._create_variables()
  36. out_variables = variables_helper.filter_variables(variables, [
  37. 'FeatureExtractor.*biases', 'StackProposalGenerator.*biases'
  38. ])
  39. self.assertItemsEqual(out_variables, [variables[0], variables[2]])
  40. def test_return_variables_matching_empty_regex_list(self):
  41. variables = self._create_variables()
  42. out_variables = variables_helper.filter_variables(
  43. variables, [''], invert=True)
  44. self.assertItemsEqual(out_variables, [])
  45. def test_return_variables_matching_some_regex_in_list(self):
  46. variables = self._create_variables()
  47. out_variables = variables_helper.filter_variables(
  48. variables,
  49. ['FeatureExtractor.*biases', 'StackProposalGenerator.*biases'],
  50. invert=True)
  51. self.assertItemsEqual(out_variables, [variables[1], variables[3]])
  52. class MultiplyGradientsMatchingRegexTest(tf.test.TestCase):
  53. def _create_grads_and_vars(self):
  54. return [(tf.constant(1.0),
  55. tf.Variable(1.0, name='FeatureExtractor/InceptionV3/weights')),
  56. (tf.constant(2.0),
  57. tf.Variable(2.0, name='FeatureExtractor/InceptionV3/biases')),
  58. (tf.constant(3.0),
  59. tf.Variable(3.0, name='StackProposalGenerator/weights')),
  60. (tf.constant(4.0),
  61. tf.Variable(4.0, name='StackProposalGenerator/biases'))]
  62. def test_multiply_all_feature_extractor_variables(self):
  63. grads_and_vars = self._create_grads_and_vars()
  64. regex_list = ['FeatureExtractor/.*']
  65. multiplier = 0.0
  66. grads_and_vars = variables_helper.multiply_gradients_matching_regex(
  67. grads_and_vars, regex_list, multiplier)
  68. exp_output = [(0.0, 1.0), (0.0, 2.0), (3.0, 3.0), (4.0, 4.0)]
  69. init_op = tf.global_variables_initializer()
  70. with self.test_session() as sess:
  71. sess.run(init_op)
  72. output = sess.run(grads_and_vars)
  73. self.assertItemsEqual(output, exp_output)
  74. def test_multiply_all_bias_variables(self):
  75. grads_and_vars = self._create_grads_and_vars()
  76. regex_list = ['.*/biases']
  77. multiplier = 0.0
  78. grads_and_vars = variables_helper.multiply_gradients_matching_regex(
  79. grads_and_vars, regex_list, multiplier)
  80. exp_output = [(1.0, 1.0), (0.0, 2.0), (3.0, 3.0), (0.0, 4.0)]
  81. init_op = tf.global_variables_initializer()
  82. with self.test_session() as sess:
  83. sess.run(init_op)
  84. output = sess.run(grads_and_vars)
  85. self.assertItemsEqual(output, exp_output)
  86. class FreezeGradientsMatchingRegexTest(tf.test.TestCase):
  87. def _create_grads_and_vars(self):
  88. return [(tf.constant(1.0),
  89. tf.Variable(1.0, name='FeatureExtractor/InceptionV3/weights')),
  90. (tf.constant(2.0),
  91. tf.Variable(2.0, name='FeatureExtractor/InceptionV3/biases')),
  92. (tf.constant(3.0),
  93. tf.Variable(3.0, name='StackProposalGenerator/weights')),
  94. (tf.constant(4.0),
  95. tf.Variable(4.0, name='StackProposalGenerator/biases'))]
  96. def test_freeze_all_feature_extractor_variables(self):
  97. grads_and_vars = self._create_grads_and_vars()
  98. regex_list = ['FeatureExtractor/.*']
  99. grads_and_vars = variables_helper.freeze_gradients_matching_regex(
  100. grads_and_vars, regex_list)
  101. exp_output = [(3.0, 3.0), (4.0, 4.0)]
  102. init_op = tf.global_variables_initializer()
  103. with self.test_session() as sess:
  104. sess.run(init_op)
  105. output = sess.run(grads_and_vars)
  106. self.assertItemsEqual(output, exp_output)
  107. class GetVariablesAvailableInCheckpointTest(tf.test.TestCase):
  108. def test_return_all_variables_from_checkpoint(self):
  109. with tf.Graph().as_default():
  110. variables = [
  111. tf.Variable(1.0, name='weights'),
  112. tf.Variable(1.0, name='biases')
  113. ]
  114. checkpoint_path = os.path.join(self.get_temp_dir(), 'model.ckpt')
  115. init_op = tf.global_variables_initializer()
  116. saver = tf.train.Saver(variables)
  117. with self.test_session() as sess:
  118. sess.run(init_op)
  119. saver.save(sess, checkpoint_path)
  120. out_variables = variables_helper.get_variables_available_in_checkpoint(
  121. variables, checkpoint_path)
  122. self.assertItemsEqual(out_variables, variables)
  123. def test_return_all_variables_from_checkpoint_with_partition(self):
  124. with tf.Graph().as_default():
  125. partitioner = tf.fixed_size_partitioner(2)
  126. variables = [
  127. tf.get_variable(
  128. name='weights', shape=(2, 2), partitioner=partitioner),
  129. tf.Variable([1.0, 2.0], name='biases')
  130. ]
  131. checkpoint_path = os.path.join(self.get_temp_dir(), 'model.ckpt')
  132. init_op = tf.global_variables_initializer()
  133. saver = tf.train.Saver(variables)
  134. with self.test_session() as sess:
  135. sess.run(init_op)
  136. saver.save(sess, checkpoint_path)
  137. out_variables = variables_helper.get_variables_available_in_checkpoint(
  138. variables, checkpoint_path)
  139. self.assertItemsEqual(out_variables, variables)
  140. def test_return_variables_available_in_checkpoint(self):
  141. checkpoint_path = os.path.join(self.get_temp_dir(), 'model.ckpt')
  142. with tf.Graph().as_default():
  143. weight_variable = tf.Variable(1.0, name='weights')
  144. global_step = tf.train.get_or_create_global_step()
  145. graph1_variables = [
  146. weight_variable,
  147. global_step
  148. ]
  149. init_op = tf.global_variables_initializer()
  150. saver = tf.train.Saver(graph1_variables)
  151. with self.test_session() as sess:
  152. sess.run(init_op)
  153. saver.save(sess, checkpoint_path)
  154. with tf.Graph().as_default():
  155. graph2_variables = graph1_variables + [tf.Variable(1.0, name='biases')]
  156. out_variables = variables_helper.get_variables_available_in_checkpoint(
  157. graph2_variables, checkpoint_path, include_global_step=False)
  158. self.assertItemsEqual(out_variables, [weight_variable])
  159. def test_return_variables_available_an_checkpoint_with_dict_inputs(self):
  160. checkpoint_path = os.path.join(self.get_temp_dir(), 'model.ckpt')
  161. with tf.Graph().as_default():
  162. graph1_variables = [
  163. tf.Variable(1.0, name='ckpt_weights'),
  164. ]
  165. init_op = tf.global_variables_initializer()
  166. saver = tf.train.Saver(graph1_variables)
  167. with self.test_session() as sess:
  168. sess.run(init_op)
  169. saver.save(sess, checkpoint_path)
  170. with tf.Graph().as_default():
  171. graph2_variables_dict = {
  172. 'ckpt_weights': tf.Variable(1.0, name='weights'),
  173. 'ckpt_biases': tf.Variable(1.0, name='biases')
  174. }
  175. out_variables = variables_helper.get_variables_available_in_checkpoint(
  176. graph2_variables_dict, checkpoint_path)
  177. self.assertTrue(isinstance(out_variables, dict))
  178. self.assertItemsEqual(out_variables.keys(), ['ckpt_weights'])
  179. self.assertTrue(out_variables['ckpt_weights'].op.name == 'weights')
  180. def test_return_variables_with_correct_sizes(self):
  181. checkpoint_path = os.path.join(self.get_temp_dir(), 'model.ckpt')
  182. with tf.Graph().as_default():
  183. bias_variable = tf.Variable(3.0, name='biases')
  184. global_step = tf.train.get_or_create_global_step()
  185. graph1_variables = [
  186. tf.Variable([[1.0, 2.0], [3.0, 4.0]], name='weights'),
  187. bias_variable,
  188. global_step
  189. ]
  190. init_op = tf.global_variables_initializer()
  191. saver = tf.train.Saver(graph1_variables)
  192. with self.test_session() as sess:
  193. sess.run(init_op)
  194. saver.save(sess, checkpoint_path)
  195. with tf.Graph().as_default():
  196. graph2_variables = [
  197. tf.Variable([1.0, 2.0], name='weights'), # New variable shape.
  198. bias_variable,
  199. global_step
  200. ]
  201. out_variables = variables_helper.get_variables_available_in_checkpoint(
  202. graph2_variables, checkpoint_path, include_global_step=True)
  203. self.assertItemsEqual(out_variables, [bias_variable, global_step])
  204. if __name__ == '__main__':
  205. tf.test.main()