You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

213 lines
8.8 KiB

  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Library of common learning rate schedules."""
  16. import numpy as np
  17. import tensorflow as tf
  18. def exponential_decay_with_burnin(global_step,
  19. learning_rate_base,
  20. learning_rate_decay_steps,
  21. learning_rate_decay_factor,
  22. burnin_learning_rate=0.0,
  23. burnin_steps=0,
  24. min_learning_rate=0.0,
  25. staircase=True):
  26. """Exponential decay schedule with burn-in period.
  27. In this schedule, learning rate is fixed at burnin_learning_rate
  28. for a fixed period, before transitioning to a regular exponential
  29. decay schedule.
  30. Args:
  31. global_step: int tensor representing global step.
  32. learning_rate_base: base learning rate.
  33. learning_rate_decay_steps: steps to take between decaying the learning rate.
  34. Note that this includes the number of burn-in steps.
  35. learning_rate_decay_factor: multiplicative factor by which to decay
  36. learning rate.
  37. burnin_learning_rate: initial learning rate during burn-in period. If
  38. 0.0 (which is the default), then the burn-in learning rate is simply
  39. set to learning_rate_base.
  40. burnin_steps: number of steps to use burnin learning rate.
  41. min_learning_rate: the minimum learning rate.
  42. staircase: whether use staircase decay.
  43. Returns:
  44. If executing eagerly:
  45. returns a no-arg callable that outputs the (scalar)
  46. float tensor learning rate given the current value of global_step.
  47. If in a graph:
  48. immediately returns a (scalar) float tensor representing learning rate.
  49. """
  50. if burnin_learning_rate == 0:
  51. burnin_learning_rate = learning_rate_base
  52. def eager_decay_rate():
  53. """Callable to compute the learning rate."""
  54. post_burnin_learning_rate = tf.train.exponential_decay(
  55. learning_rate_base,
  56. global_step - burnin_steps,
  57. learning_rate_decay_steps,
  58. learning_rate_decay_factor,
  59. staircase=staircase)
  60. if callable(post_burnin_learning_rate):
  61. post_burnin_learning_rate = post_burnin_learning_rate()
  62. return tf.maximum(tf.where(
  63. tf.less(tf.cast(global_step, tf.int32), tf.constant(burnin_steps)),
  64. tf.constant(burnin_learning_rate),
  65. post_burnin_learning_rate), min_learning_rate, name='learning_rate')
  66. if tf.executing_eagerly():
  67. return eager_decay_rate
  68. else:
  69. return eager_decay_rate()
  70. def cosine_decay_with_warmup(global_step,
  71. learning_rate_base,
  72. total_steps,
  73. warmup_learning_rate=0.0,
  74. warmup_steps=0,
  75. hold_base_rate_steps=0):
  76. """Cosine decay schedule with warm up period.
  77. Cosine annealing learning rate as described in:
  78. Loshchilov and Hutter, SGDR: Stochastic Gradient Descent with Warm Restarts.
  79. ICLR 2017. https://arxiv.org/abs/1608.03983
  80. In this schedule, the learning rate grows linearly from warmup_learning_rate
  81. to learning_rate_base for warmup_steps, then transitions to a cosine decay
  82. schedule.
  83. Args:
  84. global_step: int64 (scalar) tensor representing global step.
  85. learning_rate_base: base learning rate.
  86. total_steps: total number of training steps.
  87. warmup_learning_rate: initial learning rate for warm up.
  88. warmup_steps: number of warmup steps.
  89. hold_base_rate_steps: Optional number of steps to hold base learning rate
  90. before decaying.
  91. Returns:
  92. If executing eagerly:
  93. returns a no-arg callable that outputs the (scalar)
  94. float tensor learning rate given the current value of global_step.
  95. If in a graph:
  96. immediately returns a (scalar) float tensor representing learning rate.
  97. Raises:
  98. ValueError: if warmup_learning_rate is larger than learning_rate_base,
  99. or if warmup_steps is larger than total_steps.
  100. """
  101. if total_steps < warmup_steps:
  102. raise ValueError('total_steps must be larger or equal to '
  103. 'warmup_steps.')
  104. def eager_decay_rate():
  105. """Callable to compute the learning rate."""
  106. learning_rate = 0.5 * learning_rate_base * (1 + tf.cos(
  107. np.pi *
  108. (tf.cast(global_step, tf.float32) - warmup_steps - hold_base_rate_steps
  109. ) / float(total_steps - warmup_steps - hold_base_rate_steps)))
  110. if hold_base_rate_steps > 0:
  111. learning_rate = tf.where(
  112. global_step > warmup_steps + hold_base_rate_steps,
  113. learning_rate, learning_rate_base)
  114. if warmup_steps > 0:
  115. if learning_rate_base < warmup_learning_rate:
  116. raise ValueError('learning_rate_base must be larger or equal to '
  117. 'warmup_learning_rate.')
  118. slope = (learning_rate_base - warmup_learning_rate) / warmup_steps
  119. warmup_rate = slope * tf.cast(global_step,
  120. tf.float32) + warmup_learning_rate
  121. learning_rate = tf.where(global_step < warmup_steps, warmup_rate,
  122. learning_rate)
  123. return tf.where(global_step > total_steps, 0.0, learning_rate,
  124. name='learning_rate')
  125. if tf.executing_eagerly():
  126. return eager_decay_rate
  127. else:
  128. return eager_decay_rate()
  129. def manual_stepping(global_step, boundaries, rates, warmup=False):
  130. """Manually stepped learning rate schedule.
  131. This function provides fine grained control over learning rates. One must
  132. specify a sequence of learning rates as well as a set of integer steps
  133. at which the current learning rate must transition to the next. For example,
  134. if boundaries = [5, 10] and rates = [.1, .01, .001], then the learning
  135. rate returned by this function is .1 for global_step=0,...,4, .01 for
  136. global_step=5...9, and .001 for global_step=10 and onward.
  137. Args:
  138. global_step: int64 (scalar) tensor representing global step.
  139. boundaries: a list of global steps at which to switch learning
  140. rates. This list is assumed to consist of increasing positive integers.
  141. rates: a list of (float) learning rates corresponding to intervals between
  142. the boundaries. The length of this list must be exactly
  143. len(boundaries) + 1.
  144. warmup: Whether to linearly interpolate learning rate for steps in
  145. [0, boundaries[0]].
  146. Returns:
  147. If executing eagerly:
  148. returns a no-arg callable that outputs the (scalar)
  149. float tensor learning rate given the current value of global_step.
  150. If in a graph:
  151. immediately returns a (scalar) float tensor representing learning rate.
  152. Raises:
  153. ValueError: if one of the following checks fails:
  154. 1. boundaries is a strictly increasing list of positive integers
  155. 2. len(rates) == len(boundaries) + 1
  156. 3. boundaries[0] != 0
  157. """
  158. if any([b < 0 for b in boundaries]) or any(
  159. [not isinstance(b, int) for b in boundaries]):
  160. raise ValueError('boundaries must be a list of positive integers')
  161. if any([bnext <= b for bnext, b in zip(boundaries[1:], boundaries[:-1])]):
  162. raise ValueError('Entries in boundaries must be strictly increasing.')
  163. if any([not isinstance(r, float) for r in rates]):
  164. raise ValueError('Learning rates must be floats')
  165. if len(rates) != len(boundaries) + 1:
  166. raise ValueError('Number of provided learning rates must exceed '
  167. 'number of boundary points by exactly 1.')
  168. if boundaries and boundaries[0] == 0:
  169. raise ValueError('First step cannot be zero.')
  170. if warmup and boundaries:
  171. slope = (rates[1] - rates[0]) * 1.0 / boundaries[0]
  172. warmup_steps = range(boundaries[0])
  173. warmup_rates = [rates[0] + slope * step for step in warmup_steps]
  174. boundaries = warmup_steps + boundaries
  175. rates = warmup_rates + rates[1:]
  176. else:
  177. boundaries = [0] + boundaries
  178. num_boundaries = len(boundaries)
  179. def eager_decay_rate():
  180. """Callable to compute the learning rate."""
  181. rate_index = tf.reduce_max(tf.where(
  182. tf.greater_equal(global_step, boundaries),
  183. list(range(num_boundaries)),
  184. [0] * num_boundaries))
  185. return tf.reduce_sum(rates * tf.one_hot(rate_index, depth=num_boundaries),
  186. name='learning_rate')
  187. if tf.executing_eagerly():
  188. return eager_decay_rate
  189. else:
  190. return eager_decay_rate()