You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

143 lines
5.9 KiB

6 years ago
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for object_detection.metrics."""
  16. import numpy as np
  17. import tensorflow as tf
  18. from object_detection.utils import metrics
  19. class MetricsTest(tf.test.TestCase):
  20. def test_compute_cor_loc(self):
  21. num_gt_imgs_per_class = np.array([100, 1, 5, 1, 1], dtype=int)
  22. num_images_correctly_detected_per_class = np.array(
  23. [10, 0, 1, 0, 0], dtype=int)
  24. corloc = metrics.compute_cor_loc(num_gt_imgs_per_class,
  25. num_images_correctly_detected_per_class)
  26. expected_corloc = np.array([0.1, 0, 0.2, 0, 0], dtype=float)
  27. self.assertTrue(np.allclose(corloc, expected_corloc))
  28. def test_compute_cor_loc_nans(self):
  29. num_gt_imgs_per_class = np.array([100, 0, 0, 1, 1], dtype=int)
  30. num_images_correctly_detected_per_class = np.array(
  31. [10, 0, 1, 0, 0], dtype=int)
  32. corloc = metrics.compute_cor_loc(num_gt_imgs_per_class,
  33. num_images_correctly_detected_per_class)
  34. expected_corloc = np.array([0.1, np.nan, np.nan, 0, 0], dtype=float)
  35. self.assertAllClose(corloc, expected_corloc)
  36. def test_compute_precision_recall(self):
  37. num_gt = 10
  38. scores = np.array([0.4, 0.3, 0.6, 0.2, 0.7, 0.1], dtype=float)
  39. labels = np.array([0, 1, 1, 0, 0, 1], dtype=bool)
  40. labels_float_type = np.array([0, 1, 1, 0, 0, 1], dtype=float)
  41. accumulated_tp_count = np.array([0, 1, 1, 2, 2, 3], dtype=float)
  42. expected_precision = accumulated_tp_count / np.array([1, 2, 3, 4, 5, 6])
  43. expected_recall = accumulated_tp_count / num_gt
  44. precision, recall = metrics.compute_precision_recall(scores, labels, num_gt)
  45. precision_float_type, recall_float_type = metrics.compute_precision_recall(
  46. scores, labels_float_type, num_gt)
  47. self.assertAllClose(precision, expected_precision)
  48. self.assertAllClose(recall, expected_recall)
  49. self.assertAllClose(precision_float_type, expected_precision)
  50. self.assertAllClose(recall_float_type, expected_recall)
  51. def test_compute_precision_recall_float(self):
  52. num_gt = 10
  53. scores = np.array([0.4, 0.3, 0.6, 0.2, 0.7, 0.1], dtype=float)
  54. labels_float = np.array([0, 1, 1, 0.5, 0, 1], dtype=float)
  55. expected_precision = np.array(
  56. [0., 0.5, 0.33333333, 0.5, 0.55555556, 0.63636364], dtype=float)
  57. expected_recall = np.array([0., 0.1, 0.1, 0.2, 0.25, 0.35], dtype=float)
  58. precision, recall = metrics.compute_precision_recall(
  59. scores, labels_float, num_gt)
  60. self.assertAllClose(precision, expected_precision)
  61. self.assertAllClose(recall, expected_recall)
  62. def test_compute_average_precision(self):
  63. precision = np.array([0.8, 0.76, 0.9, 0.65, 0.7, 0.5, 0.55, 0], dtype=float)
  64. recall = np.array([0.3, 0.3, 0.4, 0.4, 0.45, 0.45, 0.5, 0.5], dtype=float)
  65. processed_precision = np.array(
  66. [0.9, 0.9, 0.9, 0.7, 0.7, 0.55, 0.55, 0], dtype=float)
  67. recall_interval = np.array([0.3, 0, 0.1, 0, 0.05, 0, 0.05, 0], dtype=float)
  68. expected_mean_ap = np.sum(recall_interval * processed_precision)
  69. mean_ap = metrics.compute_average_precision(precision, recall)
  70. self.assertAlmostEqual(expected_mean_ap, mean_ap)
  71. def test_compute_precision_recall_and_ap_no_groundtruth(self):
  72. num_gt = 0
  73. scores = np.array([0.4, 0.3, 0.6, 0.2, 0.7, 0.1], dtype=float)
  74. labels = np.array([0, 0, 0, 0, 0, 0], dtype=bool)
  75. expected_precision = None
  76. expected_recall = None
  77. precision, recall = metrics.compute_precision_recall(scores, labels, num_gt)
  78. self.assertEquals(precision, expected_precision)
  79. self.assertEquals(recall, expected_recall)
  80. ap = metrics.compute_average_precision(precision, recall)
  81. self.assertTrue(np.isnan(ap))
  82. def test_compute_recall_at_k(self):
  83. num_gt = 4
  84. tp_fp = [
  85. np.array([1, 0, 0], dtype=float),
  86. np.array([0, 1], dtype=float),
  87. np.array([0, 0, 0, 0, 0], dtype=float)
  88. ]
  89. tp_fp_bool = [
  90. np.array([True, False, False], dtype=bool),
  91. np.array([False, True], dtype=float),
  92. np.array([False, False, False, False, False], dtype=float)
  93. ]
  94. recall_1 = metrics.compute_recall_at_k(tp_fp, num_gt, 1)
  95. recall_3 = metrics.compute_recall_at_k(tp_fp, num_gt, 3)
  96. recall_5 = metrics.compute_recall_at_k(tp_fp, num_gt, 5)
  97. recall_3_bool = metrics.compute_recall_at_k(tp_fp_bool, num_gt, 3)
  98. self.assertAlmostEqual(recall_1, 0.25)
  99. self.assertAlmostEqual(recall_3, 0.5)
  100. self.assertAlmostEqual(recall_3_bool, 0.5)
  101. self.assertAlmostEqual(recall_5, 0.5)
  102. def test_compute_median_rank_at_k(self):
  103. tp_fp = [
  104. np.array([1, 0, 0], dtype=float),
  105. np.array([0, 0.1], dtype=float),
  106. np.array([0, 0, 0, 0, 0], dtype=float)
  107. ]
  108. tp_fp_bool = [
  109. np.array([True, False, False], dtype=bool),
  110. np.array([False, True], dtype=float),
  111. np.array([False, False, False, False, False], dtype=float)
  112. ]
  113. median_ranks_1 = metrics.compute_median_rank_at_k(tp_fp, 1)
  114. median_ranks_3 = metrics.compute_median_rank_at_k(tp_fp, 3)
  115. median_ranks_5 = metrics.compute_median_rank_at_k(tp_fp, 5)
  116. median_ranks_3_bool = metrics.compute_median_rank_at_k(tp_fp_bool, 3)
  117. self.assertEquals(median_ranks_1, 0)
  118. self.assertEquals(median_ranks_3, 0.5)
  119. self.assertEquals(median_ranks_3_bool, 0.5)
  120. self.assertEquals(median_ranks_5, 0.5)
  121. if __name__ == '__main__':
  122. tf.test.main()