You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

145 lines
5.7 KiB

  1. # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. r"""Runs evaluation using OpenImages groundtruth and predictions.
  16. Uses Open Images Challenge 2018, 2019 metrics
  17. Example usage:
  18. python models/research/object_detection/metrics/oid_od_challenge_evaluation.py \
  19. --input_annotations_boxes=/path/to/input/annotations-human-bbox.csv \
  20. --input_annotations_labels=/path/to/input/annotations-label.csv \
  21. --input_class_labelmap=/path/to/input/class_labelmap.pbtxt \
  22. --input_predictions=/path/to/input/predictions.csv \
  23. --output_metrics=/path/to/output/metric.csv \
  24. --input_annotations_segm=[/path/to/input/annotations-human-mask.csv] \
  25. If optional flag has_masks is True, Mask column is also expected in CSV.
  26. CSVs with bounding box annotations, instance segmentations and image label
  27. can be downloaded from the Open Images Challenge website:
  28. https://storage.googleapis.com/openimages/web/challenge.html
  29. The format of the input csv and the metrics itself are described on the
  30. challenge website as well.
  31. """
  32. from __future__ import absolute_import
  33. from __future__ import division
  34. from __future__ import print_function
  35. from absl import app
  36. from absl import flags
  37. import pandas as pd
  38. from google.protobuf import text_format
  39. from object_detection.metrics import io_utils
  40. from object_detection.metrics import oid_challenge_evaluation_utils as utils
  41. from object_detection.protos import string_int_label_map_pb2
  42. from object_detection.utils import object_detection_evaluation
  43. flags.DEFINE_string('input_annotations_boxes', None,
  44. 'File with groundtruth boxes annotations.')
  45. flags.DEFINE_string('input_annotations_labels', None,
  46. 'File with groundtruth labels annotations.')
  47. flags.DEFINE_string(
  48. 'input_predictions', None,
  49. """File with detection predictions; NOTE: no postprocessing is applied in the evaluation script."""
  50. )
  51. flags.DEFINE_string('input_class_labelmap', None,
  52. 'Open Images Challenge labelmap.')
  53. flags.DEFINE_string('output_metrics', None, 'Output file with csv metrics.')
  54. flags.DEFINE_string(
  55. 'input_annotations_segm', None,
  56. 'File with groundtruth instance segmentation annotations [OPTIONAL].')
  57. FLAGS = flags.FLAGS
  58. def _load_labelmap(labelmap_path):
  59. """Loads labelmap from the labelmap path.
  60. Args:
  61. labelmap_path: Path to the labelmap.
  62. Returns:
  63. A dictionary mapping class name to class numerical id
  64. A list with dictionaries, one dictionary per category.
  65. """
  66. label_map = string_int_label_map_pb2.StringIntLabelMap()
  67. with open(labelmap_path, 'r') as fid:
  68. label_map_string = fid.read()
  69. text_format.Merge(label_map_string, label_map)
  70. labelmap_dict = {}
  71. categories = []
  72. for item in label_map.item:
  73. labelmap_dict[item.name] = item.id
  74. categories.append({'id': item.id, 'name': item.name})
  75. return labelmap_dict, categories
  76. def main(unused_argv):
  77. flags.mark_flag_as_required('input_annotations_boxes')
  78. flags.mark_flag_as_required('input_annotations_labels')
  79. flags.mark_flag_as_required('input_predictions')
  80. flags.mark_flag_as_required('input_class_labelmap')
  81. flags.mark_flag_as_required('output_metrics')
  82. all_location_annotations = pd.read_csv(FLAGS.input_annotations_boxes)
  83. all_label_annotations = pd.read_csv(FLAGS.input_annotations_labels)
  84. all_label_annotations.rename(
  85. columns={'Confidence': 'ConfidenceImageLabel'}, inplace=True)
  86. is_instance_segmentation_eval = False
  87. if FLAGS.input_annotations_segm:
  88. is_instance_segmentation_eval = True
  89. all_segm_annotations = pd.read_csv(FLAGS.input_annotations_segm)
  90. # Note: this part is unstable as it requires the float point numbers in both
  91. # csvs are exactly the same;
  92. # Will be replaced by more stable solution: merge on LabelName and ImageID
  93. # and filter down by IoU.
  94. all_location_annotations = utils.merge_boxes_and_masks(
  95. all_location_annotations, all_segm_annotations)
  96. all_annotations = pd.concat([all_location_annotations, all_label_annotations])
  97. class_label_map, categories = _load_labelmap(FLAGS.input_class_labelmap)
  98. challenge_evaluator = (
  99. object_detection_evaluation.OpenImagesChallengeEvaluator(
  100. categories, evaluate_masks=is_instance_segmentation_eval))
  101. for _, groundtruth in enumerate(all_annotations.groupby('ImageID')):
  102. image_id, image_groundtruth = groundtruth
  103. groundtruth_dictionary = utils.build_groundtruth_dictionary(
  104. image_groundtruth, class_label_map)
  105. challenge_evaluator.add_single_ground_truth_image_info(
  106. image_id, groundtruth_dictionary)
  107. all_predictions = pd.read_csv(FLAGS.input_predictions)
  108. for _, prediction_data in enumerate(all_predictions.groupby('ImageID')):
  109. image_id, image_predictions = prediction_data
  110. prediction_dictionary = utils.build_predictions_dictionary(
  111. image_predictions, class_label_map)
  112. challenge_evaluator.add_single_detected_image_info(image_id,
  113. prediction_dictionary)
  114. metrics = challenge_evaluator.evaluate()
  115. with open(FLAGS.output_metrics, 'w') as fid:
  116. io_utils.write_csv(fid, metrics)
  117. if __name__ == '__main__':
  118. app.run(main)