You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

176 lines
6.3 KiB

  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. r"""Tests for detection_inference.py."""
  16. import os
  17. import StringIO
  18. import numpy as np
  19. from PIL import Image
  20. import tensorflow as tf
  21. from object_detection.core import standard_fields
  22. from object_detection.inference import detection_inference
  23. from object_detection.utils import dataset_util
  24. def get_mock_tfrecord_path():
  25. return os.path.join(tf.test.get_temp_dir(), 'mock.tfrec')
  26. def create_mock_tfrecord():
  27. pil_image = Image.fromarray(np.array([[[123, 0, 0]]], dtype=np.uint8), 'RGB')
  28. image_output_stream = StringIO.StringIO()
  29. pil_image.save(image_output_stream, format='png')
  30. encoded_image = image_output_stream.getvalue()
  31. feature_map = {
  32. 'test_field':
  33. dataset_util.float_list_feature([1, 2, 3, 4]),
  34. standard_fields.TfExampleFields.image_encoded:
  35. dataset_util.bytes_feature(encoded_image),
  36. }
  37. tf_example = tf.train.Example(features=tf.train.Features(feature=feature_map))
  38. with tf.python_io.TFRecordWriter(get_mock_tfrecord_path()) as writer:
  39. writer.write(tf_example.SerializeToString())
  40. def get_mock_graph_path():
  41. return os.path.join(tf.test.get_temp_dir(), 'mock_graph.pb')
  42. def create_mock_graph():
  43. g = tf.Graph()
  44. with g.as_default():
  45. in_image_tensor = tf.placeholder(
  46. tf.uint8, shape=[1, None, None, 3], name='image_tensor')
  47. tf.constant([2.0], name='num_detections')
  48. tf.constant(
  49. [[[0, 0.8, 0.7, 1], [0.1, 0.2, 0.8, 0.9], [0.2, 0.3, 0.4, 0.5]]],
  50. name='detection_boxes')
  51. tf.constant([[0.1, 0.2, 0.3]], name='detection_scores')
  52. tf.identity(
  53. tf.constant([[1.0, 2.0, 3.0]]) *
  54. tf.reduce_sum(tf.cast(in_image_tensor, dtype=tf.float32)),
  55. name='detection_classes')
  56. graph_def = g.as_graph_def()
  57. with tf.gfile.Open(get_mock_graph_path(), 'w') as fl:
  58. fl.write(graph_def.SerializeToString())
  59. class InferDetectionsTests(tf.test.TestCase):
  60. def test_simple(self):
  61. create_mock_graph()
  62. create_mock_tfrecord()
  63. serialized_example_tensor, image_tensor = detection_inference.build_input(
  64. [get_mock_tfrecord_path()])
  65. self.assertAllEqual(image_tensor.get_shape().as_list(), [1, None, None, 3])
  66. (detected_boxes_tensor, detected_scores_tensor,
  67. detected_labels_tensor) = detection_inference.build_inference_graph(
  68. image_tensor, get_mock_graph_path())
  69. with self.test_session(use_gpu=False) as sess:
  70. sess.run(tf.global_variables_initializer())
  71. sess.run(tf.local_variables_initializer())
  72. tf.train.start_queue_runners()
  73. tf_example = detection_inference.infer_detections_and_add_to_example(
  74. serialized_example_tensor, detected_boxes_tensor,
  75. detected_scores_tensor, detected_labels_tensor, False)
  76. self.assertProtoEquals(r"""
  77. features {
  78. feature {
  79. key: "image/detection/bbox/ymin"
  80. value { float_list { value: [0.0, 0.1] } } }
  81. feature {
  82. key: "image/detection/bbox/xmin"
  83. value { float_list { value: [0.8, 0.2] } } }
  84. feature {
  85. key: "image/detection/bbox/ymax"
  86. value { float_list { value: [0.7, 0.8] } } }
  87. feature {
  88. key: "image/detection/bbox/xmax"
  89. value { float_list { value: [1.0, 0.9] } } }
  90. feature {
  91. key: "image/detection/label"
  92. value { int64_list { value: [123, 246] } } }
  93. feature {
  94. key: "image/detection/score"
  95. value { float_list { value: [0.1, 0.2] } } }
  96. feature {
  97. key: "image/encoded"
  98. value { bytes_list { value:
  99. "\211PNG\r\n\032\n\000\000\000\rIHDR\000\000\000\001\000\000"
  100. "\000\001\010\002\000\000\000\220wS\336\000\000\000\022IDATx"
  101. "\234b\250f`\000\000\000\000\377\377\003\000\001u\000|gO\242"
  102. "\213\000\000\000\000IEND\256B`\202" } } }
  103. feature {
  104. key: "test_field"
  105. value { float_list { value: [1.0, 2.0, 3.0, 4.0] } } } }
  106. """, tf_example)
  107. def test_discard_image(self):
  108. create_mock_graph()
  109. create_mock_tfrecord()
  110. serialized_example_tensor, image_tensor = detection_inference.build_input(
  111. [get_mock_tfrecord_path()])
  112. (detected_boxes_tensor, detected_scores_tensor,
  113. detected_labels_tensor) = detection_inference.build_inference_graph(
  114. image_tensor, get_mock_graph_path())
  115. with self.test_session(use_gpu=False) as sess:
  116. sess.run(tf.global_variables_initializer())
  117. sess.run(tf.local_variables_initializer())
  118. tf.train.start_queue_runners()
  119. tf_example = detection_inference.infer_detections_and_add_to_example(
  120. serialized_example_tensor, detected_boxes_tensor,
  121. detected_scores_tensor, detected_labels_tensor, True)
  122. self.assertProtoEquals(r"""
  123. features {
  124. feature {
  125. key: "image/detection/bbox/ymin"
  126. value { float_list { value: [0.0, 0.1] } } }
  127. feature {
  128. key: "image/detection/bbox/xmin"
  129. value { float_list { value: [0.8, 0.2] } } }
  130. feature {
  131. key: "image/detection/bbox/ymax"
  132. value { float_list { value: [0.7, 0.8] } } }
  133. feature {
  134. key: "image/detection/bbox/xmax"
  135. value { float_list { value: [1.0, 0.9] } } }
  136. feature {
  137. key: "image/detection/label"
  138. value { int64_list { value: [123, 246] } } }
  139. feature {
  140. key: "image/detection/score"
  141. value { float_list { value: [0.1, 0.2] } } }
  142. feature {
  143. key: "test_field"
  144. value { float_list { value: [1.0, 2.0, 3.0, 4.0] } } } }
  145. """, tf_example)
  146. if __name__ == '__main__':
  147. tf.test.main()