You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

257 lines
11 KiB

  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for tensorflow_models.object_detection.utils.vrd_evaluation."""
  16. import numpy as np
  17. import tensorflow as tf
  18. from object_detection.core import standard_fields
  19. from object_detection.utils import vrd_evaluation
  20. class VRDRelationDetectionEvaluatorTest(tf.test.TestCase):
  21. def test_vrdrelation_evaluator(self):
  22. self.vrd_eval = vrd_evaluation.VRDRelationDetectionEvaluator()
  23. image_key1 = 'img1'
  24. groundtruth_box_tuples1 = np.array(
  25. [([0, 0, 1, 1], [1, 1, 2, 2]), ([0, 0, 1, 1], [1, 2, 2, 3])],
  26. dtype=vrd_evaluation.vrd_box_data_type)
  27. groundtruth_class_tuples1 = np.array(
  28. [(1, 2, 3), (1, 4, 3)], dtype=vrd_evaluation.label_data_type)
  29. groundtruth_verified_labels1 = np.array([1, 2, 3, 4, 5], dtype=int)
  30. self.vrd_eval.add_single_ground_truth_image_info(
  31. image_key1, {
  32. standard_fields.InputDataFields.groundtruth_boxes:
  33. groundtruth_box_tuples1,
  34. standard_fields.InputDataFields.groundtruth_classes:
  35. groundtruth_class_tuples1,
  36. standard_fields.InputDataFields.groundtruth_image_classes:
  37. groundtruth_verified_labels1
  38. })
  39. image_key2 = 'img2'
  40. groundtruth_box_tuples2 = np.array(
  41. [([0, 0, 1, 1], [1, 1, 2, 2])], dtype=vrd_evaluation.vrd_box_data_type)
  42. groundtruth_class_tuples2 = np.array(
  43. [(1, 4, 3)], dtype=vrd_evaluation.label_data_type)
  44. self.vrd_eval.add_single_ground_truth_image_info(
  45. image_key2, {
  46. standard_fields.InputDataFields.groundtruth_boxes:
  47. groundtruth_box_tuples2,
  48. standard_fields.InputDataFields.groundtruth_classes:
  49. groundtruth_class_tuples2,
  50. })
  51. image_key3 = 'img3'
  52. groundtruth_box_tuples3 = np.array(
  53. [([0, 0, 1, 1], [1, 1, 2, 2])], dtype=vrd_evaluation.vrd_box_data_type)
  54. groundtruth_class_tuples3 = np.array(
  55. [(1, 2, 4)], dtype=vrd_evaluation.label_data_type)
  56. self.vrd_eval.add_single_ground_truth_image_info(
  57. image_key3, {
  58. standard_fields.InputDataFields.groundtruth_boxes:
  59. groundtruth_box_tuples3,
  60. standard_fields.InputDataFields.groundtruth_classes:
  61. groundtruth_class_tuples3,
  62. })
  63. image_key = 'img1'
  64. detected_box_tuples = np.array(
  65. [([0, 0.3, 1, 1], [1.1, 1, 2, 2]), ([0, 0, 1, 1], [1, 1, 2, 2]),
  66. ([0.5, 0, 1, 1], [1, 1, 3, 3])],
  67. dtype=vrd_evaluation.vrd_box_data_type)
  68. detected_class_tuples = np.array(
  69. [(1, 2, 5), (1, 2, 3), (1, 6, 3)], dtype=vrd_evaluation.label_data_type)
  70. detected_scores = np.array([0.7, 0.8, 0.9], dtype=float)
  71. self.vrd_eval.add_single_detected_image_info(
  72. image_key, {
  73. standard_fields.DetectionResultFields.detection_boxes:
  74. detected_box_tuples,
  75. standard_fields.DetectionResultFields.detection_scores:
  76. detected_scores,
  77. standard_fields.DetectionResultFields.detection_classes:
  78. detected_class_tuples
  79. })
  80. metrics = self.vrd_eval.evaluate()
  81. self.assertAlmostEqual(metrics['VRDMetric_Relationships_weightedAP@0.5IOU'],
  82. 0.25)
  83. self.assertAlmostEqual(metrics['VRDMetric_Relationships_mAP@0.5IOU'],
  84. 0.1666666666666666)
  85. self.assertAlmostEqual(metrics['VRDMetric_Relationships_AP@0.5IOU/3'],
  86. 0.3333333333333333)
  87. self.assertAlmostEqual(metrics['VRDMetric_Relationships_AP@0.5IOU/4'], 0)
  88. self.assertAlmostEqual(metrics['VRDMetric_Relationships_Recall@50@0.5IOU'],
  89. 0.25)
  90. self.assertAlmostEqual(metrics['VRDMetric_Relationships_Recall@100@0.5IOU'],
  91. 0.25)
  92. self.vrd_eval.clear()
  93. self.assertFalse(self.vrd_eval._image_ids)
  94. class VRDPhraseDetectionEvaluatorTest(tf.test.TestCase):
  95. def test_vrdphrase_evaluator(self):
  96. self.vrd_eval = vrd_evaluation.VRDPhraseDetectionEvaluator()
  97. image_key1 = 'img1'
  98. groundtruth_box_tuples1 = np.array(
  99. [([0, 0, 1, 1], [1, 1, 2, 2]), ([0, 0, 1, 1], [1, 2, 2, 3])],
  100. dtype=vrd_evaluation.vrd_box_data_type)
  101. groundtruth_class_tuples1 = np.array(
  102. [(1, 2, 3), (1, 4, 3)], dtype=vrd_evaluation.label_data_type)
  103. groundtruth_verified_labels1 = np.array([1, 2, 3, 4, 5], dtype=int)
  104. self.vrd_eval.add_single_ground_truth_image_info(
  105. image_key1, {
  106. standard_fields.InputDataFields.groundtruth_boxes:
  107. groundtruth_box_tuples1,
  108. standard_fields.InputDataFields.groundtruth_classes:
  109. groundtruth_class_tuples1,
  110. standard_fields.InputDataFields.groundtruth_image_classes:
  111. groundtruth_verified_labels1
  112. })
  113. image_key2 = 'img2'
  114. groundtruth_box_tuples2 = np.array(
  115. [([0, 0, 1, 1], [1, 1, 2, 2])], dtype=vrd_evaluation.vrd_box_data_type)
  116. groundtruth_class_tuples2 = np.array(
  117. [(1, 4, 3)], dtype=vrd_evaluation.label_data_type)
  118. self.vrd_eval.add_single_ground_truth_image_info(
  119. image_key2, {
  120. standard_fields.InputDataFields.groundtruth_boxes:
  121. groundtruth_box_tuples2,
  122. standard_fields.InputDataFields.groundtruth_classes:
  123. groundtruth_class_tuples2,
  124. })
  125. image_key3 = 'img3'
  126. groundtruth_box_tuples3 = np.array(
  127. [([0, 0, 1, 1], [1, 1, 2, 2])], dtype=vrd_evaluation.vrd_box_data_type)
  128. groundtruth_class_tuples3 = np.array(
  129. [(1, 2, 4)], dtype=vrd_evaluation.label_data_type)
  130. self.vrd_eval.add_single_ground_truth_image_info(
  131. image_key3, {
  132. standard_fields.InputDataFields.groundtruth_boxes:
  133. groundtruth_box_tuples3,
  134. standard_fields.InputDataFields.groundtruth_classes:
  135. groundtruth_class_tuples3,
  136. })
  137. image_key = 'img1'
  138. detected_box_tuples = np.array(
  139. [([0, 0.3, 0.5, 0.5], [0.3, 0.3, 1.0, 1.0]),
  140. ([0, 0, 1.2, 1.2], [0.0, 0.0, 2.0, 2.0]),
  141. ([0.5, 0, 1, 1], [1, 1, 3, 3])],
  142. dtype=vrd_evaluation.vrd_box_data_type)
  143. detected_class_tuples = np.array(
  144. [(1, 2, 5), (1, 2, 3), (1, 6, 3)], dtype=vrd_evaluation.label_data_type)
  145. detected_scores = np.array([0.7, 0.8, 0.9], dtype=float)
  146. self.vrd_eval.add_single_detected_image_info(
  147. image_key, {
  148. standard_fields.DetectionResultFields.detection_boxes:
  149. detected_box_tuples,
  150. standard_fields.DetectionResultFields.detection_scores:
  151. detected_scores,
  152. standard_fields.DetectionResultFields.detection_classes:
  153. detected_class_tuples
  154. })
  155. metrics = self.vrd_eval.evaluate()
  156. self.assertAlmostEqual(metrics['VRDMetric_Phrases_weightedAP@0.5IOU'], 0.25)
  157. self.assertAlmostEqual(metrics['VRDMetric_Phrases_mAP@0.5IOU'],
  158. 0.1666666666666666)
  159. self.assertAlmostEqual(metrics['VRDMetric_Phrases_AP@0.5IOU/3'],
  160. 0.3333333333333333)
  161. self.assertAlmostEqual(metrics['VRDMetric_Phrases_AP@0.5IOU/4'], 0)
  162. self.assertAlmostEqual(metrics['VRDMetric_Phrases_Recall@50@0.5IOU'], 0.25)
  163. self.assertAlmostEqual(metrics['VRDMetric_Phrases_Recall@100@0.5IOU'], 0.25)
  164. self.vrd_eval.clear()
  165. self.assertFalse(self.vrd_eval._image_ids)
  166. class VRDDetectionEvaluationTest(tf.test.TestCase):
  167. def setUp(self):
  168. self.vrd_eval = vrd_evaluation._VRDDetectionEvaluation(
  169. matching_iou_threshold=0.5)
  170. image_key1 = 'img1'
  171. groundtruth_box_tuples1 = np.array(
  172. [([0, 0, 1, 1], [1, 1, 2, 2]), ([0, 0, 1, 1], [1, 2, 2, 3])],
  173. dtype=vrd_evaluation.vrd_box_data_type)
  174. groundtruth_class_tuples1 = np.array(
  175. [(1, 2, 3), (1, 4, 3)], dtype=vrd_evaluation.label_data_type)
  176. self.vrd_eval.add_single_ground_truth_image_info(
  177. image_key1, groundtruth_box_tuples1, groundtruth_class_tuples1)
  178. image_key2 = 'img2'
  179. groundtruth_box_tuples2 = np.array(
  180. [([0, 0, 1, 1], [1, 1, 2, 2])], dtype=vrd_evaluation.vrd_box_data_type)
  181. groundtruth_class_tuples2 = np.array(
  182. [(1, 4, 3)], dtype=vrd_evaluation.label_data_type)
  183. self.vrd_eval.add_single_ground_truth_image_info(
  184. image_key2, groundtruth_box_tuples2, groundtruth_class_tuples2)
  185. image_key3 = 'img3'
  186. groundtruth_box_tuples3 = np.array(
  187. [([0, 0, 1, 1], [1, 1, 2, 2])], dtype=vrd_evaluation.vrd_box_data_type)
  188. groundtruth_class_tuples3 = np.array(
  189. [(1, 2, 4)], dtype=vrd_evaluation.label_data_type)
  190. self.vrd_eval.add_single_ground_truth_image_info(
  191. image_key3, groundtruth_box_tuples3, groundtruth_class_tuples3)
  192. image_key = 'img1'
  193. detected_box_tuples = np.array(
  194. [([0, 0.3, 1, 1], [1.1, 1, 2, 2]), ([0, 0, 1, 1], [1, 1, 2, 2])],
  195. dtype=vrd_evaluation.vrd_box_data_type)
  196. detected_class_tuples = np.array(
  197. [(1, 2, 3), (1, 2, 3)], dtype=vrd_evaluation.label_data_type)
  198. detected_scores = np.array([0.7, 0.8], dtype=float)
  199. self.vrd_eval.add_single_detected_image_info(
  200. image_key, detected_box_tuples, detected_scores, detected_class_tuples)
  201. metrics = self.vrd_eval.evaluate()
  202. expected_weighted_average_precision = 0.25
  203. expected_mean_average_precision = 0.16666666666666
  204. expected_precision = np.array([1., 0.5], dtype=float)
  205. expected_recall = np.array([0.25, 0.25], dtype=float)
  206. expected_recall_50 = 0.25
  207. expected_recall_100 = 0.25
  208. expected_median_rank_50 = 0
  209. expected_median_rank_100 = 0
  210. self.assertAlmostEqual(expected_weighted_average_precision,
  211. metrics.weighted_average_precision)
  212. self.assertAlmostEqual(expected_mean_average_precision,
  213. metrics.mean_average_precision)
  214. self.assertAlmostEqual(expected_mean_average_precision,
  215. metrics.mean_average_precision)
  216. self.assertAllClose(expected_precision, metrics.precisions)
  217. self.assertAllClose(expected_recall, metrics.recalls)
  218. self.assertAlmostEqual(expected_recall_50, metrics.recall_50)
  219. self.assertAlmostEqual(expected_recall_100, metrics.recall_100)
  220. self.assertAlmostEqual(expected_median_rank_50, metrics.median_rank_50)
  221. self.assertAlmostEqual(expected_median_rank_100, metrics.median_rank_100)
  222. if __name__ == '__main__':
  223. tf.test.main()