|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Library of common learning rate schedules."""
|
|
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
|
|
|
|
def exponential_decay_with_burnin(global_step,
|
|
learning_rate_base,
|
|
learning_rate_decay_steps,
|
|
learning_rate_decay_factor,
|
|
burnin_learning_rate=0.0,
|
|
burnin_steps=0,
|
|
min_learning_rate=0.0,
|
|
staircase=True):
|
|
"""Exponential decay schedule with burn-in period.
|
|
|
|
In this schedule, learning rate is fixed at burnin_learning_rate
|
|
for a fixed period, before transitioning to a regular exponential
|
|
decay schedule.
|
|
|
|
Args:
|
|
global_step: int tensor representing global step.
|
|
learning_rate_base: base learning rate.
|
|
learning_rate_decay_steps: steps to take between decaying the learning rate.
|
|
Note that this includes the number of burn-in steps.
|
|
learning_rate_decay_factor: multiplicative factor by which to decay
|
|
learning rate.
|
|
burnin_learning_rate: initial learning rate during burn-in period. If
|
|
0.0 (which is the default), then the burn-in learning rate is simply
|
|
set to learning_rate_base.
|
|
burnin_steps: number of steps to use burnin learning rate.
|
|
min_learning_rate: the minimum learning rate.
|
|
staircase: whether use staircase decay.
|
|
|
|
Returns:
|
|
If executing eagerly:
|
|
returns a no-arg callable that outputs the (scalar)
|
|
float tensor learning rate given the current value of global_step.
|
|
If in a graph:
|
|
immediately returns a (scalar) float tensor representing learning rate.
|
|
"""
|
|
if burnin_learning_rate == 0:
|
|
burnin_learning_rate = learning_rate_base
|
|
|
|
def eager_decay_rate():
|
|
"""Callable to compute the learning rate."""
|
|
post_burnin_learning_rate = tf.train.exponential_decay(
|
|
learning_rate_base,
|
|
global_step - burnin_steps,
|
|
learning_rate_decay_steps,
|
|
learning_rate_decay_factor,
|
|
staircase=staircase)
|
|
if callable(post_burnin_learning_rate):
|
|
post_burnin_learning_rate = post_burnin_learning_rate()
|
|
return tf.maximum(tf.where(
|
|
tf.less(tf.cast(global_step, tf.int32), tf.constant(burnin_steps)),
|
|
tf.constant(burnin_learning_rate),
|
|
post_burnin_learning_rate), min_learning_rate, name='learning_rate')
|
|
|
|
if tf.executing_eagerly():
|
|
return eager_decay_rate
|
|
else:
|
|
return eager_decay_rate()
|
|
|
|
|
|
def cosine_decay_with_warmup(global_step,
|
|
learning_rate_base,
|
|
total_steps,
|
|
warmup_learning_rate=0.0,
|
|
warmup_steps=0,
|
|
hold_base_rate_steps=0):
|
|
"""Cosine decay schedule with warm up period.
|
|
|
|
Cosine annealing learning rate as described in:
|
|
Loshchilov and Hutter, SGDR: Stochastic Gradient Descent with Warm Restarts.
|
|
ICLR 2017. https://arxiv.org/abs/1608.03983
|
|
In this schedule, the learning rate grows linearly from warmup_learning_rate
|
|
to learning_rate_base for warmup_steps, then transitions to a cosine decay
|
|
schedule.
|
|
|
|
Args:
|
|
global_step: int64 (scalar) tensor representing global step.
|
|
learning_rate_base: base learning rate.
|
|
total_steps: total number of training steps.
|
|
warmup_learning_rate: initial learning rate for warm up.
|
|
warmup_steps: number of warmup steps.
|
|
hold_base_rate_steps: Optional number of steps to hold base learning rate
|
|
before decaying.
|
|
|
|
Returns:
|
|
If executing eagerly:
|
|
returns a no-arg callable that outputs the (scalar)
|
|
float tensor learning rate given the current value of global_step.
|
|
If in a graph:
|
|
immediately returns a (scalar) float tensor representing learning rate.
|
|
|
|
Raises:
|
|
ValueError: if warmup_learning_rate is larger than learning_rate_base,
|
|
or if warmup_steps is larger than total_steps.
|
|
"""
|
|
if total_steps < warmup_steps:
|
|
raise ValueError('total_steps must be larger or equal to '
|
|
'warmup_steps.')
|
|
def eager_decay_rate():
|
|
"""Callable to compute the learning rate."""
|
|
learning_rate = 0.5 * learning_rate_base * (1 + tf.cos(
|
|
np.pi *
|
|
(tf.cast(global_step, tf.float32) - warmup_steps - hold_base_rate_steps
|
|
) / float(total_steps - warmup_steps - hold_base_rate_steps)))
|
|
if hold_base_rate_steps > 0:
|
|
learning_rate = tf.where(
|
|
global_step > warmup_steps + hold_base_rate_steps,
|
|
learning_rate, learning_rate_base)
|
|
if warmup_steps > 0:
|
|
if learning_rate_base < warmup_learning_rate:
|
|
raise ValueError('learning_rate_base must be larger or equal to '
|
|
'warmup_learning_rate.')
|
|
slope = (learning_rate_base - warmup_learning_rate) / warmup_steps
|
|
warmup_rate = slope * tf.cast(global_step,
|
|
tf.float32) + warmup_learning_rate
|
|
learning_rate = tf.where(global_step < warmup_steps, warmup_rate,
|
|
learning_rate)
|
|
return tf.where(global_step > total_steps, 0.0, learning_rate,
|
|
name='learning_rate')
|
|
|
|
if tf.executing_eagerly():
|
|
return eager_decay_rate
|
|
else:
|
|
return eager_decay_rate()
|
|
|
|
|
|
def manual_stepping(global_step, boundaries, rates, warmup=False):
|
|
"""Manually stepped learning rate schedule.
|
|
|
|
This function provides fine grained control over learning rates. One must
|
|
specify a sequence of learning rates as well as a set of integer steps
|
|
at which the current learning rate must transition to the next. For example,
|
|
if boundaries = [5, 10] and rates = [.1, .01, .001], then the learning
|
|
rate returned by this function is .1 for global_step=0,...,4, .01 for
|
|
global_step=5...9, and .001 for global_step=10 and onward.
|
|
|
|
Args:
|
|
global_step: int64 (scalar) tensor representing global step.
|
|
boundaries: a list of global steps at which to switch learning
|
|
rates. This list is assumed to consist of increasing positive integers.
|
|
rates: a list of (float) learning rates corresponding to intervals between
|
|
the boundaries. The length of this list must be exactly
|
|
len(boundaries) + 1.
|
|
warmup: Whether to linearly interpolate learning rate for steps in
|
|
[0, boundaries[0]].
|
|
|
|
Returns:
|
|
If executing eagerly:
|
|
returns a no-arg callable that outputs the (scalar)
|
|
float tensor learning rate given the current value of global_step.
|
|
If in a graph:
|
|
immediately returns a (scalar) float tensor representing learning rate.
|
|
Raises:
|
|
ValueError: if one of the following checks fails:
|
|
1. boundaries is a strictly increasing list of positive integers
|
|
2. len(rates) == len(boundaries) + 1
|
|
3. boundaries[0] != 0
|
|
"""
|
|
if any([b < 0 for b in boundaries]) or any(
|
|
[not isinstance(b, int) for b in boundaries]):
|
|
raise ValueError('boundaries must be a list of positive integers')
|
|
if any([bnext <= b for bnext, b in zip(boundaries[1:], boundaries[:-1])]):
|
|
raise ValueError('Entries in boundaries must be strictly increasing.')
|
|
if any([not isinstance(r, float) for r in rates]):
|
|
raise ValueError('Learning rates must be floats')
|
|
if len(rates) != len(boundaries) + 1:
|
|
raise ValueError('Number of provided learning rates must exceed '
|
|
'number of boundary points by exactly 1.')
|
|
|
|
if boundaries and boundaries[0] == 0:
|
|
raise ValueError('First step cannot be zero.')
|
|
|
|
if warmup and boundaries:
|
|
slope = (rates[1] - rates[0]) * 1.0 / boundaries[0]
|
|
warmup_steps = range(boundaries[0])
|
|
warmup_rates = [rates[0] + slope * step for step in warmup_steps]
|
|
boundaries = warmup_steps + boundaries
|
|
rates = warmup_rates + rates[1:]
|
|
else:
|
|
boundaries = [0] + boundaries
|
|
num_boundaries = len(boundaries)
|
|
|
|
def eager_decay_rate():
|
|
"""Callable to compute the learning rate."""
|
|
rate_index = tf.reduce_max(tf.where(
|
|
tf.greater_equal(global_step, boundaries),
|
|
list(range(num_boundaries)),
|
|
[0] * num_boundaries))
|
|
return tf.reduce_sum(rates * tf.one_hot(rate_index, depth=num_boundaries),
|
|
name='learning_rate')
|
|
if tf.executing_eagerly():
|
|
return eager_decay_rate
|
|
else:
|
|
return eager_decay_rate()
|