You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

200 lines
8.9 KiB

6 years ago
  1. # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for tensorflow_models.object_detection.metrics.calibration_evaluation.""" # pylint: disable=line-too-long
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import tensorflow as tf
  20. from object_detection.core import standard_fields
  21. from object_detection.metrics import calibration_evaluation
  22. def _get_categories_list():
  23. return [{
  24. 'id': 1,
  25. 'name': 'person'
  26. }, {
  27. 'id': 2,
  28. 'name': 'dog'
  29. }, {
  30. 'id': 3,
  31. 'name': 'cat'
  32. }]
  33. class CalibrationDetectionEvaluationTest(tf.test.TestCase):
  34. def _get_ece(self, ece_op, update_op):
  35. """Return scalar expected calibration error."""
  36. with self.test_session() as sess:
  37. metrics_vars = tf.get_collection(tf.GraphKeys.METRIC_VARIABLES)
  38. sess.run(tf.variables_initializer(var_list=metrics_vars))
  39. _ = sess.run(update_op)
  40. return sess.run(ece_op)
  41. def testGetECEWithMatchingGroundtruthAndDetections(self):
  42. """Tests that ECE is calculated correctly when box matches exist."""
  43. calibration_evaluator = calibration_evaluation.CalibrationDetectionEvaluator(
  44. _get_categories_list(), iou_threshold=0.5)
  45. input_data_fields = standard_fields.InputDataFields
  46. detection_fields = standard_fields.DetectionResultFields
  47. # All gt and detection boxes match.
  48. base_eval_dict = {
  49. input_data_fields.key:
  50. tf.constant(['image_1', 'image_2', 'image_3']),
  51. input_data_fields.groundtruth_boxes:
  52. tf.constant([[[100., 100., 200., 200.]],
  53. [[50., 50., 100., 100.]],
  54. [[25., 25., 50., 50.]]],
  55. dtype=tf.float32),
  56. detection_fields.detection_boxes:
  57. tf.constant([[[100., 100., 200., 200.]],
  58. [[50., 50., 100., 100.]],
  59. [[25., 25., 50., 50.]]],
  60. dtype=tf.float32),
  61. input_data_fields.groundtruth_classes:
  62. tf.constant([[1], [2], [3]], dtype=tf.int64),
  63. # Note that, in the zero ECE case, the detection class for image_2
  64. # should NOT match groundtruth, since the detection score is zero.
  65. detection_fields.detection_scores:
  66. tf.constant([[1.0], [0.0], [1.0]], dtype=tf.float32)
  67. }
  68. # Zero ECE (perfectly calibrated).
  69. zero_ece_eval_dict = base_eval_dict.copy()
  70. zero_ece_eval_dict[detection_fields.detection_classes] = tf.constant(
  71. [[1], [1], [3]], dtype=tf.int64)
  72. zero_ece_op, zero_ece_update_op = (
  73. calibration_evaluator.get_estimator_eval_metric_ops(zero_ece_eval_dict)
  74. ['CalibrationError/ExpectedCalibrationError'])
  75. zero_ece = self._get_ece(zero_ece_op, zero_ece_update_op)
  76. self.assertAlmostEqual(zero_ece, 0.0)
  77. # ECE of 1 (poorest calibration).
  78. one_ece_eval_dict = base_eval_dict.copy()
  79. one_ece_eval_dict[detection_fields.detection_classes] = tf.constant(
  80. [[3], [2], [1]], dtype=tf.int64)
  81. one_ece_op, one_ece_update_op = (
  82. calibration_evaluator.get_estimator_eval_metric_ops(one_ece_eval_dict)
  83. ['CalibrationError/ExpectedCalibrationError'])
  84. one_ece = self._get_ece(one_ece_op, one_ece_update_op)
  85. self.assertAlmostEqual(one_ece, 1.0)
  86. def testGetECEWithUnmatchedGroundtruthAndDetections(self):
  87. """Tests that ECE is correctly calculated when boxes are unmatched."""
  88. calibration_evaluator = calibration_evaluation.CalibrationDetectionEvaluator(
  89. _get_categories_list(), iou_threshold=0.5)
  90. input_data_fields = standard_fields.InputDataFields
  91. detection_fields = standard_fields.DetectionResultFields
  92. # No gt and detection boxes match.
  93. eval_dict = {
  94. input_data_fields.key:
  95. tf.constant(['image_1', 'image_2', 'image_3']),
  96. input_data_fields.groundtruth_boxes:
  97. tf.constant([[[100., 100., 200., 200.]],
  98. [[50., 50., 100., 100.]],
  99. [[25., 25., 50., 50.]]],
  100. dtype=tf.float32),
  101. detection_fields.detection_boxes:
  102. tf.constant([[[50., 50., 100., 100.]],
  103. [[25., 25., 50., 50.]],
  104. [[100., 100., 200., 200.]]],
  105. dtype=tf.float32),
  106. input_data_fields.groundtruth_classes:
  107. tf.constant([[1], [2], [3]], dtype=tf.int64),
  108. detection_fields.detection_classes:
  109. tf.constant([[1], [1], [3]], dtype=tf.int64),
  110. # Detection scores of zero when boxes are unmatched = ECE of zero.
  111. detection_fields.detection_scores:
  112. tf.constant([[0.0], [0.0], [0.0]], dtype=tf.float32)
  113. }
  114. ece_op, update_op = calibration_evaluator.get_estimator_eval_metric_ops(
  115. eval_dict)['CalibrationError/ExpectedCalibrationError']
  116. ece = self._get_ece(ece_op, update_op)
  117. self.assertAlmostEqual(ece, 0.0)
  118. def testGetECEWithBatchedDetections(self):
  119. """Tests that ECE is correct with multiple detections per image."""
  120. calibration_evaluator = calibration_evaluation.CalibrationDetectionEvaluator(
  121. _get_categories_list(), iou_threshold=0.5)
  122. input_data_fields = standard_fields.InputDataFields
  123. detection_fields = standard_fields.DetectionResultFields
  124. # Note that image_2 has mismatched classes and detection scores but should
  125. # still produce ECE of 0 because detection scores are also 0.
  126. eval_dict = {
  127. input_data_fields.key:
  128. tf.constant(['image_1', 'image_2', 'image_3']),
  129. input_data_fields.groundtruth_boxes:
  130. tf.constant([[[100., 100., 200., 200.], [50., 50., 100., 100.]],
  131. [[50., 50., 100., 100.], [100., 100., 200., 200.]],
  132. [[25., 25., 50., 50.], [100., 100., 200., 200.]]],
  133. dtype=tf.float32),
  134. detection_fields.detection_boxes:
  135. tf.constant([[[100., 100., 200., 200.], [50., 50., 100., 100.]],
  136. [[50., 50., 100., 100.], [25., 25., 50., 50.]],
  137. [[25., 25., 50., 50.], [100., 100., 200., 200.]]],
  138. dtype=tf.float32),
  139. input_data_fields.groundtruth_classes:
  140. tf.constant([[1, 2], [2, 3], [3, 1]], dtype=tf.int64),
  141. detection_fields.detection_classes:
  142. tf.constant([[1, 2], [1, 1], [3, 1]], dtype=tf.int64),
  143. detection_fields.detection_scores:
  144. tf.constant([[1.0, 1.0], [0.0, 0.0], [1.0, 1.0]], dtype=tf.float32)
  145. }
  146. ece_op, update_op = calibration_evaluator.get_estimator_eval_metric_ops(
  147. eval_dict)['CalibrationError/ExpectedCalibrationError']
  148. ece = self._get_ece(ece_op, update_op)
  149. self.assertAlmostEqual(ece, 0.0)
  150. def testGetECEWhenImagesFilteredByIsAnnotated(self):
  151. """Tests that ECE is correct when detections filtered by is_annotated."""
  152. calibration_evaluator = calibration_evaluation.CalibrationDetectionEvaluator(
  153. _get_categories_list(), iou_threshold=0.5)
  154. input_data_fields = standard_fields.InputDataFields
  155. detection_fields = standard_fields.DetectionResultFields
  156. # ECE will be 0 only if the third image is filtered by is_annotated.
  157. eval_dict = {
  158. input_data_fields.key:
  159. tf.constant(['image_1', 'image_2', 'image_3']),
  160. input_data_fields.groundtruth_boxes:
  161. tf.constant([[[100., 100., 200., 200.]],
  162. [[50., 50., 100., 100.]],
  163. [[25., 25., 50., 50.]]],
  164. dtype=tf.float32),
  165. detection_fields.detection_boxes:
  166. tf.constant([[[100., 100., 200., 200.]],
  167. [[50., 50., 100., 100.]],
  168. [[25., 25., 50., 50.]]],
  169. dtype=tf.float32),
  170. input_data_fields.groundtruth_classes:
  171. tf.constant([[1], [2], [1]], dtype=tf.int64),
  172. detection_fields.detection_classes:
  173. tf.constant([[1], [1], [3]], dtype=tf.int64),
  174. detection_fields.detection_scores:
  175. tf.constant([[1.0], [0.0], [1.0]], dtype=tf.float32),
  176. 'is_annotated': tf.constant([True, True, False], dtype=tf.bool)
  177. }
  178. ece_op, update_op = calibration_evaluator.get_estimator_eval_metric_ops(
  179. eval_dict)['CalibrationError/ExpectedCalibrationError']
  180. ece = self._get_ece(ece_op, update_op)
  181. self.assertAlmostEqual(ece, 0.0)
  182. if __name__ == '__main__':
  183. tf.test.main()