You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

160 lines
6.7 KiB

6 years ago
  1. # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Builder function for post processing operations."""
  16. import functools
  17. import tensorflow as tf
  18. from object_detection.builders import calibration_builder
  19. from object_detection.core import post_processing
  20. from object_detection.protos import post_processing_pb2
  21. def build(post_processing_config):
  22. """Builds callables for post-processing operations.
  23. Builds callables for non-max suppression, score conversion, and (optionally)
  24. calibration based on the configuration.
  25. Non-max suppression callable takes `boxes`, `scores`, and optionally
  26. `clip_window`, `parallel_iterations` `masks, and `scope` as inputs. It returns
  27. `nms_boxes`, `nms_scores`, `nms_classes` `nms_masks` and `num_detections`. See
  28. post_processing.batch_multiclass_non_max_suppression for the type and shape
  29. of these tensors.
  30. Score converter callable should be called with `input` tensor. The callable
  31. returns the output from one of 3 tf operations based on the configuration -
  32. tf.identity, tf.sigmoid or tf.nn.softmax. If a calibration config is provided,
  33. score_converter also applies calibration transformations, as defined in
  34. calibration_builder.py. See tensorflow documentation for argument and return
  35. value descriptions.
  36. Args:
  37. post_processing_config: post_processing.proto object containing the
  38. parameters for the post-processing operations.
  39. Returns:
  40. non_max_suppressor_fn: Callable for non-max suppression.
  41. score_converter_fn: Callable for score conversion.
  42. Raises:
  43. ValueError: if the post_processing_config is of incorrect type.
  44. """
  45. if not isinstance(post_processing_config, post_processing_pb2.PostProcessing):
  46. raise ValueError('post_processing_config not of type '
  47. 'post_processing_pb2.Postprocessing.')
  48. non_max_suppressor_fn = _build_non_max_suppressor(
  49. post_processing_config.batch_non_max_suppression)
  50. score_converter_fn = _build_score_converter(
  51. post_processing_config.score_converter,
  52. post_processing_config.logit_scale)
  53. if post_processing_config.HasField('calibration_config'):
  54. score_converter_fn = _build_calibrated_score_converter(
  55. score_converter_fn,
  56. post_processing_config.calibration_config)
  57. return non_max_suppressor_fn, score_converter_fn
  58. def _build_non_max_suppressor(nms_config):
  59. """Builds non-max suppresson based on the nms config.
  60. Args:
  61. nms_config: post_processing_pb2.PostProcessing.BatchNonMaxSuppression proto.
  62. Returns:
  63. non_max_suppressor_fn: Callable non-max suppressor.
  64. Raises:
  65. ValueError: On incorrect iou_threshold or on incompatible values of
  66. max_total_detections and max_detections_per_class.
  67. """
  68. if nms_config.iou_threshold < 0 or nms_config.iou_threshold > 1.0:
  69. raise ValueError('iou_threshold not in [0, 1.0].')
  70. if nms_config.max_detections_per_class > nms_config.max_total_detections:
  71. raise ValueError('max_detections_per_class should be no greater than '
  72. 'max_total_detections.')
  73. non_max_suppressor_fn = functools.partial(
  74. post_processing.batch_multiclass_non_max_suppression,
  75. score_thresh=nms_config.score_threshold,
  76. iou_thresh=nms_config.iou_threshold,
  77. max_size_per_class=nms_config.max_detections_per_class,
  78. max_total_size=nms_config.max_total_detections,
  79. use_static_shapes=nms_config.use_static_shapes)
  80. return non_max_suppressor_fn
  81. def _score_converter_fn_with_logit_scale(tf_score_converter_fn, logit_scale):
  82. """Create a function to scale logits then apply a Tensorflow function."""
  83. def score_converter_fn(logits):
  84. scaled_logits = tf.divide(logits, logit_scale, name='scale_logits')
  85. return tf_score_converter_fn(scaled_logits, name='convert_scores')
  86. score_converter_fn.__name__ = '%s_with_logit_scale' % (
  87. tf_score_converter_fn.__name__)
  88. return score_converter_fn
  89. def _build_score_converter(score_converter_config, logit_scale):
  90. """Builds score converter based on the config.
  91. Builds one of [tf.identity, tf.sigmoid, tf.softmax] score converters based on
  92. the config.
  93. Args:
  94. score_converter_config: post_processing_pb2.PostProcessing.score_converter.
  95. logit_scale: temperature to use for SOFTMAX score_converter.
  96. Returns:
  97. Callable score converter op.
  98. Raises:
  99. ValueError: On unknown score converter.
  100. """
  101. if score_converter_config == post_processing_pb2.PostProcessing.IDENTITY:
  102. return _score_converter_fn_with_logit_scale(tf.identity, logit_scale)
  103. if score_converter_config == post_processing_pb2.PostProcessing.SIGMOID:
  104. return _score_converter_fn_with_logit_scale(tf.sigmoid, logit_scale)
  105. if score_converter_config == post_processing_pb2.PostProcessing.SOFTMAX:
  106. return _score_converter_fn_with_logit_scale(tf.nn.softmax, logit_scale)
  107. raise ValueError('Unknown score converter.')
  108. def _build_calibrated_score_converter(score_converter_fn, calibration_config):
  109. """Wraps a score_converter_fn, adding a calibration step.
  110. Builds a score converter function witha calibration transformation according
  111. to calibration_builder.py. Calibration applies positive monotonic
  112. transformations to inputs (i.e. score ordering is strictly preserved or
  113. adjacent scores are mapped to the same score). When calibration is
  114. class-agnostic, the highest-scoring class remains unchanged, unless two
  115. adjacent scores are mapped to the same value and one class arbitrarily
  116. selected to break the tie. In per-class calibration, it's possible (though
  117. rare in practice) that the highest-scoring class will change, since positive
  118. monotonicity is only required to hold within each class.
  119. Args:
  120. score_converter_fn: callable that takes logit scores as input.
  121. calibration_config: post_processing_pb2.PostProcessing.calibration_config.
  122. Returns:
  123. Callable calibrated score coverter op.
  124. """
  125. calibration_fn = calibration_builder.build(calibration_config)
  126. def calibrated_score_converter_fn(logits):
  127. converted_logits = score_converter_fn(logits)
  128. return calibration_fn(converted_logits)
  129. calibrated_score_converter_fn.__name__ = (
  130. 'calibrate_with_%s' % calibration_config.WhichOneof('calibrator'))
  131. return calibrated_score_converter_fn