You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

208 lines
6.7 KiB

6 years ago
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for optimizer_builder."""
  16. import tensorflow as tf
  17. from google.protobuf import text_format
  18. from object_detection.builders import optimizer_builder
  19. from object_detection.protos import optimizer_pb2
  20. class LearningRateBuilderTest(tf.test.TestCase):
  21. def testBuildConstantLearningRate(self):
  22. learning_rate_text_proto = """
  23. constant_learning_rate {
  24. learning_rate: 0.004
  25. }
  26. """
  27. learning_rate_proto = optimizer_pb2.LearningRate()
  28. text_format.Merge(learning_rate_text_proto, learning_rate_proto)
  29. learning_rate = optimizer_builder._create_learning_rate(
  30. learning_rate_proto)
  31. self.assertTrue(learning_rate.op.name.endswith('learning_rate'))
  32. with self.test_session():
  33. learning_rate_out = learning_rate.eval()
  34. self.assertAlmostEqual(learning_rate_out, 0.004)
  35. def testBuildExponentialDecayLearningRate(self):
  36. learning_rate_text_proto = """
  37. exponential_decay_learning_rate {
  38. initial_learning_rate: 0.004
  39. decay_steps: 99999
  40. decay_factor: 0.85
  41. staircase: false
  42. }
  43. """
  44. learning_rate_proto = optimizer_pb2.LearningRate()
  45. text_format.Merge(learning_rate_text_proto, learning_rate_proto)
  46. learning_rate = optimizer_builder._create_learning_rate(
  47. learning_rate_proto)
  48. self.assertTrue(learning_rate.op.name.endswith('learning_rate'))
  49. self.assertTrue(isinstance(learning_rate, tf.Tensor))
  50. def testBuildManualStepLearningRate(self):
  51. learning_rate_text_proto = """
  52. manual_step_learning_rate {
  53. initial_learning_rate: 0.002
  54. schedule {
  55. step: 100
  56. learning_rate: 0.006
  57. }
  58. schedule {
  59. step: 90000
  60. learning_rate: 0.00006
  61. }
  62. warmup: true
  63. }
  64. """
  65. learning_rate_proto = optimizer_pb2.LearningRate()
  66. text_format.Merge(learning_rate_text_proto, learning_rate_proto)
  67. learning_rate = optimizer_builder._create_learning_rate(
  68. learning_rate_proto)
  69. self.assertTrue(isinstance(learning_rate, tf.Tensor))
  70. def testBuildCosineDecayLearningRate(self):
  71. learning_rate_text_proto = """
  72. cosine_decay_learning_rate {
  73. learning_rate_base: 0.002
  74. total_steps: 20000
  75. warmup_learning_rate: 0.0001
  76. warmup_steps: 1000
  77. hold_base_rate_steps: 20000
  78. }
  79. """
  80. learning_rate_proto = optimizer_pb2.LearningRate()
  81. text_format.Merge(learning_rate_text_proto, learning_rate_proto)
  82. learning_rate = optimizer_builder._create_learning_rate(
  83. learning_rate_proto)
  84. self.assertTrue(isinstance(learning_rate, tf.Tensor))
  85. def testRaiseErrorOnEmptyLearningRate(self):
  86. learning_rate_text_proto = """
  87. """
  88. learning_rate_proto = optimizer_pb2.LearningRate()
  89. text_format.Merge(learning_rate_text_proto, learning_rate_proto)
  90. with self.assertRaises(ValueError):
  91. optimizer_builder._create_learning_rate(learning_rate_proto)
  92. class OptimizerBuilderTest(tf.test.TestCase):
  93. def testBuildRMSPropOptimizer(self):
  94. optimizer_text_proto = """
  95. rms_prop_optimizer: {
  96. learning_rate: {
  97. exponential_decay_learning_rate {
  98. initial_learning_rate: 0.004
  99. decay_steps: 800720
  100. decay_factor: 0.95
  101. }
  102. }
  103. momentum_optimizer_value: 0.9
  104. decay: 0.9
  105. epsilon: 1.0
  106. }
  107. use_moving_average: false
  108. """
  109. optimizer_proto = optimizer_pb2.Optimizer()
  110. text_format.Merge(optimizer_text_proto, optimizer_proto)
  111. optimizer, _ = optimizer_builder.build(optimizer_proto)
  112. self.assertTrue(isinstance(optimizer, tf.train.RMSPropOptimizer))
  113. def testBuildMomentumOptimizer(self):
  114. optimizer_text_proto = """
  115. momentum_optimizer: {
  116. learning_rate: {
  117. constant_learning_rate {
  118. learning_rate: 0.001
  119. }
  120. }
  121. momentum_optimizer_value: 0.99
  122. }
  123. use_moving_average: false
  124. """
  125. optimizer_proto = optimizer_pb2.Optimizer()
  126. text_format.Merge(optimizer_text_proto, optimizer_proto)
  127. optimizer, _ = optimizer_builder.build(optimizer_proto)
  128. self.assertTrue(isinstance(optimizer, tf.train.MomentumOptimizer))
  129. def testBuildAdamOptimizer(self):
  130. optimizer_text_proto = """
  131. adam_optimizer: {
  132. learning_rate: {
  133. constant_learning_rate {
  134. learning_rate: 0.002
  135. }
  136. }
  137. }
  138. use_moving_average: false
  139. """
  140. optimizer_proto = optimizer_pb2.Optimizer()
  141. text_format.Merge(optimizer_text_proto, optimizer_proto)
  142. optimizer, _ = optimizer_builder.build(optimizer_proto)
  143. self.assertTrue(isinstance(optimizer, tf.train.AdamOptimizer))
  144. def testBuildMovingAverageOptimizer(self):
  145. optimizer_text_proto = """
  146. adam_optimizer: {
  147. learning_rate: {
  148. constant_learning_rate {
  149. learning_rate: 0.002
  150. }
  151. }
  152. }
  153. use_moving_average: True
  154. """
  155. optimizer_proto = optimizer_pb2.Optimizer()
  156. text_format.Merge(optimizer_text_proto, optimizer_proto)
  157. optimizer, _ = optimizer_builder.build(optimizer_proto)
  158. self.assertTrue(
  159. isinstance(optimizer, tf.contrib.opt.MovingAverageOptimizer))
  160. def testBuildMovingAverageOptimizerWithNonDefaultDecay(self):
  161. optimizer_text_proto = """
  162. adam_optimizer: {
  163. learning_rate: {
  164. constant_learning_rate {
  165. learning_rate: 0.002
  166. }
  167. }
  168. }
  169. use_moving_average: True
  170. moving_average_decay: 0.2
  171. """
  172. optimizer_proto = optimizer_pb2.Optimizer()
  173. text_format.Merge(optimizer_text_proto, optimizer_proto)
  174. optimizer, _ = optimizer_builder.build(optimizer_proto)
  175. self.assertTrue(
  176. isinstance(optimizer, tf.contrib.opt.MovingAverageOptimizer))
  177. # TODO(rathodv): Find a way to not depend on the private members.
  178. self.assertAlmostEqual(optimizer._ema._decay, 0.2)
  179. def testBuildEmptyOptimizer(self):
  180. optimizer_text_proto = """
  181. """
  182. optimizer_proto = optimizer_pb2.Optimizer()
  183. text_format.Merge(optimizer_text_proto, optimizer_proto)
  184. with self.assertRaises(ValueError):
  185. optimizer_builder.build(optimizer_proto)
  186. if __name__ == '__main__':
  187. tf.test.main()