You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

156 lines
5.5 KiB

6 years ago
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for object_detection.utils.learning_schedules."""
  16. import numpy as np
  17. import tensorflow as tf
  18. from object_detection.utils import learning_schedules
  19. from object_detection.utils import test_case
  20. class LearningSchedulesTest(test_case.TestCase):
  21. def testExponentialDecayWithBurnin(self):
  22. def graph_fn(global_step):
  23. learning_rate_base = 1.0
  24. learning_rate_decay_steps = 3
  25. learning_rate_decay_factor = .1
  26. burnin_learning_rate = .5
  27. burnin_steps = 2
  28. min_learning_rate = .05
  29. learning_rate = learning_schedules.exponential_decay_with_burnin(
  30. global_step, learning_rate_base, learning_rate_decay_steps,
  31. learning_rate_decay_factor, burnin_learning_rate, burnin_steps,
  32. min_learning_rate)
  33. assert learning_rate.op.name.endswith('learning_rate')
  34. return (learning_rate,)
  35. output_rates = [
  36. self.execute(graph_fn, [np.array(i).astype(np.int64)]) for i in range(9)
  37. ]
  38. exp_rates = [.5, .5, 1, 1, 1, .1, .1, .1, .05]
  39. self.assertAllClose(output_rates, exp_rates, rtol=1e-4)
  40. def testCosineDecayWithWarmup(self):
  41. def graph_fn(global_step):
  42. learning_rate_base = 1.0
  43. total_steps = 100
  44. warmup_learning_rate = 0.1
  45. warmup_steps = 9
  46. learning_rate = learning_schedules.cosine_decay_with_warmup(
  47. global_step, learning_rate_base, total_steps,
  48. warmup_learning_rate, warmup_steps)
  49. assert learning_rate.op.name.endswith('learning_rate')
  50. return (learning_rate,)
  51. exp_rates = [0.1, 0.5, 0.9, 1.0, 0]
  52. input_global_steps = [0, 4, 8, 9, 100]
  53. output_rates = [
  54. self.execute(graph_fn, [np.array(step).astype(np.int64)])
  55. for step in input_global_steps
  56. ]
  57. self.assertAllClose(output_rates, exp_rates)
  58. def testCosineDecayAfterTotalSteps(self):
  59. def graph_fn(global_step):
  60. learning_rate_base = 1.0
  61. total_steps = 100
  62. warmup_learning_rate = 0.1
  63. warmup_steps = 9
  64. learning_rate = learning_schedules.cosine_decay_with_warmup(
  65. global_step, learning_rate_base, total_steps,
  66. warmup_learning_rate, warmup_steps)
  67. assert learning_rate.op.name.endswith('learning_rate')
  68. return (learning_rate,)
  69. exp_rates = [0]
  70. input_global_steps = [101]
  71. output_rates = [
  72. self.execute(graph_fn, [np.array(step).astype(np.int64)])
  73. for step in input_global_steps
  74. ]
  75. self.assertAllClose(output_rates, exp_rates)
  76. def testCosineDecayWithHoldBaseLearningRateSteps(self):
  77. def graph_fn(global_step):
  78. learning_rate_base = 1.0
  79. total_steps = 120
  80. warmup_learning_rate = 0.1
  81. warmup_steps = 9
  82. hold_base_rate_steps = 20
  83. learning_rate = learning_schedules.cosine_decay_with_warmup(
  84. global_step, learning_rate_base, total_steps,
  85. warmup_learning_rate, warmup_steps, hold_base_rate_steps)
  86. assert learning_rate.op.name.endswith('learning_rate')
  87. return (learning_rate,)
  88. exp_rates = [0.1, 0.5, 0.9, 1.0, 1.0, 1.0, 0.999702, 0.874255, 0.577365,
  89. 0.0]
  90. input_global_steps = [0, 4, 8, 9, 10, 29, 30, 50, 70, 120]
  91. output_rates = [
  92. self.execute(graph_fn, [np.array(step).astype(np.int64)])
  93. for step in input_global_steps
  94. ]
  95. self.assertAllClose(output_rates, exp_rates)
  96. def testManualStepping(self):
  97. def graph_fn(global_step):
  98. boundaries = [2, 3, 7]
  99. rates = [1.0, 2.0, 3.0, 4.0]
  100. learning_rate = learning_schedules.manual_stepping(
  101. global_step, boundaries, rates)
  102. assert learning_rate.op.name.endswith('learning_rate')
  103. return (learning_rate,)
  104. output_rates = [
  105. self.execute(graph_fn, [np.array(i).astype(np.int64)])
  106. for i in range(10)
  107. ]
  108. exp_rates = [1.0, 1.0, 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0]
  109. self.assertAllClose(output_rates, exp_rates)
  110. def testManualSteppingWithWarmup(self):
  111. def graph_fn(global_step):
  112. boundaries = [4, 6, 8]
  113. rates = [0.02, 0.10, 0.01, 0.001]
  114. learning_rate = learning_schedules.manual_stepping(
  115. global_step, boundaries, rates, warmup=True)
  116. assert learning_rate.op.name.endswith('learning_rate')
  117. return (learning_rate,)
  118. output_rates = [
  119. self.execute(graph_fn, [np.array(i).astype(np.int64)])
  120. for i in range(9)
  121. ]
  122. exp_rates = [0.02, 0.04, 0.06, 0.08, 0.10, 0.10, 0.01, 0.01, 0.001]
  123. self.assertAllClose(output_rates, exp_rates)
  124. def testManualSteppingWithZeroBoundaries(self):
  125. def graph_fn(global_step):
  126. boundaries = []
  127. rates = [0.01]
  128. learning_rate = learning_schedules.manual_stepping(
  129. global_step, boundaries, rates)
  130. return (learning_rate,)
  131. output_rates = [
  132. self.execute(graph_fn, [np.array(i).astype(np.int64)])
  133. for i in range(4)
  134. ]
  135. exp_rates = [0.01] * 4
  136. self.assertAllClose(output_rates, exp_rates)
  137. if __name__ == '__main__':
  138. tf.test.main()