You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

197 lines
7.6 KiB

6 years ago
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for object_detection.data_decoders.tf_example_parser."""
  16. import numpy as np
  17. import numpy.testing as np_testing
  18. import tensorflow as tf
  19. from object_detection.core import standard_fields as fields
  20. from object_detection.metrics import tf_example_parser
  21. class TfExampleDecoderTest(tf.test.TestCase):
  22. def _Int64Feature(self, value):
  23. return tf.train.Feature(int64_list=tf.train.Int64List(value=value))
  24. def _FloatFeature(self, value):
  25. return tf.train.Feature(float_list=tf.train.FloatList(value=value))
  26. def _BytesFeature(self, value):
  27. return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
  28. def testParseDetectionsAndGT(self):
  29. source_id = 'abc.jpg'
  30. # y_min, x_min, y_max, x_max
  31. object_bb = np.array([[0.0, 0.5, 0.3], [0.0, 0.1, 0.6], [1.0, 0.6, 0.8],
  32. [1.0, 0.6, 0.7]]).transpose()
  33. detection_bb = np.array([[0.1, 0.2], [0.0, 0.8], [1.0, 0.6],
  34. [1.0, 0.85]]).transpose()
  35. object_class_label = [1, 1, 2]
  36. object_difficult = [1, 0, 0]
  37. object_group_of = [0, 0, 1]
  38. verified_labels = [1, 2, 3, 4]
  39. detection_class_label = [2, 1]
  40. detection_score = [0.5, 0.3]
  41. features = {
  42. fields.TfExampleFields.source_id:
  43. self._BytesFeature(source_id),
  44. fields.TfExampleFields.object_bbox_ymin:
  45. self._FloatFeature(object_bb[:, 0].tolist()),
  46. fields.TfExampleFields.object_bbox_xmin:
  47. self._FloatFeature(object_bb[:, 1].tolist()),
  48. fields.TfExampleFields.object_bbox_ymax:
  49. self._FloatFeature(object_bb[:, 2].tolist()),
  50. fields.TfExampleFields.object_bbox_xmax:
  51. self._FloatFeature(object_bb[:, 3].tolist()),
  52. fields.TfExampleFields.detection_bbox_ymin:
  53. self._FloatFeature(detection_bb[:, 0].tolist()),
  54. fields.TfExampleFields.detection_bbox_xmin:
  55. self._FloatFeature(detection_bb[:, 1].tolist()),
  56. fields.TfExampleFields.detection_bbox_ymax:
  57. self._FloatFeature(detection_bb[:, 2].tolist()),
  58. fields.TfExampleFields.detection_bbox_xmax:
  59. self._FloatFeature(detection_bb[:, 3].tolist()),
  60. fields.TfExampleFields.detection_class_label:
  61. self._Int64Feature(detection_class_label),
  62. fields.TfExampleFields.detection_score:
  63. self._FloatFeature(detection_score),
  64. }
  65. example = tf.train.Example(features=tf.train.Features(feature=features))
  66. parser = tf_example_parser.TfExampleDetectionAndGTParser()
  67. results_dict = parser.parse(example)
  68. self.assertIsNone(results_dict)
  69. features[fields.TfExampleFields.object_class_label] = (
  70. self._Int64Feature(object_class_label))
  71. features[fields.TfExampleFields.object_difficult] = (
  72. self._Int64Feature(object_difficult))
  73. example = tf.train.Example(features=tf.train.Features(feature=features))
  74. results_dict = parser.parse(example)
  75. self.assertIsNotNone(results_dict)
  76. self.assertEqual(source_id, results_dict[fields.DetectionResultFields.key])
  77. np_testing.assert_almost_equal(
  78. object_bb, results_dict[fields.InputDataFields.groundtruth_boxes])
  79. np_testing.assert_almost_equal(
  80. detection_bb,
  81. results_dict[fields.DetectionResultFields.detection_boxes])
  82. np_testing.assert_almost_equal(
  83. detection_score,
  84. results_dict[fields.DetectionResultFields.detection_scores])
  85. np_testing.assert_almost_equal(
  86. detection_class_label,
  87. results_dict[fields.DetectionResultFields.detection_classes])
  88. np_testing.assert_almost_equal(
  89. object_difficult,
  90. results_dict[fields.InputDataFields.groundtruth_difficult])
  91. np_testing.assert_almost_equal(
  92. object_class_label,
  93. results_dict[fields.InputDataFields.groundtruth_classes])
  94. parser = tf_example_parser.TfExampleDetectionAndGTParser()
  95. features[fields.TfExampleFields.object_group_of] = (
  96. self._Int64Feature(object_group_of))
  97. example = tf.train.Example(features=tf.train.Features(feature=features))
  98. results_dict = parser.parse(example)
  99. self.assertIsNotNone(results_dict)
  100. np_testing.assert_equal(
  101. object_group_of,
  102. results_dict[fields.InputDataFields.groundtruth_group_of])
  103. features[fields.TfExampleFields.image_class_label] = (
  104. self._Int64Feature(verified_labels))
  105. example = tf.train.Example(features=tf.train.Features(feature=features))
  106. results_dict = parser.parse(example)
  107. self.assertIsNotNone(results_dict)
  108. np_testing.assert_equal(
  109. verified_labels,
  110. results_dict[fields.InputDataFields.groundtruth_image_classes])
  111. def testParseString(self):
  112. string_val = 'abc'
  113. features = {'string': self._BytesFeature(string_val)}
  114. example = tf.train.Example(features=tf.train.Features(feature=features))
  115. parser = tf_example_parser.StringParser('string')
  116. result = parser.parse(example)
  117. self.assertIsNotNone(result)
  118. self.assertEqual(result, string_val)
  119. parser = tf_example_parser.StringParser('another_string')
  120. result = parser.parse(example)
  121. self.assertIsNone(result)
  122. def testParseFloat(self):
  123. float_array_val = [1.5, 1.4, 2.0]
  124. features = {'floats': self._FloatFeature(float_array_val)}
  125. example = tf.train.Example(features=tf.train.Features(feature=features))
  126. parser = tf_example_parser.FloatParser('floats')
  127. result = parser.parse(example)
  128. self.assertIsNotNone(result)
  129. np_testing.assert_almost_equal(result, float_array_val)
  130. parser = tf_example_parser.StringParser('another_floats')
  131. result = parser.parse(example)
  132. self.assertIsNone(result)
  133. def testInt64Parser(self):
  134. int_val = [1, 2, 3]
  135. features = {'ints': self._Int64Feature(int_val)}
  136. example = tf.train.Example(features=tf.train.Features(feature=features))
  137. parser = tf_example_parser.Int64Parser('ints')
  138. result = parser.parse(example)
  139. self.assertIsNotNone(result)
  140. np_testing.assert_almost_equal(result, int_val)
  141. parser = tf_example_parser.Int64Parser('another_ints')
  142. result = parser.parse(example)
  143. self.assertIsNone(result)
  144. def testBoundingBoxParser(self):
  145. bounding_boxes = np.array([[0.0, 0.5, 0.3], [0.0, 0.1, 0.6],
  146. [1.0, 0.6, 0.8], [1.0, 0.6, 0.7]]).transpose()
  147. features = {
  148. 'ymin': self._FloatFeature(bounding_boxes[:, 0]),
  149. 'xmin': self._FloatFeature(bounding_boxes[:, 1]),
  150. 'ymax': self._FloatFeature(bounding_boxes[:, 2]),
  151. 'xmax': self._FloatFeature(bounding_boxes[:, 3])
  152. }
  153. example = tf.train.Example(features=tf.train.Features(feature=features))
  154. parser = tf_example_parser.BoundingBoxParser('xmin', 'ymin', 'xmax', 'ymax')
  155. result = parser.parse(example)
  156. self.assertIsNotNone(result)
  157. np_testing.assert_almost_equal(result, bounding_boxes)
  158. parser = tf_example_parser.BoundingBoxParser('xmin', 'ymin', 'xmax',
  159. 'another_ymax')
  160. result = parser.parse(example)
  161. self.assertIsNone(result)
  162. if __name__ == '__main__':
  163. tf.test.main()