|
|
- # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """Library of common learning rate schedules."""
-
- import numpy as np
- import tensorflow as tf
-
-
- def exponential_decay_with_burnin(global_step,
- learning_rate_base,
- learning_rate_decay_steps,
- learning_rate_decay_factor,
- burnin_learning_rate=0.0,
- burnin_steps=0,
- min_learning_rate=0.0,
- staircase=True):
- """Exponential decay schedule with burn-in period.
-
- In this schedule, learning rate is fixed at burnin_learning_rate
- for a fixed period, before transitioning to a regular exponential
- decay schedule.
-
- Args:
- global_step: int tensor representing global step.
- learning_rate_base: base learning rate.
- learning_rate_decay_steps: steps to take between decaying the learning rate.
- Note that this includes the number of burn-in steps.
- learning_rate_decay_factor: multiplicative factor by which to decay
- learning rate.
- burnin_learning_rate: initial learning rate during burn-in period. If
- 0.0 (which is the default), then the burn-in learning rate is simply
- set to learning_rate_base.
- burnin_steps: number of steps to use burnin learning rate.
- min_learning_rate: the minimum learning rate.
- staircase: whether use staircase decay.
-
- Returns:
- If executing eagerly:
- returns a no-arg callable that outputs the (scalar)
- float tensor learning rate given the current value of global_step.
- If in a graph:
- immediately returns a (scalar) float tensor representing learning rate.
- """
- if burnin_learning_rate == 0:
- burnin_learning_rate = learning_rate_base
-
- def eager_decay_rate():
- """Callable to compute the learning rate."""
- post_burnin_learning_rate = tf.train.exponential_decay(
- learning_rate_base,
- global_step - burnin_steps,
- learning_rate_decay_steps,
- learning_rate_decay_factor,
- staircase=staircase)
- if callable(post_burnin_learning_rate):
- post_burnin_learning_rate = post_burnin_learning_rate()
- return tf.maximum(tf.where(
- tf.less(tf.cast(global_step, tf.int32), tf.constant(burnin_steps)),
- tf.constant(burnin_learning_rate),
- post_burnin_learning_rate), min_learning_rate, name='learning_rate')
-
- if tf.executing_eagerly():
- return eager_decay_rate
- else:
- return eager_decay_rate()
-
-
- def cosine_decay_with_warmup(global_step,
- learning_rate_base,
- total_steps,
- warmup_learning_rate=0.0,
- warmup_steps=0,
- hold_base_rate_steps=0):
- """Cosine decay schedule with warm up period.
-
- Cosine annealing learning rate as described in:
- Loshchilov and Hutter, SGDR: Stochastic Gradient Descent with Warm Restarts.
- ICLR 2017. https://arxiv.org/abs/1608.03983
- In this schedule, the learning rate grows linearly from warmup_learning_rate
- to learning_rate_base for warmup_steps, then transitions to a cosine decay
- schedule.
-
- Args:
- global_step: int64 (scalar) tensor representing global step.
- learning_rate_base: base learning rate.
- total_steps: total number of training steps.
- warmup_learning_rate: initial learning rate for warm up.
- warmup_steps: number of warmup steps.
- hold_base_rate_steps: Optional number of steps to hold base learning rate
- before decaying.
-
- Returns:
- If executing eagerly:
- returns a no-arg callable that outputs the (scalar)
- float tensor learning rate given the current value of global_step.
- If in a graph:
- immediately returns a (scalar) float tensor representing learning rate.
-
- Raises:
- ValueError: if warmup_learning_rate is larger than learning_rate_base,
- or if warmup_steps is larger than total_steps.
- """
- if total_steps < warmup_steps:
- raise ValueError('total_steps must be larger or equal to '
- 'warmup_steps.')
- def eager_decay_rate():
- """Callable to compute the learning rate."""
- learning_rate = 0.5 * learning_rate_base * (1 + tf.cos(
- np.pi *
- (tf.cast(global_step, tf.float32) - warmup_steps - hold_base_rate_steps
- ) / float(total_steps - warmup_steps - hold_base_rate_steps)))
- if hold_base_rate_steps > 0:
- learning_rate = tf.where(
- global_step > warmup_steps + hold_base_rate_steps,
- learning_rate, learning_rate_base)
- if warmup_steps > 0:
- if learning_rate_base < warmup_learning_rate:
- raise ValueError('learning_rate_base must be larger or equal to '
- 'warmup_learning_rate.')
- slope = (learning_rate_base - warmup_learning_rate) / warmup_steps
- warmup_rate = slope * tf.cast(global_step,
- tf.float32) + warmup_learning_rate
- learning_rate = tf.where(global_step < warmup_steps, warmup_rate,
- learning_rate)
- return tf.where(global_step > total_steps, 0.0, learning_rate,
- name='learning_rate')
-
- if tf.executing_eagerly():
- return eager_decay_rate
- else:
- return eager_decay_rate()
-
-
- def manual_stepping(global_step, boundaries, rates, warmup=False):
- """Manually stepped learning rate schedule.
-
- This function provides fine grained control over learning rates. One must
- specify a sequence of learning rates as well as a set of integer steps
- at which the current learning rate must transition to the next. For example,
- if boundaries = [5, 10] and rates = [.1, .01, .001], then the learning
- rate returned by this function is .1 for global_step=0,...,4, .01 for
- global_step=5...9, and .001 for global_step=10 and onward.
-
- Args:
- global_step: int64 (scalar) tensor representing global step.
- boundaries: a list of global steps at which to switch learning
- rates. This list is assumed to consist of increasing positive integers.
- rates: a list of (float) learning rates corresponding to intervals between
- the boundaries. The length of this list must be exactly
- len(boundaries) + 1.
- warmup: Whether to linearly interpolate learning rate for steps in
- [0, boundaries[0]].
-
- Returns:
- If executing eagerly:
- returns a no-arg callable that outputs the (scalar)
- float tensor learning rate given the current value of global_step.
- If in a graph:
- immediately returns a (scalar) float tensor representing learning rate.
- Raises:
- ValueError: if one of the following checks fails:
- 1. boundaries is a strictly increasing list of positive integers
- 2. len(rates) == len(boundaries) + 1
- 3. boundaries[0] != 0
- """
- if any([b < 0 for b in boundaries]) or any(
- [not isinstance(b, int) for b in boundaries]):
- raise ValueError('boundaries must be a list of positive integers')
- if any([bnext <= b for bnext, b in zip(boundaries[1:], boundaries[:-1])]):
- raise ValueError('Entries in boundaries must be strictly increasing.')
- if any([not isinstance(r, float) for r in rates]):
- raise ValueError('Learning rates must be floats')
- if len(rates) != len(boundaries) + 1:
- raise ValueError('Number of provided learning rates must exceed '
- 'number of boundary points by exactly 1.')
-
- if boundaries and boundaries[0] == 0:
- raise ValueError('First step cannot be zero.')
-
- if warmup and boundaries:
- slope = (rates[1] - rates[0]) * 1.0 / boundaries[0]
- warmup_steps = range(boundaries[0])
- warmup_rates = [rates[0] + slope * step for step in warmup_steps]
- boundaries = warmup_steps + boundaries
- rates = warmup_rates + rates[1:]
- else:
- boundaries = [0] + boundaries
- num_boundaries = len(boundaries)
-
- def eager_decay_rate():
- """Callable to compute the learning rate."""
- rate_index = tf.reduce_max(tf.where(
- tf.greater_equal(global_step, boundaries),
- list(range(num_boundaries)),
- [0] * num_boundaries))
- return tf.reduce_sum(rates * tf.one_hot(rate_index, depth=num_boundaries),
- name='learning_rate')
- if tf.executing_eagerly():
- return eager_decay_rate
- else:
- return eager_decay_rate()
|