You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

159 lines
6.1 KiB

  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tensorflow Example proto parser for data loading.
  16. A parser to decode data containing serialized tensorflow.Example
  17. protos into materialized tensors (numpy arrays).
  18. """
  19. import numpy as np
  20. from object_detection.core import data_parser
  21. from object_detection.core import standard_fields as fields
  22. class FloatParser(data_parser.DataToNumpyParser):
  23. """Tensorflow Example float parser."""
  24. def __init__(self, field_name):
  25. self.field_name = field_name
  26. def parse(self, tf_example):
  27. return np.array(
  28. tf_example.features.feature[self.field_name].float_list.value,
  29. dtype=np.float).transpose() if tf_example.features.feature[
  30. self.field_name].HasField("float_list") else None
  31. class StringParser(data_parser.DataToNumpyParser):
  32. """Tensorflow Example string parser."""
  33. def __init__(self, field_name):
  34. self.field_name = field_name
  35. def parse(self, tf_example):
  36. return "".join(tf_example.features.feature[self.field_name]
  37. .bytes_list.value) if tf_example.features.feature[
  38. self.field_name].HasField("bytes_list") else None
  39. class Int64Parser(data_parser.DataToNumpyParser):
  40. """Tensorflow Example int64 parser."""
  41. def __init__(self, field_name):
  42. self.field_name = field_name
  43. def parse(self, tf_example):
  44. return np.array(
  45. tf_example.features.feature[self.field_name].int64_list.value,
  46. dtype=np.int64).transpose() if tf_example.features.feature[
  47. self.field_name].HasField("int64_list") else None
  48. class BoundingBoxParser(data_parser.DataToNumpyParser):
  49. """Tensorflow Example bounding box parser."""
  50. def __init__(self, xmin_field_name, ymin_field_name, xmax_field_name,
  51. ymax_field_name):
  52. self.field_names = [
  53. ymin_field_name, xmin_field_name, ymax_field_name, xmax_field_name
  54. ]
  55. def parse(self, tf_example):
  56. result = []
  57. parsed = True
  58. for field_name in self.field_names:
  59. result.append(tf_example.features.feature[field_name].float_list.value)
  60. parsed &= (
  61. tf_example.features.feature[field_name].HasField("float_list"))
  62. return np.array(result).transpose() if parsed else None
  63. class TfExampleDetectionAndGTParser(data_parser.DataToNumpyParser):
  64. """Tensorflow Example proto parser."""
  65. def __init__(self):
  66. self.items_to_handlers = {
  67. fields.DetectionResultFields.key:
  68. StringParser(fields.TfExampleFields.source_id),
  69. # Object ground truth boxes and classes.
  70. fields.InputDataFields.groundtruth_boxes: (BoundingBoxParser(
  71. fields.TfExampleFields.object_bbox_xmin,
  72. fields.TfExampleFields.object_bbox_ymin,
  73. fields.TfExampleFields.object_bbox_xmax,
  74. fields.TfExampleFields.object_bbox_ymax)),
  75. fields.InputDataFields.groundtruth_classes: (
  76. Int64Parser(fields.TfExampleFields.object_class_label)),
  77. # Object detections.
  78. fields.DetectionResultFields.detection_boxes: (BoundingBoxParser(
  79. fields.TfExampleFields.detection_bbox_xmin,
  80. fields.TfExampleFields.detection_bbox_ymin,
  81. fields.TfExampleFields.detection_bbox_xmax,
  82. fields.TfExampleFields.detection_bbox_ymax)),
  83. fields.DetectionResultFields.detection_classes: (
  84. Int64Parser(fields.TfExampleFields.detection_class_label)),
  85. fields.DetectionResultFields.detection_scores: (
  86. FloatParser(fields.TfExampleFields.detection_score)),
  87. }
  88. self.optional_items_to_handlers = {
  89. fields.InputDataFields.groundtruth_difficult:
  90. Int64Parser(fields.TfExampleFields.object_difficult),
  91. fields.InputDataFields.groundtruth_group_of:
  92. Int64Parser(fields.TfExampleFields.object_group_of),
  93. fields.InputDataFields.groundtruth_image_classes:
  94. Int64Parser(fields.TfExampleFields.image_class_label),
  95. }
  96. def parse(self, tf_example):
  97. """Parses tensorflow example and returns a tensor dictionary.
  98. Args:
  99. tf_example: a tf.Example object.
  100. Returns:
  101. A dictionary of the following numpy arrays:
  102. fields.DetectionResultFields.source_id - string containing original image
  103. id.
  104. fields.InputDataFields.groundtruth_boxes - a numpy array containing
  105. groundtruth boxes.
  106. fields.InputDataFields.groundtruth_classes - a numpy array containing
  107. groundtruth classes.
  108. fields.InputDataFields.groundtruth_group_of - a numpy array containing
  109. groundtruth group of flag (optional, None if not specified).
  110. fields.InputDataFields.groundtruth_difficult - a numpy array containing
  111. groundtruth difficult flag (optional, None if not specified).
  112. fields.InputDataFields.groundtruth_image_classes - a numpy array
  113. containing groundtruth image-level labels.
  114. fields.DetectionResultFields.detection_boxes - a numpy array containing
  115. detection boxes.
  116. fields.DetectionResultFields.detection_classes - a numpy array containing
  117. detection class labels.
  118. fields.DetectionResultFields.detection_scores - a numpy array containing
  119. detection scores.
  120. Returns None if tf.Example was not parsed or non-optional fields were not
  121. found.
  122. """
  123. results_dict = {}
  124. parsed = True
  125. for key, parser in self.items_to_handlers.items():
  126. results_dict[key] = parser.parse(tf_example)
  127. parsed &= (results_dict[key] is not None)
  128. for key, parser in self.optional_items_to_handlers.items():
  129. results_dict[key] = parser.parse(tf_example)
  130. return results_dict if parsed else None