You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

356 lines
13 KiB

6 years ago
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for dataset_builder."""
  16. import os
  17. import numpy as np
  18. import tensorflow as tf
  19. from google.protobuf import text_format
  20. from object_detection.builders import dataset_builder
  21. from object_detection.core import standard_fields as fields
  22. from object_detection.protos import input_reader_pb2
  23. from object_detection.utils import dataset_util
  24. class DatasetBuilderTest(tf.test.TestCase):
  25. def create_tf_record(self, has_additional_channels=False, num_examples=1):
  26. path = os.path.join(self.get_temp_dir(), 'tfrecord')
  27. writer = tf.python_io.TFRecordWriter(path)
  28. image_tensor = np.random.randint(255, size=(4, 5, 3)).astype(np.uint8)
  29. additional_channels_tensor = np.random.randint(
  30. 255, size=(4, 5, 1)).astype(np.uint8)
  31. flat_mask = (4 * 5) * [1.0]
  32. with self.test_session():
  33. encoded_jpeg = tf.image.encode_jpeg(tf.constant(image_tensor)).eval()
  34. encoded_additional_channels_jpeg = tf.image.encode_jpeg(
  35. tf.constant(additional_channels_tensor)).eval()
  36. for i in range(num_examples):
  37. features = {
  38. 'image/source_id': dataset_util.bytes_feature(str(i)),
  39. 'image/encoded': dataset_util.bytes_feature(encoded_jpeg),
  40. 'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
  41. 'image/height': dataset_util.int64_feature(4),
  42. 'image/width': dataset_util.int64_feature(5),
  43. 'image/object/bbox/xmin': dataset_util.float_list_feature([0.0]),
  44. 'image/object/bbox/xmax': dataset_util.float_list_feature([1.0]),
  45. 'image/object/bbox/ymin': dataset_util.float_list_feature([0.0]),
  46. 'image/object/bbox/ymax': dataset_util.float_list_feature([1.0]),
  47. 'image/object/class/label': dataset_util.int64_list_feature([2]),
  48. 'image/object/mask': dataset_util.float_list_feature(flat_mask),
  49. }
  50. if has_additional_channels:
  51. additional_channels_key = 'image/additional_channels/encoded'
  52. features[additional_channels_key] = dataset_util.bytes_list_feature(
  53. [encoded_additional_channels_jpeg] * 2)
  54. example = tf.train.Example(features=tf.train.Features(feature=features))
  55. writer.write(example.SerializeToString())
  56. writer.close()
  57. return path
  58. def test_build_tf_record_input_reader(self):
  59. tf_record_path = self.create_tf_record()
  60. input_reader_text_proto = """
  61. shuffle: false
  62. num_readers: 1
  63. tf_record_input_reader {{
  64. input_path: '{0}'
  65. }}
  66. """.format(tf_record_path)
  67. input_reader_proto = input_reader_pb2.InputReader()
  68. text_format.Merge(input_reader_text_proto, input_reader_proto)
  69. tensor_dict = dataset_builder.make_initializable_iterator(
  70. dataset_builder.build(input_reader_proto, batch_size=1)).get_next()
  71. with tf.train.MonitoredSession() as sess:
  72. output_dict = sess.run(tensor_dict)
  73. self.assertTrue(
  74. fields.InputDataFields.groundtruth_instance_masks not in output_dict)
  75. self.assertEquals((1, 4, 5, 3),
  76. output_dict[fields.InputDataFields.image].shape)
  77. self.assertAllEqual([[2]],
  78. output_dict[fields.InputDataFields.groundtruth_classes])
  79. self.assertEquals(
  80. (1, 1, 4), output_dict[fields.InputDataFields.groundtruth_boxes].shape)
  81. self.assertAllEqual(
  82. [0.0, 0.0, 1.0, 1.0],
  83. output_dict[fields.InputDataFields.groundtruth_boxes][0][0])
  84. def test_build_tf_record_input_reader_and_load_instance_masks(self):
  85. tf_record_path = self.create_tf_record()
  86. input_reader_text_proto = """
  87. shuffle: false
  88. num_readers: 1
  89. load_instance_masks: true
  90. tf_record_input_reader {{
  91. input_path: '{0}'
  92. }}
  93. """.format(tf_record_path)
  94. input_reader_proto = input_reader_pb2.InputReader()
  95. text_format.Merge(input_reader_text_proto, input_reader_proto)
  96. tensor_dict = dataset_builder.make_initializable_iterator(
  97. dataset_builder.build(input_reader_proto, batch_size=1)).get_next()
  98. with tf.train.MonitoredSession() as sess:
  99. output_dict = sess.run(tensor_dict)
  100. self.assertAllEqual(
  101. (1, 1, 4, 5),
  102. output_dict[fields.InputDataFields.groundtruth_instance_masks].shape)
  103. def test_build_tf_record_input_reader_with_batch_size_two(self):
  104. tf_record_path = self.create_tf_record()
  105. input_reader_text_proto = """
  106. shuffle: false
  107. num_readers: 1
  108. tf_record_input_reader {{
  109. input_path: '{0}'
  110. }}
  111. """.format(tf_record_path)
  112. input_reader_proto = input_reader_pb2.InputReader()
  113. text_format.Merge(input_reader_text_proto, input_reader_proto)
  114. def one_hot_class_encoding_fn(tensor_dict):
  115. tensor_dict[fields.InputDataFields.groundtruth_classes] = tf.one_hot(
  116. tensor_dict[fields.InputDataFields.groundtruth_classes] - 1, depth=3)
  117. return tensor_dict
  118. tensor_dict = dataset_builder.make_initializable_iterator(
  119. dataset_builder.build(
  120. input_reader_proto,
  121. transform_input_data_fn=one_hot_class_encoding_fn,
  122. batch_size=2)).get_next()
  123. with tf.train.MonitoredSession() as sess:
  124. output_dict = sess.run(tensor_dict)
  125. self.assertAllEqual([2, 4, 5, 3],
  126. output_dict[fields.InputDataFields.image].shape)
  127. self.assertAllEqual(
  128. [2, 1, 3],
  129. output_dict[fields.InputDataFields.groundtruth_classes].shape)
  130. self.assertAllEqual(
  131. [2, 1, 4], output_dict[fields.InputDataFields.groundtruth_boxes].shape)
  132. self.assertAllEqual([[[0.0, 0.0, 1.0, 1.0]], [[0.0, 0.0, 1.0, 1.0]]],
  133. output_dict[fields.InputDataFields.groundtruth_boxes])
  134. def test_build_tf_record_input_reader_with_batch_size_two_and_masks(self):
  135. tf_record_path = self.create_tf_record()
  136. input_reader_text_proto = """
  137. shuffle: false
  138. num_readers: 1
  139. load_instance_masks: true
  140. tf_record_input_reader {{
  141. input_path: '{0}'
  142. }}
  143. """.format(tf_record_path)
  144. input_reader_proto = input_reader_pb2.InputReader()
  145. text_format.Merge(input_reader_text_proto, input_reader_proto)
  146. def one_hot_class_encoding_fn(tensor_dict):
  147. tensor_dict[fields.InputDataFields.groundtruth_classes] = tf.one_hot(
  148. tensor_dict[fields.InputDataFields.groundtruth_classes] - 1, depth=3)
  149. return tensor_dict
  150. tensor_dict = dataset_builder.make_initializable_iterator(
  151. dataset_builder.build(
  152. input_reader_proto,
  153. transform_input_data_fn=one_hot_class_encoding_fn,
  154. batch_size=2)).get_next()
  155. with tf.train.MonitoredSession() as sess:
  156. output_dict = sess.run(tensor_dict)
  157. self.assertAllEqual(
  158. [2, 1, 4, 5],
  159. output_dict[fields.InputDataFields.groundtruth_instance_masks].shape)
  160. def test_raises_error_with_no_input_paths(self):
  161. input_reader_text_proto = """
  162. shuffle: false
  163. num_readers: 1
  164. load_instance_masks: true
  165. """
  166. input_reader_proto = input_reader_pb2.InputReader()
  167. text_format.Merge(input_reader_text_proto, input_reader_proto)
  168. with self.assertRaises(ValueError):
  169. dataset_builder.build(input_reader_proto, batch_size=1)
  170. def test_sample_all_data(self):
  171. tf_record_path = self.create_tf_record(num_examples=2)
  172. input_reader_text_proto = """
  173. shuffle: false
  174. num_readers: 1
  175. sample_1_of_n_examples: 1
  176. tf_record_input_reader {{
  177. input_path: '{0}'
  178. }}
  179. """.format(tf_record_path)
  180. input_reader_proto = input_reader_pb2.InputReader()
  181. text_format.Merge(input_reader_text_proto, input_reader_proto)
  182. tensor_dict = dataset_builder.make_initializable_iterator(
  183. dataset_builder.build(input_reader_proto, batch_size=1)).get_next()
  184. with tf.train.MonitoredSession() as sess:
  185. output_dict = sess.run(tensor_dict)
  186. self.assertAllEqual(['0'], output_dict[fields.InputDataFields.source_id])
  187. output_dict = sess.run(tensor_dict)
  188. self.assertEquals(['1'], output_dict[fields.InputDataFields.source_id])
  189. def test_sample_one_of_n_shards(self):
  190. tf_record_path = self.create_tf_record(num_examples=4)
  191. input_reader_text_proto = """
  192. shuffle: false
  193. num_readers: 1
  194. sample_1_of_n_examples: 2
  195. tf_record_input_reader {{
  196. input_path: '{0}'
  197. }}
  198. """.format(tf_record_path)
  199. input_reader_proto = input_reader_pb2.InputReader()
  200. text_format.Merge(input_reader_text_proto, input_reader_proto)
  201. tensor_dict = dataset_builder.make_initializable_iterator(
  202. dataset_builder.build(input_reader_proto, batch_size=1)).get_next()
  203. with tf.train.MonitoredSession() as sess:
  204. output_dict = sess.run(tensor_dict)
  205. self.assertAllEqual(['0'], output_dict[fields.InputDataFields.source_id])
  206. output_dict = sess.run(tensor_dict)
  207. self.assertEquals(['2'], output_dict[fields.InputDataFields.source_id])
  208. class ReadDatasetTest(tf.test.TestCase):
  209. def setUp(self):
  210. self._path_template = os.path.join(self.get_temp_dir(), 'examples_%s.txt')
  211. for i in range(5):
  212. path = self._path_template % i
  213. with tf.gfile.Open(path, 'wb') as f:
  214. f.write('\n'.join([str(i + 1), str((i + 1) * 10)]))
  215. self._shuffle_path_template = os.path.join(self.get_temp_dir(),
  216. 'shuffle_%s.txt')
  217. for i in range(2):
  218. path = self._shuffle_path_template % i
  219. with tf.gfile.Open(path, 'wb') as f:
  220. f.write('\n'.join([str(i)] * 5))
  221. def _get_dataset_next(self, files, config, batch_size):
  222. def decode_func(value):
  223. return [tf.string_to_number(value, out_type=tf.int32)]
  224. dataset = dataset_builder.read_dataset(tf.data.TextLineDataset, files,
  225. config)
  226. dataset = dataset.map(decode_func)
  227. dataset = dataset.batch(batch_size)
  228. return dataset.make_one_shot_iterator().get_next()
  229. def test_make_initializable_iterator_with_hashTable(self):
  230. keys = [1, 0, -1]
  231. dataset = tf.data.Dataset.from_tensor_slices([[1, 2, -1, 5]])
  232. table = tf.contrib.lookup.HashTable(
  233. initializer=tf.contrib.lookup.KeyValueTensorInitializer(
  234. keys=keys, values=list(reversed(keys))),
  235. default_value=100)
  236. dataset = dataset.map(table.lookup)
  237. data = dataset_builder.make_initializable_iterator(dataset).get_next()
  238. init = tf.tables_initializer()
  239. with self.test_session() as sess:
  240. sess.run(init)
  241. self.assertAllEqual(sess.run(data), [-1, 100, 1, 100])
  242. def test_read_dataset(self):
  243. config = input_reader_pb2.InputReader()
  244. config.num_readers = 1
  245. config.shuffle = False
  246. data = self._get_dataset_next(
  247. [self._path_template % '*'], config, batch_size=20)
  248. with self.test_session() as sess:
  249. self.assertAllEqual(
  250. sess.run(data), [[
  251. 1, 10, 2, 20, 3, 30, 4, 40, 5, 50, 1, 10, 2, 20, 3, 30, 4, 40, 5,
  252. 50
  253. ]])
  254. def test_reduce_num_reader(self):
  255. config = input_reader_pb2.InputReader()
  256. config.num_readers = 10
  257. config.shuffle = False
  258. data = self._get_dataset_next(
  259. [self._path_template % '*'], config, batch_size=20)
  260. with self.test_session() as sess:
  261. self.assertAllEqual(
  262. sess.run(data), [[
  263. 1, 10, 2, 20, 3, 30, 4, 40, 5, 50, 1, 10, 2, 20, 3, 30, 4, 40, 5,
  264. 50
  265. ]])
  266. def test_enable_shuffle(self):
  267. config = input_reader_pb2.InputReader()
  268. config.num_readers = 1
  269. config.shuffle = True
  270. tf.set_random_seed(1) # Set graph level seed.
  271. data = self._get_dataset_next(
  272. [self._shuffle_path_template % '*'], config, batch_size=10)
  273. expected_non_shuffle_output = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
  274. with self.test_session() as sess:
  275. self.assertTrue(
  276. np.any(np.not_equal(sess.run(data), expected_non_shuffle_output)))
  277. def test_disable_shuffle_(self):
  278. config = input_reader_pb2.InputReader()
  279. config.num_readers = 1
  280. config.shuffle = False
  281. data = self._get_dataset_next(
  282. [self._shuffle_path_template % '*'], config, batch_size=10)
  283. expected_non_shuffle_output = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
  284. with self.test_session() as sess:
  285. self.assertAllEqual(sess.run(data), [expected_non_shuffle_output])
  286. def test_read_dataset_single_epoch(self):
  287. config = input_reader_pb2.InputReader()
  288. config.num_epochs = 1
  289. config.num_readers = 1
  290. config.shuffle = False
  291. data = self._get_dataset_next(
  292. [self._path_template % '0'], config, batch_size=30)
  293. with self.test_session() as sess:
  294. # First batch will retrieve as much as it can, second batch will fail.
  295. self.assertAllEqual(sess.run(data), [[1, 10]])
  296. self.assertRaises(tf.errors.OutOfRangeError, sess.run, data)
  297. if __name__ == '__main__':
  298. tf.test.main()