You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

139 lines
5.0 KiB

  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Functions to build DetectionModel training optimizers."""
  16. import tensorflow as tf
  17. from object_detection.utils import learning_schedules
  18. def build(optimizer_config, global_step=None):
  19. """Create optimizer based on config.
  20. Args:
  21. optimizer_config: A Optimizer proto message.
  22. global_step: A variable representing the current step.
  23. If None, defaults to tf.train.get_or_create_global_step()
  24. Returns:
  25. An optimizer and a list of variables for summary.
  26. Raises:
  27. ValueError: when using an unsupported input data type.
  28. """
  29. optimizer_type = optimizer_config.WhichOneof('optimizer')
  30. optimizer = None
  31. summary_vars = []
  32. if optimizer_type == 'rms_prop_optimizer':
  33. config = optimizer_config.rms_prop_optimizer
  34. learning_rate = _create_learning_rate(config.learning_rate,
  35. global_step=global_step)
  36. summary_vars.append(learning_rate)
  37. optimizer = tf.train.RMSPropOptimizer(
  38. learning_rate,
  39. decay=config.decay,
  40. momentum=config.momentum_optimizer_value,
  41. epsilon=config.epsilon)
  42. if optimizer_type == 'momentum_optimizer':
  43. config = optimizer_config.momentum_optimizer
  44. learning_rate = _create_learning_rate(config.learning_rate,
  45. global_step=global_step)
  46. summary_vars.append(learning_rate)
  47. optimizer = tf.train.MomentumOptimizer(
  48. learning_rate,
  49. momentum=config.momentum_optimizer_value)
  50. if optimizer_type == 'adam_optimizer':
  51. config = optimizer_config.adam_optimizer
  52. learning_rate = _create_learning_rate(config.learning_rate,
  53. global_step=global_step)
  54. summary_vars.append(learning_rate)
  55. optimizer = tf.train.AdamOptimizer(learning_rate)
  56. if optimizer is None:
  57. raise ValueError('Optimizer %s not supported.' % optimizer_type)
  58. if optimizer_config.use_moving_average:
  59. optimizer = tf.contrib.opt.MovingAverageOptimizer(
  60. optimizer, average_decay=optimizer_config.moving_average_decay)
  61. return optimizer, summary_vars
  62. def _create_learning_rate(learning_rate_config, global_step=None):
  63. """Create optimizer learning rate based on config.
  64. Args:
  65. learning_rate_config: A LearningRate proto message.
  66. global_step: A variable representing the current step.
  67. If None, defaults to tf.train.get_or_create_global_step()
  68. Returns:
  69. A learning rate.
  70. Raises:
  71. ValueError: when using an unsupported input data type.
  72. """
  73. if global_step is None:
  74. global_step = tf.train.get_or_create_global_step()
  75. learning_rate = None
  76. learning_rate_type = learning_rate_config.WhichOneof('learning_rate')
  77. if learning_rate_type == 'constant_learning_rate':
  78. config = learning_rate_config.constant_learning_rate
  79. learning_rate = tf.constant(config.learning_rate, dtype=tf.float32,
  80. name='learning_rate')
  81. if learning_rate_type == 'exponential_decay_learning_rate':
  82. config = learning_rate_config.exponential_decay_learning_rate
  83. learning_rate = learning_schedules.exponential_decay_with_burnin(
  84. global_step,
  85. config.initial_learning_rate,
  86. config.decay_steps,
  87. config.decay_factor,
  88. burnin_learning_rate=config.burnin_learning_rate,
  89. burnin_steps=config.burnin_steps,
  90. min_learning_rate=config.min_learning_rate,
  91. staircase=config.staircase)
  92. if learning_rate_type == 'manual_step_learning_rate':
  93. config = learning_rate_config.manual_step_learning_rate
  94. if not config.schedule:
  95. raise ValueError('Empty learning rate schedule.')
  96. learning_rate_step_boundaries = [x.step for x in config.schedule]
  97. learning_rate_sequence = [config.initial_learning_rate]
  98. learning_rate_sequence += [x.learning_rate for x in config.schedule]
  99. learning_rate = learning_schedules.manual_stepping(
  100. global_step, learning_rate_step_boundaries,
  101. learning_rate_sequence, config.warmup)
  102. if learning_rate_type == 'cosine_decay_learning_rate':
  103. config = learning_rate_config.cosine_decay_learning_rate
  104. learning_rate = learning_schedules.cosine_decay_with_warmup(
  105. global_step,
  106. config.learning_rate_base,
  107. config.total_steps,
  108. config.warmup_learning_rate,
  109. config.warmup_steps,
  110. config.hold_base_rate_steps)
  111. if learning_rate is None:
  112. raise ValueError('Learning_rate %s not supported.' % learning_rate_type)
  113. return learning_rate