You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

581 lines
25 KiB

6 years ago
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Evaluator class for Visual Relations Detection.
  16. VRDDetectionEvaluator is a class which manages ground truth information of a
  17. visual relations detection (vrd) dataset, and computes frequently used detection
  18. metrics such as Precision, Recall, Recall@k, of the provided vrd detection
  19. results.
  20. It supports the following operations:
  21. 1) Adding ground truth information of images sequentially.
  22. 2) Adding detection results of images sequentially.
  23. 3) Evaluating detection metrics on already inserted detection results.
  24. Note1: groundtruth should be inserted before evaluation.
  25. Note2: This module operates on numpy boxes and box lists.
  26. """
  27. from abc import abstractmethod
  28. import collections
  29. import logging
  30. import numpy as np
  31. from object_detection.core import standard_fields
  32. from object_detection.utils import metrics
  33. from object_detection.utils import object_detection_evaluation
  34. from object_detection.utils import per_image_vrd_evaluation
  35. # Below standard input numpy datatypes are defined:
  36. # box_data_type - datatype of the groundtruth visual relations box annotations;
  37. # this datatype consists of two named boxes: subject bounding box and object
  38. # bounding box. Each box is of the format [y_min, x_min, y_max, x_max], each
  39. # coordinate being of type float32.
  40. # label_data_type - corresponding datatype of the visual relations label
  41. # annotaions; it consists of three numerical class labels: subject class label,
  42. # object class label and relation class label, each class label being of type
  43. # int32.
  44. vrd_box_data_type = np.dtype([('subject', 'f4', (4,)), ('object', 'f4', (4,))])
  45. single_box_data_type = np.dtype([('box', 'f4', (4,))])
  46. label_data_type = np.dtype([('subject', 'i4'), ('object', 'i4'), ('relation',
  47. 'i4')])
  48. class VRDDetectionEvaluator(object_detection_evaluation.DetectionEvaluator):
  49. """A class to evaluate VRD detections.
  50. This class serves as a base class for VRD evaluation in two settings:
  51. - phrase detection
  52. - relation detection.
  53. """
  54. def __init__(self, matching_iou_threshold=0.5, metric_prefix=None):
  55. """Constructor.
  56. Args:
  57. matching_iou_threshold: IOU threshold to use for matching groundtruth
  58. boxes to detection boxes.
  59. metric_prefix: (optional) string prefix for metric name; if None, no
  60. prefix is used.
  61. """
  62. super(VRDDetectionEvaluator, self).__init__([])
  63. self._matching_iou_threshold = matching_iou_threshold
  64. self._evaluation = _VRDDetectionEvaluation(
  65. matching_iou_threshold=self._matching_iou_threshold)
  66. self._image_ids = set([])
  67. self._metric_prefix = (metric_prefix + '_') if metric_prefix else ''
  68. self._evaluatable_labels = {}
  69. self._negative_labels = {}
  70. @abstractmethod
  71. def _process_groundtruth_boxes(self, groundtruth_box_tuples):
  72. """Pre-processes boxes before adding them to the VRDDetectionEvaluation.
  73. Phrase detection and Relation detection subclasses re-implement this method
  74. depending on the task.
  75. Args:
  76. groundtruth_box_tuples: A numpy array of structures with the shape
  77. [M, 1], each structure containing the same number of named bounding
  78. boxes. Each box is of the format [y_min, x_min, y_max, x_max] (see
  79. datatype vrd_box_data_type, single_box_data_type above).
  80. """
  81. raise NotImplementedError(
  82. '_process_groundtruth_boxes method should be implemented in subclasses'
  83. 'of VRDDetectionEvaluator.')
  84. @abstractmethod
  85. def _process_detection_boxes(self, detections_box_tuples):
  86. """Pre-processes boxes before adding them to the VRDDetectionEvaluation.
  87. Phrase detection and Relation detection subclasses re-implement this method
  88. depending on the task.
  89. Args:
  90. detections_box_tuples: A numpy array of structures with the shape
  91. [M, 1], each structure containing the same number of named bounding
  92. boxes. Each box is of the format [y_min, x_min, y_max, x_max] (see
  93. datatype vrd_box_data_type, single_box_data_type above).
  94. """
  95. raise NotImplementedError(
  96. '_process_detection_boxes method should be implemented in subclasses'
  97. 'of VRDDetectionEvaluator.')
  98. def add_single_ground_truth_image_info(self, image_id, groundtruth_dict):
  99. """Adds groundtruth for a single image to be used for evaluation.
  100. Args:
  101. image_id: A unique string/integer identifier for the image.
  102. groundtruth_dict: A dictionary containing -
  103. standard_fields.InputDataFields.groundtruth_boxes: A numpy array
  104. of structures with the shape [M, 1], representing M tuples, each tuple
  105. containing the same number of named bounding boxes.
  106. Each box is of the format [y_min, x_min, y_max, x_max] (see
  107. datatype vrd_box_data_type, single_box_data_type above).
  108. standard_fields.InputDataFields.groundtruth_classes: A numpy array of
  109. structures shape [M, 1], representing the class labels of the
  110. corresponding bounding boxes and possibly additional classes (see
  111. datatype label_data_type above).
  112. standard_fields.InputDataFields.groundtruth_image_classes: numpy array
  113. of shape [K] containing verified labels.
  114. Raises:
  115. ValueError: On adding groundtruth for an image more than once.
  116. """
  117. if image_id in self._image_ids:
  118. raise ValueError('Image with id {} already added.'.format(image_id))
  119. groundtruth_class_tuples = (
  120. groundtruth_dict[standard_fields.InputDataFields.groundtruth_classes])
  121. groundtruth_box_tuples = (
  122. groundtruth_dict[standard_fields.InputDataFields.groundtruth_boxes])
  123. self._evaluation.add_single_ground_truth_image_info(
  124. image_key=image_id,
  125. groundtruth_box_tuples=self._process_groundtruth_boxes(
  126. groundtruth_box_tuples),
  127. groundtruth_class_tuples=groundtruth_class_tuples)
  128. self._image_ids.update([image_id])
  129. all_classes = []
  130. for field in groundtruth_box_tuples.dtype.fields:
  131. all_classes.append(groundtruth_class_tuples[field])
  132. groudtruth_positive_classes = np.unique(np.concatenate(all_classes))
  133. verified_labels = groundtruth_dict.get(
  134. standard_fields.InputDataFields.groundtruth_image_classes,
  135. np.array([], dtype=int))
  136. self._evaluatable_labels[image_id] = np.unique(
  137. np.concatenate((verified_labels, groudtruth_positive_classes)))
  138. self._negative_labels[image_id] = np.setdiff1d(verified_labels,
  139. groudtruth_positive_classes)
  140. def add_single_detected_image_info(self, image_id, detections_dict):
  141. """Adds detections for a single image to be used for evaluation.
  142. Args:
  143. image_id: A unique string/integer identifier for the image.
  144. detections_dict: A dictionary containing -
  145. standard_fields.DetectionResultFields.detection_boxes: A numpy array of
  146. structures with shape [N, 1], representing N tuples, each tuple
  147. containing the same number of named bounding boxes.
  148. Each box is of the format [y_min, x_min, y_max, x_max] (as an example
  149. see datatype vrd_box_data_type, single_box_data_type above).
  150. standard_fields.DetectionResultFields.detection_scores: float32 numpy
  151. array of shape [N] containing detection scores for the boxes.
  152. standard_fields.DetectionResultFields.detection_classes: A numpy array
  153. of structures shape [N, 1], representing the class labels of the
  154. corresponding bounding boxes and possibly additional classes (see
  155. datatype label_data_type above).
  156. """
  157. if image_id not in self._image_ids:
  158. logging.warn('No groundtruth for the image with id %s.', image_id)
  159. # Since for the correct work of evaluator it is assumed that groundtruth
  160. # is inserted first we make sure to break the code if is it not the case.
  161. self._image_ids.update([image_id])
  162. self._negative_labels[image_id] = np.array([])
  163. self._evaluatable_labels[image_id] = np.array([])
  164. num_detections = detections_dict[
  165. standard_fields.DetectionResultFields.detection_boxes].shape[0]
  166. detection_class_tuples = detections_dict[
  167. standard_fields.DetectionResultFields.detection_classes]
  168. detection_box_tuples = detections_dict[
  169. standard_fields.DetectionResultFields.detection_boxes]
  170. negative_selector = np.zeros(num_detections, dtype=bool)
  171. selector = np.ones(num_detections, dtype=bool)
  172. # Only check boxable labels
  173. for field in detection_box_tuples.dtype.fields:
  174. # Verify if one of the labels is negative (this is sure FP)
  175. negative_selector |= np.isin(detection_class_tuples[field],
  176. self._negative_labels[image_id])
  177. # Verify if all labels are verified
  178. selector &= np.isin(detection_class_tuples[field],
  179. self._evaluatable_labels[image_id])
  180. selector |= negative_selector
  181. self._evaluation.add_single_detected_image_info(
  182. image_key=image_id,
  183. detected_box_tuples=self._process_detection_boxes(
  184. detection_box_tuples[selector]),
  185. detected_scores=detections_dict[
  186. standard_fields.DetectionResultFields.detection_scores][selector],
  187. detected_class_tuples=detection_class_tuples[selector])
  188. def evaluate(self, relationships=None):
  189. """Compute evaluation result.
  190. Args:
  191. relationships: A dictionary of numerical label-text label mapping; if
  192. specified, returns per-relationship AP.
  193. Returns:
  194. A dictionary of metrics with the following fields -
  195. summary_metrics:
  196. 'weightedAP@<matching_iou_threshold>IOU' : weighted average precision
  197. at the specified IOU threshold.
  198. 'AP@<matching_iou_threshold>IOU/<relationship>' : AP per relationship.
  199. 'mAP@<matching_iou_threshold>IOU': mean average precision at the
  200. specified IOU threshold.
  201. 'Recall@50@<matching_iou_threshold>IOU': recall@50 at the specified IOU
  202. threshold.
  203. 'Recall@100@<matching_iou_threshold>IOU': recall@100 at the specified
  204. IOU threshold.
  205. if relationships is specified, returns <relationship> in AP metrics as
  206. readable names, otherwise the names correspond to class numbers.
  207. """
  208. (weighted_average_precision, mean_average_precision, average_precisions, _,
  209. _, recall_50, recall_100, _, _) = (
  210. self._evaluation.evaluate())
  211. vrd_metrics = {
  212. (self._metric_prefix + 'weightedAP@{}IOU'.format(
  213. self._matching_iou_threshold)):
  214. weighted_average_precision,
  215. self._metric_prefix + 'mAP@{}IOU'.format(self._matching_iou_threshold):
  216. mean_average_precision,
  217. self._metric_prefix + 'Recall@50@{}IOU'.format(
  218. self._matching_iou_threshold):
  219. recall_50,
  220. self._metric_prefix + 'Recall@100@{}IOU'.format(
  221. self._matching_iou_threshold):
  222. recall_100,
  223. }
  224. if relationships:
  225. for key, average_precision in average_precisions.iteritems():
  226. vrd_metrics[self._metric_prefix + 'AP@{}IOU/{}'.format(
  227. self._matching_iou_threshold,
  228. relationships[key])] = average_precision
  229. else:
  230. for key, average_precision in average_precisions.iteritems():
  231. vrd_metrics[self._metric_prefix + 'AP@{}IOU/{}'.format(
  232. self._matching_iou_threshold, key)] = average_precision
  233. return vrd_metrics
  234. def clear(self):
  235. """Clears the state to prepare for a fresh evaluation."""
  236. self._evaluation = _VRDDetectionEvaluation(
  237. matching_iou_threshold=self._matching_iou_threshold)
  238. self._image_ids.clear()
  239. self._negative_labels.clear()
  240. self._evaluatable_labels.clear()
  241. class VRDRelationDetectionEvaluator(VRDDetectionEvaluator):
  242. """A class to evaluate VRD detections in relations setting.
  243. Expected groundtruth box datatype is vrd_box_data_type, expected groudtruth
  244. labels datatype is label_data_type.
  245. Expected detection box datatype is vrd_box_data_type, expected detection
  246. labels
  247. datatype is label_data_type.
  248. """
  249. def __init__(self, matching_iou_threshold=0.5):
  250. super(VRDRelationDetectionEvaluator, self).__init__(
  251. matching_iou_threshold=matching_iou_threshold,
  252. metric_prefix='VRDMetric_Relationships')
  253. def _process_groundtruth_boxes(self, groundtruth_box_tuples):
  254. """Pre-processes boxes before adding them to the VRDDetectionEvaluation.
  255. Args:
  256. groundtruth_box_tuples: A numpy array of structures with the shape
  257. [M, 1], each structure containing the same number of named bounding
  258. boxes. Each box is of the format [y_min, x_min, y_max, x_max].
  259. Returns:
  260. Unchanged input.
  261. """
  262. return groundtruth_box_tuples
  263. def _process_detection_boxes(self, detections_box_tuples):
  264. """Pre-processes boxes before adding them to the VRDDetectionEvaluation.
  265. Phrase detection and Relation detection subclasses re-implement this method
  266. depending on the task.
  267. Args:
  268. detections_box_tuples: A numpy array of structures with the shape
  269. [M, 1], each structure containing the same number of named bounding
  270. boxes. Each box is of the format [y_min, x_min, y_max, x_max] (see
  271. datatype vrd_box_data_type, single_box_data_type above).
  272. Returns:
  273. Unchanged input.
  274. """
  275. return detections_box_tuples
  276. class VRDPhraseDetectionEvaluator(VRDDetectionEvaluator):
  277. """A class to evaluate VRD detections in phrase setting.
  278. Expected groundtruth box datatype is vrd_box_data_type, expected groudtruth
  279. labels datatype is label_data_type.
  280. Expected detection box datatype is single_box_data_type, expected detection
  281. labels datatype is label_data_type.
  282. """
  283. def __init__(self, matching_iou_threshold=0.5):
  284. super(VRDPhraseDetectionEvaluator, self).__init__(
  285. matching_iou_threshold=matching_iou_threshold,
  286. metric_prefix='VRDMetric_Phrases')
  287. def _process_groundtruth_boxes(self, groundtruth_box_tuples):
  288. """Pre-processes boxes before adding them to the VRDDetectionEvaluation.
  289. In case of phrase evaluation task, evaluation expects exactly one bounding
  290. box containing all objects in the phrase. This bounding box is computed
  291. as an enclosing box of all groundtruth boxes of a phrase.
  292. Args:
  293. groundtruth_box_tuples: A numpy array of structures with the shape
  294. [M, 1], each structure containing the same number of named bounding
  295. boxes. Each box is of the format [y_min, x_min, y_max, x_max]. See
  296. vrd_box_data_type for an example of structure.
  297. Returns:
  298. result: A numpy array of structures with the shape [M, 1], each
  299. structure containing exactly one named bounding box. i-th output
  300. structure corresponds to the result of processing i-th input structure,
  301. where the named bounding box is computed as an enclosing bounding box
  302. of all bounding boxes of the i-th input structure.
  303. """
  304. first_box_key = groundtruth_box_tuples.dtype.fields.keys()[0]
  305. miny = groundtruth_box_tuples[first_box_key][:, 0]
  306. minx = groundtruth_box_tuples[first_box_key][:, 1]
  307. maxy = groundtruth_box_tuples[first_box_key][:, 2]
  308. maxx = groundtruth_box_tuples[first_box_key][:, 3]
  309. for fields in groundtruth_box_tuples.dtype.fields:
  310. miny = np.minimum(groundtruth_box_tuples[fields][:, 0], miny)
  311. minx = np.minimum(groundtruth_box_tuples[fields][:, 1], minx)
  312. maxy = np.maximum(groundtruth_box_tuples[fields][:, 2], maxy)
  313. maxx = np.maximum(groundtruth_box_tuples[fields][:, 3], maxx)
  314. data_result = []
  315. for i in range(groundtruth_box_tuples.shape[0]):
  316. data_result.append(([miny[i], minx[i], maxy[i], maxx[i]],))
  317. result = np.array(data_result, dtype=[('box', 'f4', (4,))])
  318. return result
  319. def _process_detection_boxes(self, detections_box_tuples):
  320. """Pre-processes boxes before adding them to the VRDDetectionEvaluation.
  321. In case of phrase evaluation task, evaluation expects exactly one bounding
  322. box containing all objects in the phrase. This bounding box is computed
  323. as an enclosing box of all groundtruth boxes of a phrase.
  324. Args:
  325. detections_box_tuples: A numpy array of structures with the shape
  326. [M, 1], each structure containing the same number of named bounding
  327. boxes. Each box is of the format [y_min, x_min, y_max, x_max]. See
  328. vrd_box_data_type for an example of this structure.
  329. Returns:
  330. result: A numpy array of structures with the shape [M, 1], each
  331. structure containing exactly one named bounding box. i-th output
  332. structure corresponds to the result of processing i-th input structure,
  333. where the named bounding box is computed as an enclosing bounding box
  334. of all bounding boxes of the i-th input structure.
  335. """
  336. first_box_key = detections_box_tuples.dtype.fields.keys()[0]
  337. miny = detections_box_tuples[first_box_key][:, 0]
  338. minx = detections_box_tuples[first_box_key][:, 1]
  339. maxy = detections_box_tuples[first_box_key][:, 2]
  340. maxx = detections_box_tuples[first_box_key][:, 3]
  341. for fields in detections_box_tuples.dtype.fields:
  342. miny = np.minimum(detections_box_tuples[fields][:, 0], miny)
  343. minx = np.minimum(detections_box_tuples[fields][:, 1], minx)
  344. maxy = np.maximum(detections_box_tuples[fields][:, 2], maxy)
  345. maxx = np.maximum(detections_box_tuples[fields][:, 3], maxx)
  346. data_result = []
  347. for i in range(detections_box_tuples.shape[0]):
  348. data_result.append(([miny[i], minx[i], maxy[i], maxx[i]],))
  349. result = np.array(data_result, dtype=[('box', 'f4', (4,))])
  350. return result
  351. VRDDetectionEvalMetrics = collections.namedtuple('VRDDetectionEvalMetrics', [
  352. 'weighted_average_precision', 'mean_average_precision',
  353. 'average_precisions', 'precisions', 'recalls', 'recall_50', 'recall_100',
  354. 'median_rank_50', 'median_rank_100'
  355. ])
  356. class _VRDDetectionEvaluation(object):
  357. """Performs metric computation for the VRD task. This class is internal.
  358. """
  359. def __init__(self, matching_iou_threshold=0.5):
  360. """Constructor.
  361. Args:
  362. matching_iou_threshold: IOU threshold to use for matching groundtruth
  363. boxes to detection boxes.
  364. """
  365. self._per_image_eval = per_image_vrd_evaluation.PerImageVRDEvaluation(
  366. matching_iou_threshold=matching_iou_threshold)
  367. self._groundtruth_box_tuples = {}
  368. self._groundtruth_class_tuples = {}
  369. self._num_gt_instances = 0
  370. self._num_gt_imgs = 0
  371. self._num_gt_instances_per_relationship = {}
  372. self.clear_detections()
  373. def clear_detections(self):
  374. """Clears detections."""
  375. self._detection_keys = set()
  376. self._scores = []
  377. self._relation_field_values = []
  378. self._tp_fp_labels = []
  379. self._average_precisions = {}
  380. self._precisions = []
  381. self._recalls = []
  382. def add_single_ground_truth_image_info(
  383. self, image_key, groundtruth_box_tuples, groundtruth_class_tuples):
  384. """Adds groundtruth for a single image to be used for evaluation.
  385. Args:
  386. image_key: A unique string/integer identifier for the image.
  387. groundtruth_box_tuples: A numpy array of structures with the shape
  388. [M, 1], representing M tuples, each tuple containing the same number
  389. of named bounding boxes.
  390. Each box is of the format [y_min, x_min, y_max, x_max].
  391. groundtruth_class_tuples: A numpy array of structures shape [M, 1],
  392. representing the class labels of the corresponding bounding boxes and
  393. possibly additional classes.
  394. """
  395. if image_key in self._groundtruth_box_tuples:
  396. logging.warn(
  397. 'image %s has already been added to the ground truth database.',
  398. image_key)
  399. return
  400. self._groundtruth_box_tuples[image_key] = groundtruth_box_tuples
  401. self._groundtruth_class_tuples[image_key] = groundtruth_class_tuples
  402. self._update_groundtruth_statistics(groundtruth_class_tuples)
  403. def add_single_detected_image_info(self, image_key, detected_box_tuples,
  404. detected_scores, detected_class_tuples):
  405. """Adds detections for a single image to be used for evaluation.
  406. Args:
  407. image_key: A unique string/integer identifier for the image.
  408. detected_box_tuples: A numpy array of structures with shape [N, 1],
  409. representing N tuples, each tuple containing the same number of named
  410. bounding boxes.
  411. Each box is of the format [y_min, x_min, y_max, x_max].
  412. detected_scores: A float numpy array of shape [N, 1], representing
  413. the confidence scores of the detected N object instances.
  414. detected_class_tuples: A numpy array of structures shape [N, 1],
  415. representing the class labels of the corresponding bounding boxes and
  416. possibly additional classes.
  417. """
  418. self._detection_keys.add(image_key)
  419. if image_key in self._groundtruth_box_tuples:
  420. groundtruth_box_tuples = self._groundtruth_box_tuples[image_key]
  421. groundtruth_class_tuples = self._groundtruth_class_tuples[image_key]
  422. else:
  423. groundtruth_box_tuples = np.empty(
  424. shape=[0, 4], dtype=detected_box_tuples.dtype)
  425. groundtruth_class_tuples = np.array([], dtype=detected_class_tuples.dtype)
  426. scores, tp_fp_labels, mapping = (
  427. self._per_image_eval.compute_detection_tp_fp(
  428. detected_box_tuples=detected_box_tuples,
  429. detected_scores=detected_scores,
  430. detected_class_tuples=detected_class_tuples,
  431. groundtruth_box_tuples=groundtruth_box_tuples,
  432. groundtruth_class_tuples=groundtruth_class_tuples))
  433. self._scores += [scores]
  434. self._tp_fp_labels += [tp_fp_labels]
  435. self._relation_field_values += [detected_class_tuples[mapping]['relation']]
  436. def _update_groundtruth_statistics(self, groundtruth_class_tuples):
  437. """Updates grouth truth statistics.
  438. Args:
  439. groundtruth_class_tuples: A numpy array of structures shape [M, 1],
  440. representing the class labels of the corresponding bounding boxes and
  441. possibly additional classes.
  442. """
  443. self._num_gt_instances += groundtruth_class_tuples.shape[0]
  444. self._num_gt_imgs += 1
  445. for relation_field_value in np.unique(groundtruth_class_tuples['relation']):
  446. if relation_field_value not in self._num_gt_instances_per_relationship:
  447. self._num_gt_instances_per_relationship[relation_field_value] = 0
  448. self._num_gt_instances_per_relationship[relation_field_value] += np.sum(
  449. groundtruth_class_tuples['relation'] == relation_field_value)
  450. def evaluate(self):
  451. """Computes evaluation result.
  452. Returns:
  453. A named tuple with the following fields -
  454. average_precision: a float number corresponding to average precision.
  455. precisions: an array of precisions.
  456. recalls: an array of recalls.
  457. recall@50: recall computed on 50 top-scoring samples.
  458. recall@100: recall computed on 100 top-scoring samples.
  459. median_rank@50: median rank computed on 50 top-scoring samples.
  460. median_rank@100: median rank computed on 100 top-scoring samples.
  461. """
  462. if self._num_gt_instances == 0:
  463. logging.warn('No ground truth instances')
  464. if not self._scores:
  465. scores = np.array([], dtype=float)
  466. tp_fp_labels = np.array([], dtype=bool)
  467. else:
  468. scores = np.concatenate(self._scores)
  469. tp_fp_labels = np.concatenate(self._tp_fp_labels)
  470. relation_field_values = np.concatenate(self._relation_field_values)
  471. for relation_field_value, _ in (
  472. self._num_gt_instances_per_relationship.iteritems()):
  473. precisions, recalls = metrics.compute_precision_recall(
  474. scores[relation_field_values == relation_field_value],
  475. tp_fp_labels[relation_field_values == relation_field_value],
  476. self._num_gt_instances_per_relationship[relation_field_value])
  477. self._average_precisions[
  478. relation_field_value] = metrics.compute_average_precision(
  479. precisions, recalls)
  480. self._mean_average_precision = np.mean(self._average_precisions.values())
  481. self._precisions, self._recalls = metrics.compute_precision_recall(
  482. scores, tp_fp_labels, self._num_gt_instances)
  483. self._weighted_average_precision = metrics.compute_average_precision(
  484. self._precisions, self._recalls)
  485. self._recall_50 = (
  486. metrics.compute_recall_at_k(self._tp_fp_labels, self._num_gt_instances,
  487. 50))
  488. self._median_rank_50 = (
  489. metrics.compute_median_rank_at_k(self._tp_fp_labels, 50))
  490. self._recall_100 = (
  491. metrics.compute_recall_at_k(self._tp_fp_labels, self._num_gt_instances,
  492. 100))
  493. self._median_rank_100 = (
  494. metrics.compute_median_rank_at_k(self._tp_fp_labels, 100))
  495. return VRDDetectionEvalMetrics(
  496. self._weighted_average_precision, self._mean_average_precision,
  497. self._average_precisions, self._precisions, self._recalls,
  498. self._recall_50, self._recall_100, self._median_rank_50,
  499. self._median_rank_100)