You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

213 lines
8.8 KiB

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Library of common learning rate schedules."""
import numpy as np
import tensorflow as tf
def exponential_decay_with_burnin(global_step,
learning_rate_base,
learning_rate_decay_steps,
learning_rate_decay_factor,
burnin_learning_rate=0.0,
burnin_steps=0,
min_learning_rate=0.0,
staircase=True):
"""Exponential decay schedule with burn-in period.
In this schedule, learning rate is fixed at burnin_learning_rate
for a fixed period, before transitioning to a regular exponential
decay schedule.
Args:
global_step: int tensor representing global step.
learning_rate_base: base learning rate.
learning_rate_decay_steps: steps to take between decaying the learning rate.
Note that this includes the number of burn-in steps.
learning_rate_decay_factor: multiplicative factor by which to decay
learning rate.
burnin_learning_rate: initial learning rate during burn-in period. If
0.0 (which is the default), then the burn-in learning rate is simply
set to learning_rate_base.
burnin_steps: number of steps to use burnin learning rate.
min_learning_rate: the minimum learning rate.
staircase: whether use staircase decay.
Returns:
If executing eagerly:
returns a no-arg callable that outputs the (scalar)
float tensor learning rate given the current value of global_step.
If in a graph:
immediately returns a (scalar) float tensor representing learning rate.
"""
if burnin_learning_rate == 0:
burnin_learning_rate = learning_rate_base
def eager_decay_rate():
"""Callable to compute the learning rate."""
post_burnin_learning_rate = tf.train.exponential_decay(
learning_rate_base,
global_step - burnin_steps,
learning_rate_decay_steps,
learning_rate_decay_factor,
staircase=staircase)
if callable(post_burnin_learning_rate):
post_burnin_learning_rate = post_burnin_learning_rate()
return tf.maximum(tf.where(
tf.less(tf.cast(global_step, tf.int32), tf.constant(burnin_steps)),
tf.constant(burnin_learning_rate),
post_burnin_learning_rate), min_learning_rate, name='learning_rate')
if tf.executing_eagerly():
return eager_decay_rate
else:
return eager_decay_rate()
def cosine_decay_with_warmup(global_step,
learning_rate_base,
total_steps,
warmup_learning_rate=0.0,
warmup_steps=0,
hold_base_rate_steps=0):
"""Cosine decay schedule with warm up period.
Cosine annealing learning rate as described in:
Loshchilov and Hutter, SGDR: Stochastic Gradient Descent with Warm Restarts.
ICLR 2017. https://arxiv.org/abs/1608.03983
In this schedule, the learning rate grows linearly from warmup_learning_rate
to learning_rate_base for warmup_steps, then transitions to a cosine decay
schedule.
Args:
global_step: int64 (scalar) tensor representing global step.
learning_rate_base: base learning rate.
total_steps: total number of training steps.
warmup_learning_rate: initial learning rate for warm up.
warmup_steps: number of warmup steps.
hold_base_rate_steps: Optional number of steps to hold base learning rate
before decaying.
Returns:
If executing eagerly:
returns a no-arg callable that outputs the (scalar)
float tensor learning rate given the current value of global_step.
If in a graph:
immediately returns a (scalar) float tensor representing learning rate.
Raises:
ValueError: if warmup_learning_rate is larger than learning_rate_base,
or if warmup_steps is larger than total_steps.
"""
if total_steps < warmup_steps:
raise ValueError('total_steps must be larger or equal to '
'warmup_steps.')
def eager_decay_rate():
"""Callable to compute the learning rate."""
learning_rate = 0.5 * learning_rate_base * (1 + tf.cos(
np.pi *
(tf.cast(global_step, tf.float32) - warmup_steps - hold_base_rate_steps
) / float(total_steps - warmup_steps - hold_base_rate_steps)))
if hold_base_rate_steps > 0:
learning_rate = tf.where(
global_step > warmup_steps + hold_base_rate_steps,
learning_rate, learning_rate_base)
if warmup_steps > 0:
if learning_rate_base < warmup_learning_rate:
raise ValueError('learning_rate_base must be larger or equal to '
'warmup_learning_rate.')
slope = (learning_rate_base - warmup_learning_rate) / warmup_steps
warmup_rate = slope * tf.cast(global_step,
tf.float32) + warmup_learning_rate
learning_rate = tf.where(global_step < warmup_steps, warmup_rate,
learning_rate)
return tf.where(global_step > total_steps, 0.0, learning_rate,
name='learning_rate')
if tf.executing_eagerly():
return eager_decay_rate
else:
return eager_decay_rate()
def manual_stepping(global_step, boundaries, rates, warmup=False):
"""Manually stepped learning rate schedule.
This function provides fine grained control over learning rates. One must
specify a sequence of learning rates as well as a set of integer steps
at which the current learning rate must transition to the next. For example,
if boundaries = [5, 10] and rates = [.1, .01, .001], then the learning
rate returned by this function is .1 for global_step=0,...,4, .01 for
global_step=5...9, and .001 for global_step=10 and onward.
Args:
global_step: int64 (scalar) tensor representing global step.
boundaries: a list of global steps at which to switch learning
rates. This list is assumed to consist of increasing positive integers.
rates: a list of (float) learning rates corresponding to intervals between
the boundaries. The length of this list must be exactly
len(boundaries) + 1.
warmup: Whether to linearly interpolate learning rate for steps in
[0, boundaries[0]].
Returns:
If executing eagerly:
returns a no-arg callable that outputs the (scalar)
float tensor learning rate given the current value of global_step.
If in a graph:
immediately returns a (scalar) float tensor representing learning rate.
Raises:
ValueError: if one of the following checks fails:
1. boundaries is a strictly increasing list of positive integers
2. len(rates) == len(boundaries) + 1
3. boundaries[0] != 0
"""
if any([b < 0 for b in boundaries]) or any(
[not isinstance(b, int) for b in boundaries]):
raise ValueError('boundaries must be a list of positive integers')
if any([bnext <= b for bnext, b in zip(boundaries[1:], boundaries[:-1])]):
raise ValueError('Entries in boundaries must be strictly increasing.')
if any([not isinstance(r, float) for r in rates]):
raise ValueError('Learning rates must be floats')
if len(rates) != len(boundaries) + 1:
raise ValueError('Number of provided learning rates must exceed '
'number of boundary points by exactly 1.')
if boundaries and boundaries[0] == 0:
raise ValueError('First step cannot be zero.')
if warmup and boundaries:
slope = (rates[1] - rates[0]) * 1.0 / boundaries[0]
warmup_steps = range(boundaries[0])
warmup_rates = [rates[0] + slope * step for step in warmup_steps]
boundaries = warmup_steps + boundaries
rates = warmup_rates + rates[1:]
else:
boundaries = [0] + boundaries
num_boundaries = len(boundaries)
def eager_decay_rate():
"""Callable to compute the learning rate."""
rate_index = tf.reduce_max(tf.where(
tf.greater_equal(global_step, boundaries),
list(range(num_boundaries)),
[0] * num_boundaries))
return tf.reduce_sum(rates * tf.one_hot(rate_index, depth=num_boundaries),
name='learning_rate')
if tf.executing_eagerly():
return eager_decay_rate
else:
return eager_decay_rate()