You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

228 lines
10 KiB

  1. # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Class for evaluating object detections with calibration metrics."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import tensorflow as tf
  20. from object_detection.box_coders import mean_stddev_box_coder
  21. from object_detection.core import box_list
  22. from object_detection.core import region_similarity_calculator
  23. from object_detection.core import standard_fields
  24. from object_detection.core import target_assigner
  25. from object_detection.matchers import argmax_matcher
  26. from object_detection.metrics import calibration_metrics
  27. from object_detection.utils import object_detection_evaluation
  28. # TODO(zbeaver): Implement metrics per category.
  29. class CalibrationDetectionEvaluator(
  30. object_detection_evaluation.DetectionEvaluator):
  31. """Class to evaluate calibration detection metrics."""
  32. def __init__(self,
  33. categories,
  34. iou_threshold=0.5):
  35. """Constructor.
  36. Args:
  37. categories: A list of dicts, each of which has the following keys -
  38. 'id': (required) an integer id uniquely identifying this category.
  39. 'name': (required) string representing category name e.g., 'cat', 'dog'.
  40. iou_threshold: Threshold above which to consider a box as matched during
  41. evaluation.
  42. """
  43. super(CalibrationDetectionEvaluator, self).__init__(categories)
  44. # Constructing target_assigner to match detections to groundtruth.
  45. similarity_calc = region_similarity_calculator.IouSimilarity()
  46. matcher = argmax_matcher.ArgMaxMatcher(
  47. matched_threshold=iou_threshold, unmatched_threshold=iou_threshold)
  48. box_coder = mean_stddev_box_coder.MeanStddevBoxCoder(stddev=0.1)
  49. self._target_assigner = target_assigner.TargetAssigner(
  50. similarity_calc, matcher, box_coder)
  51. def match_single_image_info(self, image_info):
  52. """Match detections to groundtruth for a single image.
  53. Detections are matched to available groundtruth in the image based on the
  54. IOU threshold from the constructor. The classes of the detections and
  55. groundtruth matches are then compared. Detections that do not have IOU above
  56. the required threshold or have different classes from their match are
  57. considered negative matches. All inputs in `image_info` originate or are
  58. inferred from the eval_dict passed to class method
  59. `get_estimator_eval_metric_ops`.
  60. Args:
  61. image_info: a tuple or list containing the following (in order):
  62. - gt_boxes: tf.float32 tensor of groundtruth boxes.
  63. - gt_classes: tf.int64 tensor of groundtruth classes associated with
  64. groundtruth boxes.
  65. - num_gt_box: scalar indicating the number of groundtruth boxes per
  66. image.
  67. - det_boxes: tf.float32 tensor of detection boxes.
  68. - det_classes: tf.int64 tensor of detection classes associated with
  69. detection boxes.
  70. - num_det_box: scalar indicating the number of detection boxes per
  71. image.
  72. Returns:
  73. is_class_matched: tf.int64 tensor identical in shape to det_boxes,
  74. indicating whether detection boxes matched with and had the same
  75. class as groundtruth annotations.
  76. """
  77. (gt_boxes, gt_classes, num_gt_box, det_boxes, det_classes,
  78. num_det_box) = image_info
  79. detection_boxes = det_boxes[:num_det_box]
  80. detection_classes = det_classes[:num_det_box]
  81. groundtruth_boxes = gt_boxes[:num_gt_box]
  82. groundtruth_classes = gt_classes[:num_gt_box]
  83. det_boxlist = box_list.BoxList(detection_boxes)
  84. gt_boxlist = box_list.BoxList(groundtruth_boxes)
  85. # Target assigner requires classes in one-hot format. An additional
  86. # dimension is required since gt_classes are 1-indexed; the zero index is
  87. # provided to all non-matches.
  88. one_hot_depth = tf.cast(tf.add(tf.reduce_max(groundtruth_classes), 1),
  89. dtype=tf.int32)
  90. gt_classes_one_hot = tf.one_hot(
  91. groundtruth_classes, one_hot_depth, dtype=tf.float32)
  92. one_hot_cls_targets, _, _, _, _ = self._target_assigner.assign(
  93. det_boxlist,
  94. gt_boxlist,
  95. gt_classes_one_hot,
  96. unmatched_class_label=tf.zeros(shape=one_hot_depth, dtype=tf.float32))
  97. # Transform from one-hot back to indexes.
  98. cls_targets = tf.argmax(one_hot_cls_targets, axis=1)
  99. is_class_matched = tf.cast(
  100. tf.equal(tf.cast(cls_targets, tf.int64), detection_classes),
  101. dtype=tf.int64)
  102. return is_class_matched
  103. def get_estimator_eval_metric_ops(self, eval_dict):
  104. """Returns a dictionary of eval metric ops.
  105. Note that once value_op is called, the detections and groundtruth added via
  106. update_op are cleared.
  107. This function can take in groundtruth and detections for a batch of images,
  108. or for a single image. For the latter case, the batch dimension for input
  109. tensors need not be present.
  110. Args:
  111. eval_dict: A dictionary that holds tensors for evaluating object detection
  112. performance. For single-image evaluation, this dictionary may be
  113. produced from eval_util.result_dict_for_single_example(). If multi-image
  114. evaluation, `eval_dict` should contain the fields
  115. 'num_groundtruth_boxes_per_image' and 'num_det_boxes_per_image' to
  116. properly unpad the tensors from the batch.
  117. Returns:
  118. a dictionary of metric names to tuple of value_op and update_op that can
  119. be used as eval metric ops in tf.estimator.EstimatorSpec. Note that all
  120. update ops must be run together and similarly all value ops must be run
  121. together to guarantee correct behaviour.
  122. """
  123. # Unpack items from the evaluation dictionary.
  124. input_data_fields = standard_fields.InputDataFields
  125. detection_fields = standard_fields.DetectionResultFields
  126. image_id = eval_dict[input_data_fields.key]
  127. groundtruth_boxes = eval_dict[input_data_fields.groundtruth_boxes]
  128. groundtruth_classes = eval_dict[input_data_fields.groundtruth_classes]
  129. detection_boxes = eval_dict[detection_fields.detection_boxes]
  130. detection_scores = eval_dict[detection_fields.detection_scores]
  131. detection_classes = eval_dict[detection_fields.detection_classes]
  132. num_gt_boxes_per_image = eval_dict.get(
  133. 'num_groundtruth_boxes_per_image', None)
  134. num_det_boxes_per_image = eval_dict.get('num_det_boxes_per_image', None)
  135. is_annotated_batched = eval_dict.get('is_annotated', None)
  136. if not image_id.shape.as_list():
  137. # Apply a batch dimension to all tensors.
  138. image_id = tf.expand_dims(image_id, 0)
  139. groundtruth_boxes = tf.expand_dims(groundtruth_boxes, 0)
  140. groundtruth_classes = tf.expand_dims(groundtruth_classes, 0)
  141. detection_boxes = tf.expand_dims(detection_boxes, 0)
  142. detection_scores = tf.expand_dims(detection_scores, 0)
  143. detection_classes = tf.expand_dims(detection_classes, 0)
  144. if num_gt_boxes_per_image is None:
  145. num_gt_boxes_per_image = tf.shape(groundtruth_boxes)[1:2]
  146. else:
  147. num_gt_boxes_per_image = tf.expand_dims(num_gt_boxes_per_image, 0)
  148. if num_det_boxes_per_image is None:
  149. num_det_boxes_per_image = tf.shape(detection_boxes)[1:2]
  150. else:
  151. num_det_boxes_per_image = tf.expand_dims(num_det_boxes_per_image, 0)
  152. if is_annotated_batched is None:
  153. is_annotated_batched = tf.constant([True])
  154. else:
  155. is_annotated_batched = tf.expand_dims(is_annotated_batched, 0)
  156. else:
  157. if num_gt_boxes_per_image is None:
  158. num_gt_boxes_per_image = tf.tile(
  159. tf.shape(groundtruth_boxes)[1:2],
  160. multiples=tf.shape(groundtruth_boxes)[0:1])
  161. if num_det_boxes_per_image is None:
  162. num_det_boxes_per_image = tf.tile(
  163. tf.shape(detection_boxes)[1:2],
  164. multiples=tf.shape(detection_boxes)[0:1])
  165. if is_annotated_batched is None:
  166. is_annotated_batched = tf.ones_like(image_id, dtype=tf.bool)
  167. # Filter images based on is_annotated_batched and match detections.
  168. image_info = [tf.boolean_mask(tensor, is_annotated_batched) for tensor in
  169. [groundtruth_boxes, groundtruth_classes,
  170. num_gt_boxes_per_image, detection_boxes, detection_classes,
  171. num_det_boxes_per_image]]
  172. is_class_matched = tf.map_fn(
  173. self.match_single_image_info, image_info, dtype=tf.int64)
  174. y_true = tf.squeeze(is_class_matched)
  175. y_pred = tf.squeeze(tf.boolean_mask(detection_scores, is_annotated_batched))
  176. ece, update_op = calibration_metrics.expected_calibration_error(
  177. y_true, y_pred)
  178. return {'CalibrationError/ExpectedCalibrationError': (ece, update_op)}
  179. def add_single_ground_truth_image_info(self, image_id, groundtruth_dict):
  180. """Adds groundtruth for a single image to be used for evaluation.
  181. Args:
  182. image_id: A unique string/integer identifier for the image.
  183. groundtruth_dict: A dictionary of groundtruth numpy arrays required
  184. for evaluations.
  185. """
  186. raise NotImplementedError
  187. def add_single_detected_image_info(self, image_id, detections_dict):
  188. """Adds detections for a single image to be used for evaluation.
  189. Args:
  190. image_id: A unique string/integer identifier for the image.
  191. detections_dict: A dictionary of detection numpy arrays required for
  192. evaluation.
  193. """
  194. raise NotImplementedError
  195. def evaluate(self):
  196. """Evaluates detections and returns a dictionary of metrics."""
  197. raise NotImplementedError
  198. def clear(self):
  199. """Clears the state to prepare for a fresh evaluation."""
  200. raise NotImplementedError