You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

161 lines
6.8 KiB

  1. # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Builder function for post processing operations."""
  16. import functools
  17. import tensorflow as tf
  18. from object_detection.builders import calibration_builder
  19. from object_detection.core import post_processing
  20. from object_detection.protos import post_processing_pb2
  21. def build(post_processing_config):
  22. """Builds callables for post-processing operations.
  23. Builds callables for non-max suppression, score conversion, and (optionally)
  24. calibration based on the configuration.
  25. Non-max suppression callable takes `boxes`, `scores`, and optionally
  26. `clip_window`, `parallel_iterations` `masks, and `scope` as inputs. It returns
  27. `nms_boxes`, `nms_scores`, `nms_classes` `nms_masks` and `num_detections`. See
  28. post_processing.batch_multiclass_non_max_suppression for the type and shape
  29. of these tensors.
  30. Score converter callable should be called with `input` tensor. The callable
  31. returns the output from one of 3 tf operations based on the configuration -
  32. tf.identity, tf.sigmoid or tf.nn.softmax. If a calibration config is provided,
  33. score_converter also applies calibration transformations, as defined in
  34. calibration_builder.py. See tensorflow documentation for argument and return
  35. value descriptions.
  36. Args:
  37. post_processing_config: post_processing.proto object containing the
  38. parameters for the post-processing operations.
  39. Returns:
  40. non_max_suppressor_fn: Callable for non-max suppression.
  41. score_converter_fn: Callable for score conversion.
  42. Raises:
  43. ValueError: if the post_processing_config is of incorrect type.
  44. """
  45. if not isinstance(post_processing_config, post_processing_pb2.PostProcessing):
  46. raise ValueError('post_processing_config not of type '
  47. 'post_processing_pb2.Postprocessing.')
  48. non_max_suppressor_fn = _build_non_max_suppressor(
  49. post_processing_config.batch_non_max_suppression)
  50. score_converter_fn = _build_score_converter(
  51. post_processing_config.score_converter,
  52. post_processing_config.logit_scale)
  53. if post_processing_config.HasField('calibration_config'):
  54. score_converter_fn = _build_calibrated_score_converter(
  55. score_converter_fn,
  56. post_processing_config.calibration_config)
  57. return non_max_suppressor_fn, score_converter_fn
  58. def _build_non_max_suppressor(nms_config):
  59. """Builds non-max suppresson based on the nms config.
  60. Args:
  61. nms_config: post_processing_pb2.PostProcessing.BatchNonMaxSuppression proto.
  62. Returns:
  63. non_max_suppressor_fn: Callable non-max suppressor.
  64. Raises:
  65. ValueError: On incorrect iou_threshold or on incompatible values of
  66. max_total_detections and max_detections_per_class.
  67. """
  68. if nms_config.iou_threshold < 0 or nms_config.iou_threshold > 1.0:
  69. raise ValueError('iou_threshold not in [0, 1.0].')
  70. if nms_config.max_detections_per_class > nms_config.max_total_detections:
  71. raise ValueError('max_detections_per_class should be no greater than '
  72. 'max_total_detections.')
  73. non_max_suppressor_fn = functools.partial(
  74. post_processing.batch_multiclass_non_max_suppression,
  75. score_thresh=nms_config.score_threshold,
  76. iou_thresh=nms_config.iou_threshold,
  77. max_size_per_class=nms_config.max_detections_per_class,
  78. max_total_size=nms_config.max_total_detections,
  79. use_static_shapes=nms_config.use_static_shapes,
  80. use_class_agnostic_nms=nms_config.use_class_agnostic_nms,
  81. max_classes_per_detection=nms_config.max_classes_per_detection)
  82. return non_max_suppressor_fn
  83. def _score_converter_fn_with_logit_scale(tf_score_converter_fn, logit_scale):
  84. """Create a function to scale logits then apply a Tensorflow function."""
  85. def score_converter_fn(logits):
  86. scaled_logits = tf.divide(logits, logit_scale, name='scale_logits')
  87. return tf_score_converter_fn(scaled_logits, name='convert_scores')
  88. score_converter_fn.__name__ = '%s_with_logit_scale' % (
  89. tf_score_converter_fn.__name__)
  90. return score_converter_fn
  91. def _build_score_converter(score_converter_config, logit_scale):
  92. """Builds score converter based on the config.
  93. Builds one of [tf.identity, tf.sigmoid, tf.softmax] score converters based on
  94. the config.
  95. Args:
  96. score_converter_config: post_processing_pb2.PostProcessing.score_converter.
  97. logit_scale: temperature to use for SOFTMAX score_converter.
  98. Returns:
  99. Callable score converter op.
  100. Raises:
  101. ValueError: On unknown score converter.
  102. """
  103. if score_converter_config == post_processing_pb2.PostProcessing.IDENTITY:
  104. return _score_converter_fn_with_logit_scale(tf.identity, logit_scale)
  105. if score_converter_config == post_processing_pb2.PostProcessing.SIGMOID:
  106. return _score_converter_fn_with_logit_scale(tf.sigmoid, logit_scale)
  107. if score_converter_config == post_processing_pb2.PostProcessing.SOFTMAX:
  108. return _score_converter_fn_with_logit_scale(tf.nn.softmax, logit_scale)
  109. raise ValueError('Unknown score converter.')
  110. def _build_calibrated_score_converter(score_converter_fn, calibration_config):
  111. """Wraps a score_converter_fn, adding a calibration step.
  112. Builds a score converter function witha calibration transformation according
  113. to calibration_builder.py. Calibration applies positive monotonic
  114. transformations to inputs (i.e. score ordering is strictly preserved or
  115. adjacent scores are mapped to the same score). When calibration is
  116. class-agnostic, the highest-scoring class remains unchanged, unless two
  117. adjacent scores are mapped to the same value and one class arbitrarily
  118. selected to break the tie. In per-class calibration, it's possible (though
  119. rare in practice) that the highest-scoring class will change, since positive
  120. monotonicity is only required to hold within each class.
  121. Args:
  122. score_converter_fn: callable that takes logit scores as input.
  123. calibration_config: post_processing_pb2.PostProcessing.calibration_config.
  124. Returns:
  125. Callable calibrated score coverter op.
  126. """
  127. calibration_fn = calibration_builder.build(calibration_config)
  128. def calibrated_score_converter_fn(logits):
  129. converted_logits = score_converter_fn(logits)
  130. return calibration_fn(converted_logits)
  131. calibrated_score_converter_fn.__name__ = (
  132. 'calibrate_with_%s' % calibration_config.WhichOneof('calibrator'))
  133. return calibrated_score_converter_fn