You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

154 lines
5.7 KiB

6 years ago
  1. # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. r"""Runs evaluation using OpenImages groundtruth and predictions.
  16. Example usage:
  17. python \
  18. models/research/object_detection/metrics/oid_vrd_challenge_evaluation.py \
  19. --input_annotations_boxes=/path/to/input/annotations-human-bbox.csv \
  20. --input_annotations_labels=/path/to/input/annotations-label.csv \
  21. --input_class_labelmap=/path/to/input/class_labelmap.pbtxt \
  22. --input_relationship_labelmap=/path/to/input/relationship_labelmap.pbtxt \
  23. --input_predictions=/path/to/input/predictions.csv \
  24. --output_metrics=/path/to/output/metric.csv \
  25. CSVs with bounding box annotations and image label (including the image URLs)
  26. can be downloaded from the Open Images Challenge website:
  27. https://storage.googleapis.com/openimages/web/challenge.html
  28. The format of the input csv and the metrics itself are described on the
  29. challenge website.
  30. """
  31. from __future__ import absolute_import
  32. from __future__ import division
  33. from __future__ import print_function
  34. import argparse
  35. import pandas as pd
  36. from google.protobuf import text_format
  37. from object_detection.metrics import io_utils
  38. from object_detection.metrics import oid_vrd_challenge_evaluation_utils as utils
  39. from object_detection.protos import string_int_label_map_pb2
  40. from object_detection.utils import vrd_evaluation
  41. def _load_labelmap(labelmap_path):
  42. """Loads labelmap from the labelmap path.
  43. Args:
  44. labelmap_path: Path to the labelmap.
  45. Returns:
  46. A dictionary mapping class name to class numerical id.
  47. """
  48. label_map = string_int_label_map_pb2.StringIntLabelMap()
  49. with open(labelmap_path, 'r') as fid:
  50. label_map_string = fid.read()
  51. text_format.Merge(label_map_string, label_map)
  52. labelmap_dict = {}
  53. for item in label_map.item:
  54. labelmap_dict[item.name] = item.id
  55. return labelmap_dict
  56. def _swap_labelmap_dict(labelmap_dict):
  57. """Swaps keys and labels in labelmap.
  58. Args:
  59. labelmap_dict: Input dictionary.
  60. Returns:
  61. A dictionary mapping class name to class numerical id.
  62. """
  63. return dict((v, k) for k, v in labelmap_dict.iteritems())
  64. def main(parsed_args):
  65. all_box_annotations = pd.read_csv(parsed_args.input_annotations_boxes)
  66. all_label_annotations = pd.read_csv(parsed_args.input_annotations_labels)
  67. all_annotations = pd.concat([all_box_annotations, all_label_annotations])
  68. class_label_map = _load_labelmap(parsed_args.input_class_labelmap)
  69. relationship_label_map = _load_labelmap(
  70. parsed_args.input_relationship_labelmap)
  71. relation_evaluator = vrd_evaluation.VRDRelationDetectionEvaluator()
  72. phrase_evaluator = vrd_evaluation.VRDPhraseDetectionEvaluator()
  73. for _, groundtruth in enumerate(all_annotations.groupby('ImageID')):
  74. image_id, image_groundtruth = groundtruth
  75. groundtruth_dictionary = utils.build_groundtruth_vrd_dictionary(
  76. image_groundtruth, class_label_map, relationship_label_map)
  77. relation_evaluator.add_single_ground_truth_image_info(
  78. image_id, groundtruth_dictionary)
  79. phrase_evaluator.add_single_ground_truth_image_info(image_id,
  80. groundtruth_dictionary)
  81. all_predictions = pd.read_csv(parsed_args.input_predictions)
  82. for _, prediction_data in enumerate(all_predictions.groupby('ImageID')):
  83. image_id, image_predictions = prediction_data
  84. prediction_dictionary = utils.build_predictions_vrd_dictionary(
  85. image_predictions, class_label_map, relationship_label_map)
  86. relation_evaluator.add_single_detected_image_info(image_id,
  87. prediction_dictionary)
  88. phrase_evaluator.add_single_detected_image_info(image_id,
  89. prediction_dictionary)
  90. relation_metrics = relation_evaluator.evaluate(
  91. relationships=_swap_labelmap_dict(relationship_label_map))
  92. phrase_metrics = phrase_evaluator.evaluate(
  93. relationships=_swap_labelmap_dict(relationship_label_map))
  94. with open(parsed_args.output_metrics, 'w') as fid:
  95. io_utils.write_csv(fid, relation_metrics)
  96. io_utils.write_csv(fid, phrase_metrics)
  97. if __name__ == '__main__':
  98. parser = argparse.ArgumentParser(
  99. description=
  100. 'Evaluate Open Images Visual Relationship Detection predictions.')
  101. parser.add_argument(
  102. '--input_annotations_boxes',
  103. required=True,
  104. help='File with groundtruth vrd annotations.')
  105. parser.add_argument(
  106. '--input_annotations_labels',
  107. required=True,
  108. help='File with groundtruth labels annotations')
  109. parser.add_argument(
  110. '--input_predictions',
  111. required=True,
  112. help="""File with detection predictions; NOTE: no postprocessing is
  113. applied in the evaluation script.""")
  114. parser.add_argument(
  115. '--input_class_labelmap',
  116. required=True,
  117. help="""OpenImages Challenge labelmap; note: it is expected to include
  118. attributes.""")
  119. parser.add_argument(
  120. '--input_relationship_labelmap',
  121. required=True,
  122. help="""OpenImages Challenge relationship labelmap.""")
  123. parser.add_argument(
  124. '--output_metrics', required=True, help='Output file with csv metrics')
  125. args = parser.parse_args()
  126. main(args)