You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

129 lines
4.8 KiB

6 years ago
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for input_reader_builder."""
  16. import os
  17. import numpy as np
  18. import tensorflow as tf
  19. from google.protobuf import text_format
  20. from object_detection.builders import input_reader_builder
  21. from object_detection.core import standard_fields as fields
  22. from object_detection.protos import input_reader_pb2
  23. from object_detection.utils import dataset_util
  24. class InputReaderBuilderTest(tf.test.TestCase):
  25. def create_tf_record(self):
  26. path = os.path.join(self.get_temp_dir(), 'tfrecord')
  27. writer = tf.python_io.TFRecordWriter(path)
  28. image_tensor = np.random.randint(255, size=(4, 5, 3)).astype(np.uint8)
  29. flat_mask = (4 * 5) * [1.0]
  30. with self.test_session():
  31. encoded_jpeg = tf.image.encode_jpeg(tf.constant(image_tensor)).eval()
  32. example = tf.train.Example(features=tf.train.Features(feature={
  33. 'image/encoded': dataset_util.bytes_feature(encoded_jpeg),
  34. 'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
  35. 'image/height': dataset_util.int64_feature(4),
  36. 'image/width': dataset_util.int64_feature(5),
  37. 'image/object/bbox/xmin': dataset_util.float_list_feature([0.0]),
  38. 'image/object/bbox/xmax': dataset_util.float_list_feature([1.0]),
  39. 'image/object/bbox/ymin': dataset_util.float_list_feature([0.0]),
  40. 'image/object/bbox/ymax': dataset_util.float_list_feature([1.0]),
  41. 'image/object/class/label': dataset_util.int64_list_feature([2]),
  42. 'image/object/mask': dataset_util.float_list_feature(flat_mask),
  43. }))
  44. writer.write(example.SerializeToString())
  45. writer.close()
  46. return path
  47. def test_build_tf_record_input_reader(self):
  48. tf_record_path = self.create_tf_record()
  49. input_reader_text_proto = """
  50. shuffle: false
  51. num_readers: 1
  52. tf_record_input_reader {{
  53. input_path: '{0}'
  54. }}
  55. """.format(tf_record_path)
  56. input_reader_proto = input_reader_pb2.InputReader()
  57. text_format.Merge(input_reader_text_proto, input_reader_proto)
  58. tensor_dict = input_reader_builder.build(input_reader_proto)
  59. with tf.train.MonitoredSession() as sess:
  60. output_dict = sess.run(tensor_dict)
  61. self.assertTrue(fields.InputDataFields.groundtruth_instance_masks
  62. not in output_dict)
  63. self.assertEquals(
  64. (4, 5, 3), output_dict[fields.InputDataFields.image].shape)
  65. self.assertEquals(
  66. [2], output_dict[fields.InputDataFields.groundtruth_classes])
  67. self.assertEquals(
  68. (1, 4), output_dict[fields.InputDataFields.groundtruth_boxes].shape)
  69. self.assertAllEqual(
  70. [0.0, 0.0, 1.0, 1.0],
  71. output_dict[fields.InputDataFields.groundtruth_boxes][0])
  72. def test_build_tf_record_input_reader_and_load_instance_masks(self):
  73. tf_record_path = self.create_tf_record()
  74. input_reader_text_proto = """
  75. shuffle: false
  76. num_readers: 1
  77. load_instance_masks: true
  78. tf_record_input_reader {{
  79. input_path: '{0}'
  80. }}
  81. """.format(tf_record_path)
  82. input_reader_proto = input_reader_pb2.InputReader()
  83. text_format.Merge(input_reader_text_proto, input_reader_proto)
  84. tensor_dict = input_reader_builder.build(input_reader_proto)
  85. with tf.train.MonitoredSession() as sess:
  86. output_dict = sess.run(tensor_dict)
  87. self.assertEquals(
  88. (4, 5, 3), output_dict[fields.InputDataFields.image].shape)
  89. self.assertEquals(
  90. [2], output_dict[fields.InputDataFields.groundtruth_classes])
  91. self.assertEquals(
  92. (1, 4), output_dict[fields.InputDataFields.groundtruth_boxes].shape)
  93. self.assertAllEqual(
  94. [0.0, 0.0, 1.0, 1.0],
  95. output_dict[fields.InputDataFields.groundtruth_boxes][0])
  96. self.assertAllEqual(
  97. (1, 4, 5),
  98. output_dict[fields.InputDataFields.groundtruth_instance_masks].shape)
  99. def test_raises_error_with_no_input_paths(self):
  100. input_reader_text_proto = """
  101. shuffle: false
  102. num_readers: 1
  103. load_instance_masks: true
  104. """
  105. input_reader_proto = input_reader_pb2.InputReader()
  106. text_format.Merge(input_reader_text_proto, input_reader_proto)
  107. with self.assertRaises(ValueError):
  108. input_reader_builder.build(input_reader_proto)
  109. if __name__ == '__main__':
  110. tf.test.main()