You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1238 lines
54 KiB

  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """object_detection_evaluation module.
  16. ObjectDetectionEvaluation is a class which manages ground truth information of a
  17. object detection dataset, and computes frequently used detection metrics such as
  18. Precision, Recall, CorLoc of the provided detection results.
  19. It supports the following operations:
  20. 1) Add ground truth information of images sequentially.
  21. 2) Add detection result of images sequentially.
  22. 3) Evaluate detection metrics on already inserted detection results.
  23. 4) Write evaluation result into a pickle file for future processing or
  24. visualization.
  25. Note: This module operates on numpy boxes and box lists.
  26. """
  27. from abc import ABCMeta
  28. from abc import abstractmethod
  29. import collections
  30. import logging
  31. import unicodedata
  32. import numpy as np
  33. import tensorflow as tf
  34. from object_detection.core import standard_fields
  35. from object_detection.utils import label_map_util
  36. from object_detection.utils import metrics
  37. from object_detection.utils import per_image_evaluation
  38. class DetectionEvaluator(object):
  39. """Interface for object detection evalution classes.
  40. Example usage of the Evaluator:
  41. ------------------------------
  42. evaluator = DetectionEvaluator(categories)
  43. # Detections and groundtruth for image 1.
  44. evaluator.add_single_groundtruth_image_info(...)
  45. evaluator.add_single_detected_image_info(...)
  46. # Detections and groundtruth for image 2.
  47. evaluator.add_single_groundtruth_image_info(...)
  48. evaluator.add_single_detected_image_info(...)
  49. metrics_dict = evaluator.evaluate()
  50. """
  51. __metaclass__ = ABCMeta
  52. def __init__(self, categories):
  53. """Constructor.
  54. Args:
  55. categories: A list of dicts, each of which has the following keys -
  56. 'id': (required) an integer id uniquely identifying this category.
  57. 'name': (required) string representing category name e.g., 'cat', 'dog'.
  58. """
  59. self._categories = categories
  60. def observe_result_dict_for_single_example(self, eval_dict):
  61. """Observes an evaluation result dict for a single example.
  62. When executing eagerly, once all observations have been observed by this
  63. method you can use `.evaluate()` to get the final metrics.
  64. When using `tf.estimator.Estimator` for evaluation this function is used by
  65. `get_estimator_eval_metric_ops()` to construct the metric update op.
  66. Args:
  67. eval_dict: A dictionary that holds tensors for evaluating an object
  68. detection model, returned from
  69. eval_util.result_dict_for_single_example().
  70. Returns:
  71. None when executing eagerly, or an update_op that can be used to update
  72. the eval metrics in `tf.estimator.EstimatorSpec`.
  73. """
  74. raise NotImplementedError('Not implemented for this evaluator!')
  75. @abstractmethod
  76. def add_single_ground_truth_image_info(self, image_id, groundtruth_dict):
  77. """Adds groundtruth for a single image to be used for evaluation.
  78. Args:
  79. image_id: A unique string/integer identifier for the image.
  80. groundtruth_dict: A dictionary of groundtruth numpy arrays required
  81. for evaluations.
  82. """
  83. pass
  84. @abstractmethod
  85. def add_single_detected_image_info(self, image_id, detections_dict):
  86. """Adds detections for a single image to be used for evaluation.
  87. Args:
  88. image_id: A unique string/integer identifier for the image.
  89. detections_dict: A dictionary of detection numpy arrays required
  90. for evaluation.
  91. """
  92. pass
  93. def get_estimator_eval_metric_ops(self, eval_dict):
  94. """Returns dict of metrics to use with `tf.estimator.EstimatorSpec`.
  95. Note that this must only be implemented if performing evaluation with a
  96. `tf.estimator.Estimator`.
  97. Args:
  98. eval_dict: A dictionary that holds tensors for evaluating an object
  99. detection model, returned from
  100. eval_util.result_dict_for_single_example().
  101. Returns:
  102. A dictionary of metric names to tuple of value_op and update_op that can
  103. be used as eval metric ops in `tf.estimator.EstimatorSpec`.
  104. """
  105. pass
  106. @abstractmethod
  107. def evaluate(self):
  108. """Evaluates detections and returns a dictionary of metrics."""
  109. pass
  110. @abstractmethod
  111. def clear(self):
  112. """Clears the state to prepare for a fresh evaluation."""
  113. pass
  114. class ObjectDetectionEvaluator(DetectionEvaluator):
  115. """A class to evaluate detections."""
  116. def __init__(self,
  117. categories,
  118. matching_iou_threshold=0.5,
  119. recall_lower_bound=0.0,
  120. recall_upper_bound=1.0,
  121. evaluate_corlocs=False,
  122. evaluate_precision_recall=False,
  123. metric_prefix=None,
  124. use_weighted_mean_ap=False,
  125. evaluate_masks=False,
  126. group_of_weight=0.0):
  127. """Constructor.
  128. Args:
  129. categories: A list of dicts, each of which has the following keys -
  130. 'id': (required) an integer id uniquely identifying this category.
  131. 'name': (required) string representing category name e.g., 'cat', 'dog'.
  132. matching_iou_threshold: IOU threshold to use for matching groundtruth
  133. boxes to detection boxes.
  134. recall_lower_bound: lower bound of recall operating area.
  135. recall_upper_bound: upper bound of recall operating area.
  136. evaluate_corlocs: (optional) boolean which determines if corloc scores
  137. are to be returned or not.
  138. evaluate_precision_recall: (optional) boolean which determines if
  139. precision and recall values are to be returned or not.
  140. metric_prefix: (optional) string prefix for metric name; if None, no
  141. prefix is used.
  142. use_weighted_mean_ap: (optional) boolean which determines if the mean
  143. average precision is computed directly from the scores and tp_fp_labels
  144. of all classes.
  145. evaluate_masks: If False, evaluation will be performed based on boxes.
  146. If True, mask evaluation will be performed instead.
  147. group_of_weight: Weight of group-of boxes.If set to 0, detections of the
  148. correct class within a group-of box are ignored. If weight is > 0, then
  149. if at least one detection falls within a group-of box with
  150. matching_iou_threshold, weight group_of_weight is added to true
  151. positives. Consequently, if no detection falls within a group-of box,
  152. weight group_of_weight is added to false negatives.
  153. Raises:
  154. ValueError: If the category ids are not 1-indexed.
  155. """
  156. super(ObjectDetectionEvaluator, self).__init__(categories)
  157. self._num_classes = max([cat['id'] for cat in categories])
  158. if min(cat['id'] for cat in categories) < 1:
  159. raise ValueError('Classes should be 1-indexed.')
  160. self._matching_iou_threshold = matching_iou_threshold
  161. self._recall_lower_bound = recall_lower_bound
  162. self._recall_upper_bound = recall_upper_bound
  163. self._use_weighted_mean_ap = use_weighted_mean_ap
  164. self._label_id_offset = 1
  165. self._evaluate_masks = evaluate_masks
  166. self._group_of_weight = group_of_weight
  167. self._evaluation = ObjectDetectionEvaluation(
  168. num_groundtruth_classes=self._num_classes,
  169. matching_iou_threshold=self._matching_iou_threshold,
  170. recall_lower_bound=self._recall_lower_bound,
  171. recall_upper_bound=self._recall_upper_bound,
  172. use_weighted_mean_ap=self._use_weighted_mean_ap,
  173. label_id_offset=self._label_id_offset,
  174. group_of_weight=self._group_of_weight)
  175. self._image_ids = set([])
  176. self._evaluate_corlocs = evaluate_corlocs
  177. self._evaluate_precision_recall = evaluate_precision_recall
  178. self._metric_prefix = (metric_prefix + '_') if metric_prefix else ''
  179. self._expected_keys = set([
  180. standard_fields.InputDataFields.key,
  181. standard_fields.InputDataFields.groundtruth_boxes,
  182. standard_fields.InputDataFields.groundtruth_classes,
  183. standard_fields.InputDataFields.groundtruth_difficult,
  184. standard_fields.InputDataFields.groundtruth_instance_masks,
  185. standard_fields.DetectionResultFields.detection_boxes,
  186. standard_fields.DetectionResultFields.detection_scores,
  187. standard_fields.DetectionResultFields.detection_classes,
  188. standard_fields.DetectionResultFields.detection_masks
  189. ])
  190. self._build_metric_names()
  191. def _build_metric_names(self):
  192. """Builds a list with metric names."""
  193. if self._recall_lower_bound > 0.0 or self._recall_upper_bound < 1.0:
  194. self._metric_names = [
  195. self._metric_prefix +
  196. 'Precision/mAP@{}IOU@[{:.1f},{:.1f}]Recall'.format(
  197. self._matching_iou_threshold, self._recall_lower_bound,
  198. self._recall_upper_bound)
  199. ]
  200. else:
  201. self._metric_names = [
  202. self._metric_prefix +
  203. 'Precision/mAP@{}IOU'.format(self._matching_iou_threshold)
  204. ]
  205. if self._evaluate_corlocs:
  206. self._metric_names.append(
  207. self._metric_prefix +
  208. 'Precision/meanCorLoc@{}IOU'.format(self._matching_iou_threshold))
  209. category_index = label_map_util.create_category_index(self._categories)
  210. for idx in range(self._num_classes):
  211. if idx + self._label_id_offset in category_index:
  212. category_name = category_index[idx + self._label_id_offset]['name']
  213. try:
  214. category_name = unicode(category_name, 'utf-8')
  215. except TypeError:
  216. pass
  217. category_name = unicodedata.normalize('NFKD', category_name).encode(
  218. 'ascii', 'ignore')
  219. self._metric_names.append(
  220. self._metric_prefix + 'PerformanceByCategory/AP@{}IOU/{}'.format(
  221. self._matching_iou_threshold, category_name))
  222. if self._evaluate_corlocs:
  223. self._metric_names.append(
  224. self._metric_prefix + 'PerformanceByCategory/CorLoc@{}IOU/{}'
  225. .format(self._matching_iou_threshold, category_name))
  226. def add_single_ground_truth_image_info(self, image_id, groundtruth_dict):
  227. """Adds groundtruth for a single image to be used for evaluation.
  228. Args:
  229. image_id: A unique string/integer identifier for the image.
  230. groundtruth_dict: A dictionary containing -
  231. standard_fields.InputDataFields.groundtruth_boxes: float32 numpy array
  232. of shape [num_boxes, 4] containing `num_boxes` groundtruth boxes of
  233. the format [ymin, xmin, ymax, xmax] in absolute image coordinates.
  234. standard_fields.InputDataFields.groundtruth_classes: integer numpy array
  235. of shape [num_boxes] containing 1-indexed groundtruth classes for the
  236. boxes.
  237. standard_fields.InputDataFields.groundtruth_difficult: Optional length
  238. M numpy boolean array denoting whether a ground truth box is a
  239. difficult instance or not. This field is optional to support the case
  240. that no boxes are difficult.
  241. standard_fields.InputDataFields.groundtruth_instance_masks: Optional
  242. numpy array of shape [num_boxes, height, width] with values in {0, 1}.
  243. Raises:
  244. ValueError: On adding groundtruth for an image more than once. Will also
  245. raise error if instance masks are not in groundtruth dictionary.
  246. """
  247. if image_id in self._image_ids:
  248. raise ValueError('Image with id {} already added.'.format(image_id))
  249. groundtruth_classes = (
  250. groundtruth_dict[standard_fields.InputDataFields.groundtruth_classes] -
  251. self._label_id_offset)
  252. # If the key is not present in the groundtruth_dict or the array is empty
  253. # (unless there are no annotations for the groundtruth on this image)
  254. # use values from the dictionary or insert None otherwise.
  255. if (standard_fields.InputDataFields.groundtruth_difficult in
  256. groundtruth_dict.keys() and
  257. (groundtruth_dict[standard_fields.InputDataFields.groundtruth_difficult]
  258. .size or not groundtruth_classes.size)):
  259. groundtruth_difficult = groundtruth_dict[
  260. standard_fields.InputDataFields.groundtruth_difficult]
  261. else:
  262. groundtruth_difficult = None
  263. if not len(self._image_ids) % 1000:
  264. logging.warn(
  265. 'image %s does not have groundtruth difficult flag specified',
  266. image_id)
  267. groundtruth_masks = None
  268. if self._evaluate_masks:
  269. if (standard_fields.InputDataFields.groundtruth_instance_masks not in
  270. groundtruth_dict):
  271. raise ValueError('Instance masks not in groundtruth dictionary.')
  272. groundtruth_masks = groundtruth_dict[
  273. standard_fields.InputDataFields.groundtruth_instance_masks]
  274. self._evaluation.add_single_ground_truth_image_info(
  275. image_key=image_id,
  276. groundtruth_boxes=groundtruth_dict[
  277. standard_fields.InputDataFields.groundtruth_boxes],
  278. groundtruth_class_labels=groundtruth_classes,
  279. groundtruth_is_difficult_list=groundtruth_difficult,
  280. groundtruth_masks=groundtruth_masks)
  281. self._image_ids.update([image_id])
  282. def add_single_detected_image_info(self, image_id, detections_dict):
  283. """Adds detections for a single image to be used for evaluation.
  284. Args:
  285. image_id: A unique string/integer identifier for the image.
  286. detections_dict: A dictionary containing -
  287. standard_fields.DetectionResultFields.detection_boxes: float32 numpy
  288. array of shape [num_boxes, 4] containing `num_boxes` detection boxes
  289. of the format [ymin, xmin, ymax, xmax] in absolute image coordinates.
  290. standard_fields.DetectionResultFields.detection_scores: float32 numpy
  291. array of shape [num_boxes] containing detection scores for the boxes.
  292. standard_fields.DetectionResultFields.detection_classes: integer numpy
  293. array of shape [num_boxes] containing 1-indexed detection classes for
  294. the boxes.
  295. standard_fields.DetectionResultFields.detection_masks: uint8 numpy
  296. array of shape [num_boxes, height, width] containing `num_boxes` masks
  297. of values ranging between 0 and 1.
  298. Raises:
  299. ValueError: If detection masks are not in detections dictionary.
  300. """
  301. detection_classes = (
  302. detections_dict[standard_fields.DetectionResultFields.detection_classes]
  303. - self._label_id_offset)
  304. detection_masks = None
  305. if self._evaluate_masks:
  306. if (standard_fields.DetectionResultFields.detection_masks not in
  307. detections_dict):
  308. raise ValueError('Detection masks not in detections dictionary.')
  309. detection_masks = detections_dict[
  310. standard_fields.DetectionResultFields.detection_masks]
  311. self._evaluation.add_single_detected_image_info(
  312. image_key=image_id,
  313. detected_boxes=detections_dict[
  314. standard_fields.DetectionResultFields.detection_boxes],
  315. detected_scores=detections_dict[
  316. standard_fields.DetectionResultFields.detection_scores],
  317. detected_class_labels=detection_classes,
  318. detected_masks=detection_masks)
  319. def evaluate(self):
  320. """Compute evaluation result.
  321. Returns:
  322. A dictionary of metrics with the following fields -
  323. 1. summary_metrics:
  324. '<prefix if not empty>_Precision/mAP@<matching_iou_threshold>IOU': mean
  325. average precision at the specified IOU threshold.
  326. 2. per_category_ap: category specific results with keys of the form
  327. '<prefix if not empty>_PerformanceByCategory/
  328. mAP@<matching_iou_threshold>IOU/category'.
  329. """
  330. (per_class_ap, mean_ap, per_class_precision, per_class_recall,
  331. per_class_corloc, mean_corloc) = (
  332. self._evaluation.evaluate())
  333. pascal_metrics = {self._metric_names[0]: mean_ap}
  334. if self._evaluate_corlocs:
  335. pascal_metrics[self._metric_names[1]] = mean_corloc
  336. category_index = label_map_util.create_category_index(self._categories)
  337. for idx in range(per_class_ap.size):
  338. if idx + self._label_id_offset in category_index:
  339. category_name = category_index[idx + self._label_id_offset]['name']
  340. try:
  341. category_name = unicode(category_name, 'utf-8')
  342. except TypeError:
  343. pass
  344. category_name = unicodedata.normalize(
  345. 'NFKD', category_name).encode('ascii', 'ignore')
  346. display_name = (
  347. self._metric_prefix + 'PerformanceByCategory/AP@{}IOU/{}'.format(
  348. self._matching_iou_threshold, category_name))
  349. pascal_metrics[display_name] = per_class_ap[idx]
  350. # Optionally add precision and recall values
  351. if self._evaluate_precision_recall:
  352. display_name = (
  353. self._metric_prefix +
  354. 'PerformanceByCategory/Precision@{}IOU/{}'.format(
  355. self._matching_iou_threshold, category_name))
  356. pascal_metrics[display_name] = per_class_precision[idx]
  357. display_name = (
  358. self._metric_prefix +
  359. 'PerformanceByCategory/Recall@{}IOU/{}'.format(
  360. self._matching_iou_threshold, category_name))
  361. pascal_metrics[display_name] = per_class_recall[idx]
  362. # Optionally add CorLoc metrics.classes
  363. if self._evaluate_corlocs:
  364. display_name = (
  365. self._metric_prefix + 'PerformanceByCategory/CorLoc@{}IOU/{}'
  366. .format(self._matching_iou_threshold, category_name))
  367. pascal_metrics[display_name] = per_class_corloc[idx]
  368. return pascal_metrics
  369. def clear(self):
  370. """Clears the state to prepare for a fresh evaluation."""
  371. self._evaluation = ObjectDetectionEvaluation(
  372. num_groundtruth_classes=self._num_classes,
  373. matching_iou_threshold=self._matching_iou_threshold,
  374. use_weighted_mean_ap=self._use_weighted_mean_ap,
  375. label_id_offset=self._label_id_offset)
  376. self._image_ids.clear()
  377. def get_estimator_eval_metric_ops(self, eval_dict):
  378. """Returns dict of metrics to use with `tf.estimator.EstimatorSpec`.
  379. Note that this must only be implemented if performing evaluation with a
  380. `tf.estimator.Estimator`.
  381. Args:
  382. eval_dict: A dictionary that holds tensors for evaluating an object
  383. detection model, returned from
  384. eval_util.result_dict_for_single_example(). It must contain
  385. standard_fields.InputDataFields.key.
  386. Returns:
  387. A dictionary of metric names to tuple of value_op and update_op that can
  388. be used as eval metric ops in `tf.estimator.EstimatorSpec`.
  389. """
  390. # remove unexpected fields
  391. eval_dict_filtered = dict()
  392. for key, value in eval_dict.items():
  393. if key in self._expected_keys:
  394. eval_dict_filtered[key] = value
  395. eval_dict_keys = eval_dict_filtered.keys()
  396. def update_op(image_id, *eval_dict_batched_as_list):
  397. """Update operation that adds batch of images to ObjectDetectionEvaluator.
  398. Args:
  399. image_id: image id (single id or an array)
  400. *eval_dict_batched_as_list: the values of the dictionary of tensors.
  401. """
  402. if np.isscalar(image_id):
  403. single_example_dict = dict(
  404. zip(eval_dict_keys, eval_dict_batched_as_list))
  405. self.add_single_ground_truth_image_info(image_id, single_example_dict)
  406. self.add_single_detected_image_info(image_id, single_example_dict)
  407. else:
  408. for unzipped_tuple in zip(*eval_dict_batched_as_list):
  409. single_example_dict = dict(zip(eval_dict_keys, unzipped_tuple))
  410. image_id = single_example_dict[standard_fields.InputDataFields.key]
  411. self.add_single_ground_truth_image_info(image_id, single_example_dict)
  412. self.add_single_detected_image_info(image_id, single_example_dict)
  413. args = [eval_dict_filtered[standard_fields.InputDataFields.key]]
  414. args.extend(eval_dict_filtered.values())
  415. update_op = tf.py_func(update_op, args, [])
  416. def first_value_func():
  417. self._metrics = self.evaluate()
  418. self.clear()
  419. return np.float32(self._metrics[self._metric_names[0]])
  420. def value_func_factory(metric_name):
  421. def value_func():
  422. return np.float32(self._metrics[metric_name])
  423. return value_func
  424. # Ensure that the metrics are only evaluated once.
  425. first_value_op = tf.py_func(first_value_func, [], tf.float32)
  426. eval_metric_ops = {self._metric_names[0]: (first_value_op, update_op)}
  427. with tf.control_dependencies([first_value_op]):
  428. for metric_name in self._metric_names[1:]:
  429. eval_metric_ops[metric_name] = (tf.py_func(
  430. value_func_factory(metric_name), [], np.float32), update_op)
  431. return eval_metric_ops
  432. class PascalDetectionEvaluator(ObjectDetectionEvaluator):
  433. """A class to evaluate detections using PASCAL metrics."""
  434. def __init__(self, categories, matching_iou_threshold=0.5):
  435. super(PascalDetectionEvaluator, self).__init__(
  436. categories,
  437. matching_iou_threshold=matching_iou_threshold,
  438. evaluate_corlocs=False,
  439. metric_prefix='PascalBoxes',
  440. use_weighted_mean_ap=False)
  441. class WeightedPascalDetectionEvaluator(ObjectDetectionEvaluator):
  442. """A class to evaluate detections using weighted PASCAL metrics.
  443. Weighted PASCAL metrics computes the mean average precision as the average
  444. precision given the scores and tp_fp_labels of all classes. In comparison,
  445. PASCAL metrics computes the mean average precision as the mean of the
  446. per-class average precisions.
  447. This definition is very similar to the mean of the per-class average
  448. precisions weighted by class frequency. However, they are typically not the
  449. same as the average precision is not a linear function of the scores and
  450. tp_fp_labels.
  451. """
  452. def __init__(self, categories, matching_iou_threshold=0.5):
  453. super(WeightedPascalDetectionEvaluator, self).__init__(
  454. categories,
  455. matching_iou_threshold=matching_iou_threshold,
  456. evaluate_corlocs=False,
  457. metric_prefix='WeightedPascalBoxes',
  458. use_weighted_mean_ap=True)
  459. class PrecisionAtRecallDetectionEvaluator(ObjectDetectionEvaluator):
  460. """A class to evaluate detections using precision@recall metrics."""
  461. def __init__(self,
  462. categories,
  463. matching_iou_threshold=0.5,
  464. recall_lower_bound=0.0,
  465. recall_upper_bound=1.0):
  466. super(PrecisionAtRecallDetectionEvaluator, self).__init__(
  467. categories,
  468. matching_iou_threshold=matching_iou_threshold,
  469. recall_lower_bound=recall_lower_bound,
  470. recall_upper_bound=recall_upper_bound,
  471. evaluate_corlocs=False,
  472. metric_prefix='PrecisionAtRecallBoxes',
  473. use_weighted_mean_ap=False)
  474. class PascalInstanceSegmentationEvaluator(ObjectDetectionEvaluator):
  475. """A class to evaluate instance masks using PASCAL metrics."""
  476. def __init__(self, categories, matching_iou_threshold=0.5):
  477. super(PascalInstanceSegmentationEvaluator, self).__init__(
  478. categories,
  479. matching_iou_threshold=matching_iou_threshold,
  480. evaluate_corlocs=False,
  481. metric_prefix='PascalMasks',
  482. use_weighted_mean_ap=False,
  483. evaluate_masks=True)
  484. class WeightedPascalInstanceSegmentationEvaluator(ObjectDetectionEvaluator):
  485. """A class to evaluate instance masks using weighted PASCAL metrics.
  486. Weighted PASCAL metrics computes the mean average precision as the average
  487. precision given the scores and tp_fp_labels of all classes. In comparison,
  488. PASCAL metrics computes the mean average precision as the mean of the
  489. per-class average precisions.
  490. This definition is very similar to the mean of the per-class average
  491. precisions weighted by class frequency. However, they are typically not the
  492. same as the average precision is not a linear function of the scores and
  493. tp_fp_labels.
  494. """
  495. def __init__(self, categories, matching_iou_threshold=0.5):
  496. super(WeightedPascalInstanceSegmentationEvaluator, self).__init__(
  497. categories,
  498. matching_iou_threshold=matching_iou_threshold,
  499. evaluate_corlocs=False,
  500. metric_prefix='WeightedPascalMasks',
  501. use_weighted_mean_ap=True,
  502. evaluate_masks=True)
  503. class OpenImagesDetectionEvaluator(ObjectDetectionEvaluator):
  504. """A class to evaluate detections using Open Images V2 metrics.
  505. Open Images V2 introduce group_of type of bounding boxes and this metric
  506. handles those boxes appropriately.
  507. """
  508. def __init__(self,
  509. categories,
  510. matching_iou_threshold=0.5,
  511. evaluate_masks=False,
  512. evaluate_corlocs=False,
  513. metric_prefix='OpenImagesV2',
  514. group_of_weight=0.0):
  515. """Constructor.
  516. Args:
  517. categories: A list of dicts, each of which has the following keys -
  518. 'id': (required) an integer id uniquely identifying this category.
  519. 'name': (required) string representing category name e.g., 'cat', 'dog'.
  520. matching_iou_threshold: IOU threshold to use for matching groundtruth
  521. boxes to detection boxes.
  522. evaluate_masks: if True, evaluator evaluates masks.
  523. evaluate_corlocs: if True, additionally evaluates and returns CorLoc.
  524. metric_prefix: Prefix name of the metric.
  525. group_of_weight: Weight of the group-of bounding box. If set to 0 (default
  526. for Open Images V2 detection protocol), detections of the correct class
  527. within a group-of box are ignored. If weight is > 0, then if at least
  528. one detection falls within a group-of box with matching_iou_threshold,
  529. weight group_of_weight is added to true positives. Consequently, if no
  530. detection falls within a group-of box, weight group_of_weight is added
  531. to false negatives.
  532. """
  533. super(OpenImagesDetectionEvaluator, self).__init__(
  534. categories,
  535. matching_iou_threshold,
  536. evaluate_corlocs,
  537. metric_prefix=metric_prefix,
  538. group_of_weight=group_of_weight,
  539. evaluate_masks=evaluate_masks)
  540. self._expected_keys = set([
  541. standard_fields.InputDataFields.key,
  542. standard_fields.InputDataFields.groundtruth_boxes,
  543. standard_fields.InputDataFields.groundtruth_classes,
  544. standard_fields.InputDataFields.groundtruth_group_of,
  545. standard_fields.DetectionResultFields.detection_boxes,
  546. standard_fields.DetectionResultFields.detection_scores,
  547. standard_fields.DetectionResultFields.detection_classes,
  548. ])
  549. if evaluate_masks:
  550. self._expected_keys.add(
  551. standard_fields.InputDataFields.groundtruth_instance_masks)
  552. self._expected_keys.add(
  553. standard_fields.DetectionResultFields.detection_masks)
  554. def add_single_ground_truth_image_info(self, image_id, groundtruth_dict):
  555. """Adds groundtruth for a single image to be used for evaluation.
  556. Args:
  557. image_id: A unique string/integer identifier for the image.
  558. groundtruth_dict: A dictionary containing -
  559. standard_fields.InputDataFields.groundtruth_boxes: float32 numpy array
  560. of shape [num_boxes, 4] containing `num_boxes` groundtruth boxes of
  561. the format [ymin, xmin, ymax, xmax] in absolute image coordinates.
  562. standard_fields.InputDataFields.groundtruth_classes: integer numpy array
  563. of shape [num_boxes] containing 1-indexed groundtruth classes for the
  564. boxes.
  565. standard_fields.InputDataFields.groundtruth_group_of: Optional length
  566. M numpy boolean array denoting whether a groundtruth box contains a
  567. group of instances.
  568. Raises:
  569. ValueError: On adding groundtruth for an image more than once.
  570. """
  571. if image_id in self._image_ids:
  572. raise ValueError('Image with id {} already added.'.format(image_id))
  573. groundtruth_classes = (
  574. groundtruth_dict[standard_fields.InputDataFields.groundtruth_classes] -
  575. self._label_id_offset)
  576. # If the key is not present in the groundtruth_dict or the array is empty
  577. # (unless there are no annotations for the groundtruth on this image)
  578. # use values from the dictionary or insert None otherwise.
  579. if (standard_fields.InputDataFields.groundtruth_group_of in
  580. groundtruth_dict.keys() and
  581. (groundtruth_dict[standard_fields.InputDataFields.groundtruth_group_of]
  582. .size or not groundtruth_classes.size)):
  583. groundtruth_group_of = groundtruth_dict[
  584. standard_fields.InputDataFields.groundtruth_group_of]
  585. else:
  586. groundtruth_group_of = None
  587. if not len(self._image_ids) % 1000:
  588. logging.warn(
  589. 'image %s does not have groundtruth group_of flag specified',
  590. image_id)
  591. if self._evaluate_masks:
  592. groundtruth_masks = groundtruth_dict[
  593. standard_fields.InputDataFields.groundtruth_instance_masks]
  594. else:
  595. groundtruth_masks = None
  596. self._evaluation.add_single_ground_truth_image_info(
  597. image_id,
  598. groundtruth_dict[standard_fields.InputDataFields.groundtruth_boxes],
  599. groundtruth_classes,
  600. groundtruth_is_difficult_list=None,
  601. groundtruth_is_group_of_list=groundtruth_group_of,
  602. groundtruth_masks=groundtruth_masks)
  603. self._image_ids.update([image_id])
  604. class OpenImagesChallengeEvaluator(OpenImagesDetectionEvaluator):
  605. """A class implements Open Images Challenge metrics.
  606. Both Detection and Instance Segmentation evaluation metrics are implemented.
  607. Open Images Challenge Detection metric has two major changes in comparison
  608. with Open Images V2 detection metric:
  609. - a custom weight might be specified for detecting an object contained in
  610. a group-of box.
  611. - verified image-level labels should be explicitelly provided for
  612. evaluation: in case in image has neither positive nor negative image level
  613. label of class c, all detections of this class on this image will be
  614. ignored.
  615. Open Images Challenge Instance Segmentation metric allows to measure per
  616. formance of models in case of incomplete annotations: some instances are
  617. annotations only on box level and some - on image-level. In addition,
  618. image-level labels are taken into account as in detection metric.
  619. Open Images Challenge Detection metric default parameters:
  620. evaluate_masks = False
  621. group_of_weight = 1.0
  622. Open Images Challenge Instance Segmentation metric default parameters:
  623. evaluate_masks = True
  624. (group_of_weight will not matter)
  625. """
  626. def __init__(self,
  627. categories,
  628. evaluate_masks=False,
  629. matching_iou_threshold=0.5,
  630. evaluate_corlocs=False,
  631. group_of_weight=1.0):
  632. """Constructor.
  633. Args:
  634. categories: A list of dicts, each of which has the following keys -
  635. 'id': (required) an integer id uniquely identifying this category.
  636. 'name': (required) string representing category name e.g., 'cat', 'dog'.
  637. evaluate_masks: set to true for instance segmentation metric and to false
  638. for detection metric.
  639. matching_iou_threshold: IOU threshold to use for matching groundtruth
  640. boxes to detection boxes.
  641. evaluate_corlocs: if True, additionally evaluates and returns CorLoc.
  642. group_of_weight: weight of a group-of box. If set to 0, detections of the
  643. correct class within a group-of box are ignored. If weight is > 0
  644. (default for Open Images Detection Challenge), then if at least one
  645. detection falls within a group-of box with matching_iou_threshold,
  646. weight group_of_weight is added to true positives. Consequently, if no
  647. detection falls within a group-of box, weight group_of_weight is added
  648. to false negatives.
  649. """
  650. if not evaluate_masks:
  651. metrics_prefix = 'OpenImagesDetectionChallenge'
  652. else:
  653. metrics_prefix = 'OpenImagesInstanceSegmentationChallenge'
  654. super(OpenImagesChallengeEvaluator, self).__init__(
  655. categories,
  656. matching_iou_threshold,
  657. evaluate_masks=evaluate_masks,
  658. evaluate_corlocs=evaluate_corlocs,
  659. group_of_weight=group_of_weight,
  660. metric_prefix=metrics_prefix)
  661. self._evaluatable_labels = {}
  662. self._expected_keys.add(
  663. standard_fields.InputDataFields.groundtruth_image_classes)
  664. def add_single_ground_truth_image_info(self, image_id, groundtruth_dict):
  665. """Adds groundtruth for a single image to be used for evaluation.
  666. Args:
  667. image_id: A unique string/integer identifier for the image.
  668. groundtruth_dict: A dictionary containing -
  669. standard_fields.InputDataFields.groundtruth_boxes: float32 numpy array
  670. of shape [num_boxes, 4] containing `num_boxes` groundtruth boxes of
  671. the format [ymin, xmin, ymax, xmax] in absolute image coordinates.
  672. standard_fields.InputDataFields.groundtruth_classes: integer numpy array
  673. of shape [num_boxes] containing 1-indexed groundtruth classes for the
  674. boxes.
  675. standard_fields.InputDataFields.groundtruth_image_classes: integer 1D
  676. numpy array containing all classes for which labels are verified.
  677. standard_fields.InputDataFields.groundtruth_group_of: Optional length
  678. M numpy boolean array denoting whether a groundtruth box contains a
  679. group of instances.
  680. Raises:
  681. ValueError: On adding groundtruth for an image more than once.
  682. """
  683. super(OpenImagesChallengeEvaluator,
  684. self).add_single_ground_truth_image_info(image_id, groundtruth_dict)
  685. groundtruth_classes = (
  686. groundtruth_dict[standard_fields.InputDataFields.groundtruth_classes] -
  687. self._label_id_offset)
  688. self._evaluatable_labels[image_id] = np.unique(
  689. np.concatenate(((groundtruth_dict.get(
  690. standard_fields.InputDataFields.groundtruth_image_classes,
  691. np.array([], dtype=int)) - self._label_id_offset),
  692. groundtruth_classes)))
  693. def add_single_detected_image_info(self, image_id, detections_dict):
  694. """Adds detections for a single image to be used for evaluation.
  695. Args:
  696. image_id: A unique string/integer identifier for the image.
  697. detections_dict: A dictionary containing -
  698. standard_fields.DetectionResultFields.detection_boxes: float32 numpy
  699. array of shape [num_boxes, 4] containing `num_boxes` detection boxes
  700. of the format [ymin, xmin, ymax, xmax] in absolute image coordinates.
  701. standard_fields.DetectionResultFields.detection_scores: float32 numpy
  702. array of shape [num_boxes] containing detection scores for the boxes.
  703. standard_fields.DetectionResultFields.detection_classes: integer numpy
  704. array of shape [num_boxes] containing 1-indexed detection classes for
  705. the boxes.
  706. Raises:
  707. ValueError: If detection masks are not in detections dictionary.
  708. """
  709. if image_id not in self._image_ids:
  710. # Since for the correct work of evaluator it is assumed that groundtruth
  711. # is inserted first we make sure to break the code if is it not the case.
  712. self._image_ids.update([image_id])
  713. self._evaluatable_labels[image_id] = np.array([])
  714. detection_classes = (
  715. detections_dict[standard_fields.DetectionResultFields.detection_classes]
  716. - self._label_id_offset)
  717. allowed_classes = np.where(
  718. np.isin(detection_classes, self._evaluatable_labels[image_id]))
  719. detection_classes = detection_classes[allowed_classes]
  720. detected_boxes = detections_dict[
  721. standard_fields.DetectionResultFields.detection_boxes][allowed_classes]
  722. detected_scores = detections_dict[
  723. standard_fields.DetectionResultFields.detection_scores][allowed_classes]
  724. if self._evaluate_masks:
  725. detection_masks = detections_dict[standard_fields.DetectionResultFields
  726. .detection_masks][allowed_classes]
  727. else:
  728. detection_masks = None
  729. self._evaluation.add_single_detected_image_info(
  730. image_key=image_id,
  731. detected_boxes=detected_boxes,
  732. detected_scores=detected_scores,
  733. detected_class_labels=detection_classes,
  734. detected_masks=detection_masks)
  735. def clear(self):
  736. """Clears stored data."""
  737. super(OpenImagesChallengeEvaluator, self).clear()
  738. self._evaluatable_labels.clear()
  739. ObjectDetectionEvalMetrics = collections.namedtuple(
  740. 'ObjectDetectionEvalMetrics', [
  741. 'average_precisions', 'mean_ap', 'precisions', 'recalls', 'corlocs',
  742. 'mean_corloc'
  743. ])
  744. class OpenImagesDetectionChallengeEvaluator(OpenImagesChallengeEvaluator):
  745. """A class implements Open Images Detection Challenge metric."""
  746. def __init__(self,
  747. categories,
  748. matching_iou_threshold=0.5,
  749. evaluate_corlocs=False,
  750. group_of_weight=1.0):
  751. """Constructor.
  752. Args:
  753. categories: A list of dicts, each of which has the following keys -
  754. 'id': (required) an integer id uniquely identifying this category.
  755. 'name': (required) string representing category name e.g., 'cat', 'dog'.
  756. matching_iou_threshold: IOU threshold to use for matching groundtruth
  757. boxes to detection boxes.
  758. evaluate_corlocs: if True, additionally evaluates and returns CorLoc.
  759. group_of_weight: weight of a group-of box. If set to 0, detections of the
  760. correct class within a group-of box are ignored. If weight is > 0
  761. (default for Open Images Detection Challenge), then if at least one
  762. detection falls within a group-of box with matching_iou_threshold,
  763. weight group_of_weight is added to true positives. Consequently, if no
  764. detection falls within a group-of box, weight group_of_weight is added
  765. to false negatives.
  766. """
  767. super(OpenImagesDetectionChallengeEvaluator, self).__init__(
  768. categories=categories,
  769. evaluate_masks=False,
  770. matching_iou_threshold=matching_iou_threshold,
  771. evaluate_corlocs=False,
  772. group_of_weight=1.0)
  773. class OpenImagesInstanceSegmentationChallengeEvaluator(
  774. OpenImagesChallengeEvaluator):
  775. """A class implements Open Images Instance Segmentation Challenge metric."""
  776. def __init__(self,
  777. categories,
  778. matching_iou_threshold=0.5,
  779. evaluate_corlocs=False,
  780. group_of_weight=1.0):
  781. """Constructor.
  782. Args:
  783. categories: A list of dicts, each of which has the following keys -
  784. 'id': (required) an integer id uniquely identifying this category.
  785. 'name': (required) string representing category name e.g., 'cat', 'dog'.
  786. matching_iou_threshold: IOU threshold to use for matching groundtruth
  787. boxes to detection boxes.
  788. evaluate_corlocs: if True, additionally evaluates and returns CorLoc.
  789. group_of_weight: weight of a group-of box. If set to 0, detections of the
  790. correct class within a group-of box are ignored. If weight is > 0
  791. (default for Open Images Detection Challenge), then if at least one
  792. detection falls within a group-of box with matching_iou_threshold,
  793. weight group_of_weight is added to true positives. Consequently, if no
  794. detection falls within a group-of box, weight group_of_weight is added
  795. to false negatives.
  796. """
  797. super(OpenImagesInstanceSegmentationChallengeEvaluator, self).__init__(
  798. categories=categories,
  799. evaluate_masks=True,
  800. matching_iou_threshold=matching_iou_threshold,
  801. evaluate_corlocs=False,
  802. group_of_weight=1.0)
  803. class ObjectDetectionEvaluation(object):
  804. """Internal implementation of Pascal object detection metrics."""
  805. def __init__(self,
  806. num_groundtruth_classes,
  807. matching_iou_threshold=0.5,
  808. nms_iou_threshold=1.0,
  809. nms_max_output_boxes=10000,
  810. recall_lower_bound=0.0,
  811. recall_upper_bound=1.0,
  812. use_weighted_mean_ap=False,
  813. label_id_offset=0,
  814. group_of_weight=0.0,
  815. per_image_eval_class=per_image_evaluation.PerImageEvaluation):
  816. """Constructor.
  817. Args:
  818. num_groundtruth_classes: Number of ground-truth classes.
  819. matching_iou_threshold: IOU threshold used for matching detected boxes
  820. to ground-truth boxes.
  821. nms_iou_threshold: IOU threshold used for non-maximum suppression.
  822. nms_max_output_boxes: Maximum number of boxes returned by non-maximum
  823. suppression.
  824. recall_lower_bound: lower bound of recall operating area
  825. recall_upper_bound: upper bound of recall operating area
  826. use_weighted_mean_ap: (optional) boolean which determines if the mean
  827. average precision is computed directly from the scores and tp_fp_labels
  828. of all classes.
  829. label_id_offset: The label id offset.
  830. group_of_weight: Weight of group-of boxes.If set to 0, detections of the
  831. correct class within a group-of box are ignored. If weight is > 0, then
  832. if at least one detection falls within a group-of box with
  833. matching_iou_threshold, weight group_of_weight is added to true
  834. positives. Consequently, if no detection falls within a group-of box,
  835. weight group_of_weight is added to false negatives.
  836. per_image_eval_class: The class that contains functions for computing
  837. per image metrics.
  838. Raises:
  839. ValueError: if num_groundtruth_classes is smaller than 1.
  840. """
  841. if num_groundtruth_classes < 1:
  842. raise ValueError('Need at least 1 groundtruth class for evaluation.')
  843. self.per_image_eval = per_image_eval_class(
  844. num_groundtruth_classes=num_groundtruth_classes,
  845. matching_iou_threshold=matching_iou_threshold,
  846. nms_iou_threshold=nms_iou_threshold,
  847. nms_max_output_boxes=nms_max_output_boxes,
  848. group_of_weight=group_of_weight)
  849. self.recall_lower_bound = recall_lower_bound
  850. self.recall_upper_bound = recall_upper_bound
  851. self.group_of_weight = group_of_weight
  852. self.num_class = num_groundtruth_classes
  853. self.use_weighted_mean_ap = use_weighted_mean_ap
  854. self.label_id_offset = label_id_offset
  855. self.groundtruth_boxes = {}
  856. self.groundtruth_class_labels = {}
  857. self.groundtruth_masks = {}
  858. self.groundtruth_is_difficult_list = {}
  859. self.groundtruth_is_group_of_list = {}
  860. self.num_gt_instances_per_class = np.zeros(self.num_class, dtype=float)
  861. self.num_gt_imgs_per_class = np.zeros(self.num_class, dtype=int)
  862. self._initialize_detections()
  863. def _initialize_detections(self):
  864. """Initializes internal data structures."""
  865. self.detection_keys = set()
  866. self.scores_per_class = [[] for _ in range(self.num_class)]
  867. self.tp_fp_labels_per_class = [[] for _ in range(self.num_class)]
  868. self.num_images_correctly_detected_per_class = np.zeros(self.num_class)
  869. self.average_precision_per_class = np.empty(self.num_class, dtype=float)
  870. self.average_precision_per_class.fill(np.nan)
  871. self.precisions_per_class = [np.nan] * self.num_class
  872. self.recalls_per_class = [np.nan] * self.num_class
  873. self.corloc_per_class = np.ones(self.num_class, dtype=float)
  874. def clear_detections(self):
  875. self._initialize_detections()
  876. def add_single_ground_truth_image_info(self,
  877. image_key,
  878. groundtruth_boxes,
  879. groundtruth_class_labels,
  880. groundtruth_is_difficult_list=None,
  881. groundtruth_is_group_of_list=None,
  882. groundtruth_masks=None):
  883. """Adds groundtruth for a single image to be used for evaluation.
  884. Args:
  885. image_key: A unique string/integer identifier for the image.
  886. groundtruth_boxes: float32 numpy array of shape [num_boxes, 4]
  887. containing `num_boxes` groundtruth boxes of the format
  888. [ymin, xmin, ymax, xmax] in absolute image coordinates.
  889. groundtruth_class_labels: integer numpy array of shape [num_boxes]
  890. containing 0-indexed groundtruth classes for the boxes.
  891. groundtruth_is_difficult_list: A length M numpy boolean array denoting
  892. whether a ground truth box is a difficult instance or not. To support
  893. the case that no boxes are difficult, it is by default set as None.
  894. groundtruth_is_group_of_list: A length M numpy boolean array denoting
  895. whether a ground truth box is a group-of box or not. To support
  896. the case that no boxes are groups-of, it is by default set as None.
  897. groundtruth_masks: uint8 numpy array of shape
  898. [num_boxes, height, width] containing `num_boxes` groundtruth masks.
  899. The mask values range from 0 to 1.
  900. """
  901. if image_key in self.groundtruth_boxes:
  902. logging.warn(
  903. 'image %s has already been added to the ground truth database.',
  904. image_key)
  905. return
  906. self.groundtruth_boxes[image_key] = groundtruth_boxes
  907. self.groundtruth_class_labels[image_key] = groundtruth_class_labels
  908. self.groundtruth_masks[image_key] = groundtruth_masks
  909. if groundtruth_is_difficult_list is None:
  910. num_boxes = groundtruth_boxes.shape[0]
  911. groundtruth_is_difficult_list = np.zeros(num_boxes, dtype=bool)
  912. self.groundtruth_is_difficult_list[
  913. image_key] = groundtruth_is_difficult_list.astype(dtype=bool)
  914. if groundtruth_is_group_of_list is None:
  915. num_boxes = groundtruth_boxes.shape[0]
  916. groundtruth_is_group_of_list = np.zeros(num_boxes, dtype=bool)
  917. self.groundtruth_is_group_of_list[
  918. image_key] = groundtruth_is_group_of_list.astype(dtype=bool)
  919. self._update_ground_truth_statistics(
  920. groundtruth_class_labels,
  921. groundtruth_is_difficult_list.astype(dtype=bool),
  922. groundtruth_is_group_of_list.astype(dtype=bool))
  923. def add_single_detected_image_info(self, image_key, detected_boxes,
  924. detected_scores, detected_class_labels,
  925. detected_masks=None):
  926. """Adds detections for a single image to be used for evaluation.
  927. Args:
  928. image_key: A unique string/integer identifier for the image.
  929. detected_boxes: float32 numpy array of shape [num_boxes, 4]
  930. containing `num_boxes` detection boxes of the format
  931. [ymin, xmin, ymax, xmax] in absolute image coordinates.
  932. detected_scores: float32 numpy array of shape [num_boxes] containing
  933. detection scores for the boxes.
  934. detected_class_labels: integer numpy array of shape [num_boxes] containing
  935. 0-indexed detection classes for the boxes.
  936. detected_masks: np.uint8 numpy array of shape [num_boxes, height, width]
  937. containing `num_boxes` detection masks with values ranging
  938. between 0 and 1.
  939. Raises:
  940. ValueError: if the number of boxes, scores and class labels differ in
  941. length.
  942. """
  943. if (len(detected_boxes) != len(detected_scores) or
  944. len(detected_boxes) != len(detected_class_labels)):
  945. raise ValueError('detected_boxes, detected_scores and '
  946. 'detected_class_labels should all have same lengths. Got'
  947. '[%d, %d, %d]' % len(detected_boxes),
  948. len(detected_scores), len(detected_class_labels))
  949. if image_key in self.detection_keys:
  950. logging.warn(
  951. 'image %s has already been added to the detection result database',
  952. image_key)
  953. return
  954. self.detection_keys.add(image_key)
  955. if image_key in self.groundtruth_boxes:
  956. groundtruth_boxes = self.groundtruth_boxes[image_key]
  957. groundtruth_class_labels = self.groundtruth_class_labels[image_key]
  958. # Masks are popped instead of look up. The reason is that we do not want
  959. # to keep all masks in memory which can cause memory overflow.
  960. groundtruth_masks = self.groundtruth_masks.pop(
  961. image_key)
  962. groundtruth_is_difficult_list = self.groundtruth_is_difficult_list[
  963. image_key]
  964. groundtruth_is_group_of_list = self.groundtruth_is_group_of_list[
  965. image_key]
  966. else:
  967. groundtruth_boxes = np.empty(shape=[0, 4], dtype=float)
  968. groundtruth_class_labels = np.array([], dtype=int)
  969. if detected_masks is None:
  970. groundtruth_masks = None
  971. else:
  972. groundtruth_masks = np.empty(shape=[0, 1, 1], dtype=float)
  973. groundtruth_is_difficult_list = np.array([], dtype=bool)
  974. groundtruth_is_group_of_list = np.array([], dtype=bool)
  975. scores, tp_fp_labels, is_class_correctly_detected_in_image = (
  976. self.per_image_eval.compute_object_detection_metrics(
  977. detected_boxes=detected_boxes,
  978. detected_scores=detected_scores,
  979. detected_class_labels=detected_class_labels,
  980. groundtruth_boxes=groundtruth_boxes,
  981. groundtruth_class_labels=groundtruth_class_labels,
  982. groundtruth_is_difficult_list=groundtruth_is_difficult_list,
  983. groundtruth_is_group_of_list=groundtruth_is_group_of_list,
  984. detected_masks=detected_masks,
  985. groundtruth_masks=groundtruth_masks))
  986. for i in range(self.num_class):
  987. if scores[i].shape[0] > 0:
  988. self.scores_per_class[i].append(scores[i])
  989. self.tp_fp_labels_per_class[i].append(tp_fp_labels[i])
  990. (self.num_images_correctly_detected_per_class
  991. ) += is_class_correctly_detected_in_image
  992. def _update_ground_truth_statistics(self, groundtruth_class_labels,
  993. groundtruth_is_difficult_list,
  994. groundtruth_is_group_of_list):
  995. """Update grouth truth statitistics.
  996. 1. Difficult boxes are ignored when counting the number of ground truth
  997. instances as done in Pascal VOC devkit.
  998. 2. Difficult boxes are treated as normal boxes when computing CorLoc related
  999. statitistics.
  1000. Args:
  1001. groundtruth_class_labels: An integer numpy array of length M,
  1002. representing M class labels of object instances in ground truth
  1003. groundtruth_is_difficult_list: A boolean numpy array of length M denoting
  1004. whether a ground truth box is a difficult instance or not
  1005. groundtruth_is_group_of_list: A boolean numpy array of length M denoting
  1006. whether a ground truth box is a group-of box or not
  1007. """
  1008. for class_index in range(self.num_class):
  1009. num_gt_instances = np.sum(groundtruth_class_labels[
  1010. ~groundtruth_is_difficult_list
  1011. & ~groundtruth_is_group_of_list] == class_index)
  1012. num_groupof_gt_instances = self.group_of_weight * np.sum(
  1013. groundtruth_class_labels[groundtruth_is_group_of_list] == class_index)
  1014. self.num_gt_instances_per_class[
  1015. class_index] += num_gt_instances + num_groupof_gt_instances
  1016. if np.any(groundtruth_class_labels == class_index):
  1017. self.num_gt_imgs_per_class[class_index] += 1
  1018. def evaluate(self):
  1019. """Compute evaluation result.
  1020. Returns:
  1021. A named tuple with the following fields -
  1022. average_precision: float numpy array of average precision for
  1023. each class.
  1024. mean_ap: mean average precision of all classes, float scalar
  1025. precisions: List of precisions, each precision is a float numpy
  1026. array
  1027. recalls: List of recalls, each recall is a float numpy array
  1028. corloc: numpy float array
  1029. mean_corloc: Mean CorLoc score for each class, float scalar
  1030. """
  1031. if (self.num_gt_instances_per_class == 0).any():
  1032. logging.warn(
  1033. 'The following classes have no ground truth examples: %s',
  1034. np.squeeze(np.argwhere(self.num_gt_instances_per_class == 0)) +
  1035. self.label_id_offset)
  1036. if self.use_weighted_mean_ap:
  1037. all_scores = np.array([], dtype=float)
  1038. all_tp_fp_labels = np.array([], dtype=bool)
  1039. for class_index in range(self.num_class):
  1040. if self.num_gt_instances_per_class[class_index] == 0:
  1041. continue
  1042. if not self.scores_per_class[class_index]:
  1043. scores = np.array([], dtype=float)
  1044. tp_fp_labels = np.array([], dtype=float)
  1045. else:
  1046. scores = np.concatenate(self.scores_per_class[class_index])
  1047. tp_fp_labels = np.concatenate(self.tp_fp_labels_per_class[class_index])
  1048. if self.use_weighted_mean_ap:
  1049. all_scores = np.append(all_scores, scores)
  1050. all_tp_fp_labels = np.append(all_tp_fp_labels, tp_fp_labels)
  1051. precision, recall = metrics.compute_precision_recall(
  1052. scores, tp_fp_labels, self.num_gt_instances_per_class[class_index])
  1053. recall_within_bound_indices = [
  1054. index for index, value in enumerate(recall) if
  1055. value >= self.recall_lower_bound and value <= self.recall_upper_bound
  1056. ]
  1057. recall_within_bound = recall[recall_within_bound_indices]
  1058. precision_within_bound = precision[recall_within_bound_indices]
  1059. self.precisions_per_class[class_index] = precision_within_bound
  1060. self.recalls_per_class[class_index] = recall_within_bound
  1061. average_precision = metrics.compute_average_precision(
  1062. precision_within_bound, recall_within_bound)
  1063. self.average_precision_per_class[class_index] = average_precision
  1064. logging.info('average_precision: %f', average_precision)
  1065. self.corloc_per_class = metrics.compute_cor_loc(
  1066. self.num_gt_imgs_per_class,
  1067. self.num_images_correctly_detected_per_class)
  1068. if self.use_weighted_mean_ap:
  1069. num_gt_instances = np.sum(self.num_gt_instances_per_class)
  1070. precision, recall = metrics.compute_precision_recall(
  1071. all_scores, all_tp_fp_labels, num_gt_instances)
  1072. recall_within_bound_indices = [
  1073. index for index, value in enumerate(recall) if
  1074. value >= self.recall_lower_bound and value <= self.recall_upper_bound
  1075. ]
  1076. recall_within_bound = recall[recall_within_bound_indices]
  1077. precision_within_bound = precision[recall_within_bound_indices]
  1078. mean_ap = metrics.compute_average_precision(precision_within_bound,
  1079. recall_within_bound)
  1080. else:
  1081. mean_ap = np.nanmean(self.average_precision_per_class)
  1082. mean_corloc = np.nanmean(self.corloc_per_class)
  1083. return ObjectDetectionEvalMetrics(
  1084. self.average_precision_per_class, mean_ap, self.precisions_per_class,
  1085. self.recalls_per_class, self.corloc_per_class, mean_corloc)