You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

175 lines
7.6 KiB

6 years ago
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Library of common learning rate schedules."""
  16. import numpy as np
  17. import tensorflow as tf
  18. def exponential_decay_with_burnin(global_step,
  19. learning_rate_base,
  20. learning_rate_decay_steps,
  21. learning_rate_decay_factor,
  22. burnin_learning_rate=0.0,
  23. burnin_steps=0,
  24. min_learning_rate=0.0,
  25. staircase=True):
  26. """Exponential decay schedule with burn-in period.
  27. In this schedule, learning rate is fixed at burnin_learning_rate
  28. for a fixed period, before transitioning to a regular exponential
  29. decay schedule.
  30. Args:
  31. global_step: int tensor representing global step.
  32. learning_rate_base: base learning rate.
  33. learning_rate_decay_steps: steps to take between decaying the learning rate.
  34. Note that this includes the number of burn-in steps.
  35. learning_rate_decay_factor: multiplicative factor by which to decay
  36. learning rate.
  37. burnin_learning_rate: initial learning rate during burn-in period. If
  38. 0.0 (which is the default), then the burn-in learning rate is simply
  39. set to learning_rate_base.
  40. burnin_steps: number of steps to use burnin learning rate.
  41. min_learning_rate: the minimum learning rate.
  42. staircase: whether use staircase decay.
  43. Returns:
  44. a (scalar) float tensor representing learning rate
  45. """
  46. if burnin_learning_rate == 0:
  47. burnin_learning_rate = learning_rate_base
  48. post_burnin_learning_rate = tf.train.exponential_decay(
  49. learning_rate_base,
  50. global_step - burnin_steps,
  51. learning_rate_decay_steps,
  52. learning_rate_decay_factor,
  53. staircase=staircase)
  54. return tf.maximum(tf.where(
  55. tf.less(tf.cast(global_step, tf.int32), tf.constant(burnin_steps)),
  56. tf.constant(burnin_learning_rate),
  57. post_burnin_learning_rate), min_learning_rate, name='learning_rate')
  58. def cosine_decay_with_warmup(global_step,
  59. learning_rate_base,
  60. total_steps,
  61. warmup_learning_rate=0.0,
  62. warmup_steps=0,
  63. hold_base_rate_steps=0):
  64. """Cosine decay schedule with warm up period.
  65. Cosine annealing learning rate as described in:
  66. Loshchilov and Hutter, SGDR: Stochastic Gradient Descent with Warm Restarts.
  67. ICLR 2017. https://arxiv.org/abs/1608.03983
  68. In this schedule, the learning rate grows linearly from warmup_learning_rate
  69. to learning_rate_base for warmup_steps, then transitions to a cosine decay
  70. schedule.
  71. Args:
  72. global_step: int64 (scalar) tensor representing global step.
  73. learning_rate_base: base learning rate.
  74. total_steps: total number of training steps.
  75. warmup_learning_rate: initial learning rate for warm up.
  76. warmup_steps: number of warmup steps.
  77. hold_base_rate_steps: Optional number of steps to hold base learning rate
  78. before decaying.
  79. Returns:
  80. a (scalar) float tensor representing learning rate.
  81. Raises:
  82. ValueError: if warmup_learning_rate is larger than learning_rate_base,
  83. or if warmup_steps is larger than total_steps.
  84. """
  85. if total_steps < warmup_steps:
  86. raise ValueError('total_steps must be larger or equal to '
  87. 'warmup_steps.')
  88. learning_rate = 0.5 * learning_rate_base * (1 + tf.cos(
  89. np.pi *
  90. (tf.cast(global_step, tf.float32) - warmup_steps - hold_base_rate_steps
  91. ) / float(total_steps - warmup_steps - hold_base_rate_steps)))
  92. if hold_base_rate_steps > 0:
  93. learning_rate = tf.where(global_step > warmup_steps + hold_base_rate_steps,
  94. learning_rate, learning_rate_base)
  95. if warmup_steps > 0:
  96. if learning_rate_base < warmup_learning_rate:
  97. raise ValueError('learning_rate_base must be larger or equal to '
  98. 'warmup_learning_rate.')
  99. slope = (learning_rate_base - warmup_learning_rate) / warmup_steps
  100. warmup_rate = slope * tf.cast(global_step,
  101. tf.float32) + warmup_learning_rate
  102. learning_rate = tf.where(global_step < warmup_steps, warmup_rate,
  103. learning_rate)
  104. return tf.where(global_step > total_steps, 0.0, learning_rate,
  105. name='learning_rate')
  106. def manual_stepping(global_step, boundaries, rates, warmup=False):
  107. """Manually stepped learning rate schedule.
  108. This function provides fine grained control over learning rates. One must
  109. specify a sequence of learning rates as well as a set of integer steps
  110. at which the current learning rate must transition to the next. For example,
  111. if boundaries = [5, 10] and rates = [.1, .01, .001], then the learning
  112. rate returned by this function is .1 for global_step=0,...,4, .01 for
  113. global_step=5...9, and .001 for global_step=10 and onward.
  114. Args:
  115. global_step: int64 (scalar) tensor representing global step.
  116. boundaries: a list of global steps at which to switch learning
  117. rates. This list is assumed to consist of increasing positive integers.
  118. rates: a list of (float) learning rates corresponding to intervals between
  119. the boundaries. The length of this list must be exactly
  120. len(boundaries) + 1.
  121. warmup: Whether to linearly interpolate learning rate for steps in
  122. [0, boundaries[0]].
  123. Returns:
  124. a (scalar) float tensor representing learning rate
  125. Raises:
  126. ValueError: if one of the following checks fails:
  127. 1. boundaries is a strictly increasing list of positive integers
  128. 2. len(rates) == len(boundaries) + 1
  129. 3. boundaries[0] != 0
  130. """
  131. if any([b < 0 for b in boundaries]) or any(
  132. [not isinstance(b, int) for b in boundaries]):
  133. raise ValueError('boundaries must be a list of positive integers')
  134. if any([bnext <= b for bnext, b in zip(boundaries[1:], boundaries[:-1])]):
  135. raise ValueError('Entries in boundaries must be strictly increasing.')
  136. if any([not isinstance(r, float) for r in rates]):
  137. raise ValueError('Learning rates must be floats')
  138. if len(rates) != len(boundaries) + 1:
  139. raise ValueError('Number of provided learning rates must exceed '
  140. 'number of boundary points by exactly 1.')
  141. if boundaries and boundaries[0] == 0:
  142. raise ValueError('First step cannot be zero.')
  143. if warmup and boundaries:
  144. slope = (rates[1] - rates[0]) * 1.0 / boundaries[0]
  145. warmup_steps = range(boundaries[0])
  146. warmup_rates = [rates[0] + slope * step for step in warmup_steps]
  147. boundaries = warmup_steps + boundaries
  148. rates = warmup_rates + rates[1:]
  149. else:
  150. boundaries = [0] + boundaries
  151. num_boundaries = len(boundaries)
  152. rate_index = tf.reduce_max(tf.where(tf.greater_equal(global_step, boundaries),
  153. list(range(num_boundaries)),
  154. [0] * num_boundaries))
  155. return tf.reduce_sum(rates * tf.one_hot(rate_index, depth=num_boundaries),
  156. name='learning_rate')