You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

252 lines
9.0 KiB

  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """A function to build localization and classification losses from config."""
  16. import functools
  17. from object_detection.core import balanced_positive_negative_sampler as sampler
  18. from object_detection.core import losses
  19. from object_detection.protos import losses_pb2
  20. from object_detection.utils import ops
  21. def build(loss_config):
  22. """Build losses based on the config.
  23. Builds classification, localization losses and optionally a hard example miner
  24. based on the config.
  25. Args:
  26. loss_config: A losses_pb2.Loss object.
  27. Returns:
  28. classification_loss: Classification loss object.
  29. localization_loss: Localization loss object.
  30. classification_weight: Classification loss weight.
  31. localization_weight: Localization loss weight.
  32. hard_example_miner: Hard example miner object.
  33. random_example_sampler: BalancedPositiveNegativeSampler object.
  34. Raises:
  35. ValueError: If hard_example_miner is used with sigmoid_focal_loss.
  36. ValueError: If random_example_sampler is getting non-positive value as
  37. desired positive example fraction.
  38. """
  39. classification_loss = _build_classification_loss(
  40. loss_config.classification_loss)
  41. localization_loss = _build_localization_loss(
  42. loss_config.localization_loss)
  43. classification_weight = loss_config.classification_weight
  44. localization_weight = loss_config.localization_weight
  45. hard_example_miner = None
  46. if loss_config.HasField('hard_example_miner'):
  47. if (loss_config.classification_loss.WhichOneof('classification_loss') ==
  48. 'weighted_sigmoid_focal'):
  49. raise ValueError('HardExampleMiner should not be used with sigmoid focal '
  50. 'loss')
  51. hard_example_miner = build_hard_example_miner(
  52. loss_config.hard_example_miner,
  53. classification_weight,
  54. localization_weight)
  55. random_example_sampler = None
  56. if loss_config.HasField('random_example_sampler'):
  57. if loss_config.random_example_sampler.positive_sample_fraction <= 0:
  58. raise ValueError('RandomExampleSampler should not use non-positive'
  59. 'value as positive sample fraction.')
  60. random_example_sampler = sampler.BalancedPositiveNegativeSampler(
  61. positive_fraction=loss_config.random_example_sampler.
  62. positive_sample_fraction)
  63. if loss_config.expected_loss_weights == loss_config.NONE:
  64. expected_loss_weights_fn = None
  65. elif loss_config.expected_loss_weights == loss_config.EXPECTED_SAMPLING:
  66. expected_loss_weights_fn = functools.partial(
  67. ops.expected_classification_loss_by_expected_sampling,
  68. min_num_negative_samples=loss_config.min_num_negative_samples,
  69. desired_negative_sampling_ratio=loss_config
  70. .desired_negative_sampling_ratio)
  71. elif (loss_config.expected_loss_weights == loss_config
  72. .REWEIGHTING_UNMATCHED_ANCHORS):
  73. expected_loss_weights_fn = functools.partial(
  74. ops.expected_classification_loss_by_reweighting_unmatched_anchors,
  75. min_num_negative_samples=loss_config.min_num_negative_samples,
  76. desired_negative_sampling_ratio=loss_config
  77. .desired_negative_sampling_ratio)
  78. else:
  79. raise ValueError('Not a valid value for expected_classification_loss.')
  80. return (classification_loss, localization_loss, classification_weight,
  81. localization_weight, hard_example_miner, random_example_sampler,
  82. expected_loss_weights_fn)
  83. def build_hard_example_miner(config,
  84. classification_weight,
  85. localization_weight):
  86. """Builds hard example miner based on the config.
  87. Args:
  88. config: A losses_pb2.HardExampleMiner object.
  89. classification_weight: Classification loss weight.
  90. localization_weight: Localization loss weight.
  91. Returns:
  92. Hard example miner.
  93. """
  94. loss_type = None
  95. if config.loss_type == losses_pb2.HardExampleMiner.BOTH:
  96. loss_type = 'both'
  97. if config.loss_type == losses_pb2.HardExampleMiner.CLASSIFICATION:
  98. loss_type = 'cls'
  99. if config.loss_type == losses_pb2.HardExampleMiner.LOCALIZATION:
  100. loss_type = 'loc'
  101. max_negatives_per_positive = None
  102. num_hard_examples = None
  103. if config.max_negatives_per_positive > 0:
  104. max_negatives_per_positive = config.max_negatives_per_positive
  105. if config.num_hard_examples > 0:
  106. num_hard_examples = config.num_hard_examples
  107. hard_example_miner = losses.HardExampleMiner(
  108. num_hard_examples=num_hard_examples,
  109. iou_threshold=config.iou_threshold,
  110. loss_type=loss_type,
  111. cls_loss_weight=classification_weight,
  112. loc_loss_weight=localization_weight,
  113. max_negatives_per_positive=max_negatives_per_positive,
  114. min_negatives_per_image=config.min_negatives_per_image)
  115. return hard_example_miner
  116. def build_faster_rcnn_classification_loss(loss_config):
  117. """Builds a classification loss for Faster RCNN based on the loss config.
  118. Args:
  119. loss_config: A losses_pb2.ClassificationLoss object.
  120. Returns:
  121. Loss based on the config.
  122. Raises:
  123. ValueError: On invalid loss_config.
  124. """
  125. if not isinstance(loss_config, losses_pb2.ClassificationLoss):
  126. raise ValueError('loss_config not of type losses_pb2.ClassificationLoss.')
  127. loss_type = loss_config.WhichOneof('classification_loss')
  128. if loss_type == 'weighted_sigmoid':
  129. return losses.WeightedSigmoidClassificationLoss()
  130. if loss_type == 'weighted_softmax':
  131. config = loss_config.weighted_softmax
  132. return losses.WeightedSoftmaxClassificationLoss(
  133. logit_scale=config.logit_scale)
  134. if loss_type == 'weighted_logits_softmax':
  135. config = loss_config.weighted_logits_softmax
  136. return losses.WeightedSoftmaxClassificationAgainstLogitsLoss(
  137. logit_scale=config.logit_scale)
  138. if loss_type == 'weighted_sigmoid_focal':
  139. config = loss_config.weighted_sigmoid_focal
  140. alpha = None
  141. if config.HasField('alpha'):
  142. alpha = config.alpha
  143. return losses.SigmoidFocalClassificationLoss(
  144. gamma=config.gamma,
  145. alpha=alpha)
  146. # By default, Faster RCNN second stage classifier uses Softmax loss
  147. # with anchor-wise outputs.
  148. config = loss_config.weighted_softmax
  149. return losses.WeightedSoftmaxClassificationLoss(
  150. logit_scale=config.logit_scale)
  151. def _build_localization_loss(loss_config):
  152. """Builds a localization loss based on the loss config.
  153. Args:
  154. loss_config: A losses_pb2.LocalizationLoss object.
  155. Returns:
  156. Loss based on the config.
  157. Raises:
  158. ValueError: On invalid loss_config.
  159. """
  160. if not isinstance(loss_config, losses_pb2.LocalizationLoss):
  161. raise ValueError('loss_config not of type losses_pb2.LocalizationLoss.')
  162. loss_type = loss_config.WhichOneof('localization_loss')
  163. if loss_type == 'weighted_l2':
  164. return losses.WeightedL2LocalizationLoss()
  165. if loss_type == 'weighted_smooth_l1':
  166. return losses.WeightedSmoothL1LocalizationLoss(
  167. loss_config.weighted_smooth_l1.delta)
  168. if loss_type == 'weighted_iou':
  169. return losses.WeightedIOULocalizationLoss()
  170. raise ValueError('Empty loss config.')
  171. def _build_classification_loss(loss_config):
  172. """Builds a classification loss based on the loss config.
  173. Args:
  174. loss_config: A losses_pb2.ClassificationLoss object.
  175. Returns:
  176. Loss based on the config.
  177. Raises:
  178. ValueError: On invalid loss_config.
  179. """
  180. if not isinstance(loss_config, losses_pb2.ClassificationLoss):
  181. raise ValueError('loss_config not of type losses_pb2.ClassificationLoss.')
  182. loss_type = loss_config.WhichOneof('classification_loss')
  183. if loss_type == 'weighted_sigmoid':
  184. return losses.WeightedSigmoidClassificationLoss()
  185. if loss_type == 'weighted_sigmoid_focal':
  186. config = loss_config.weighted_sigmoid_focal
  187. alpha = None
  188. if config.HasField('alpha'):
  189. alpha = config.alpha
  190. return losses.SigmoidFocalClassificationLoss(
  191. gamma=config.gamma,
  192. alpha=alpha)
  193. if loss_type == 'weighted_softmax':
  194. config = loss_config.weighted_softmax
  195. return losses.WeightedSoftmaxClassificationLoss(
  196. logit_scale=config.logit_scale)
  197. if loss_type == 'weighted_logits_softmax':
  198. config = loss_config.weighted_logits_softmax
  199. return losses.WeightedSoftmaxClassificationAgainstLogitsLoss(
  200. logit_scale=config.logit_scale)
  201. if loss_type == 'bootstrapped_sigmoid':
  202. config = loss_config.bootstrapped_sigmoid
  203. return losses.BootstrappedSigmoidClassificationLoss(
  204. alpha=config.alpha,
  205. bootstrap_type=('hard' if config.hard_bootstrap else 'soft'))
  206. raise ValueError('Empty loss config.')