You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1091 lines
44 KiB

  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for object_detection.tflearn.inputs."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import functools
  20. import os
  21. from absl.testing import parameterized
  22. import numpy as np
  23. import tensorflow as tf
  24. from object_detection import inputs
  25. from object_detection.core import preprocessor
  26. from object_detection.core import standard_fields as fields
  27. from object_detection.utils import config_util
  28. from object_detection.utils import test_case
  29. FLAGS = tf.flags.FLAGS
  30. def _get_configs_for_model(model_name):
  31. """Returns configurations for model."""
  32. fname = os.path.join(tf.resource_loader.get_data_files_path(),
  33. 'samples/configs/' + model_name + '.config')
  34. label_map_path = os.path.join(tf.resource_loader.get_data_files_path(),
  35. 'data/pet_label_map.pbtxt')
  36. data_path = os.path.join(tf.resource_loader.get_data_files_path(),
  37. 'test_data/pets_examples.record')
  38. configs = config_util.get_configs_from_pipeline_file(fname)
  39. override_dict = {
  40. 'train_input_path': data_path,
  41. 'eval_input_path': data_path,
  42. 'label_map_path': label_map_path
  43. }
  44. return config_util.merge_external_params_with_configs(
  45. configs, kwargs_dict=override_dict)
  46. def _make_initializable_iterator(dataset):
  47. """Creates an iterator, and initializes tables.
  48. Args:
  49. dataset: A `tf.data.Dataset` object.
  50. Returns:
  51. A `tf.data.Iterator`.
  52. """
  53. iterator = dataset.make_initializable_iterator()
  54. tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
  55. return iterator
  56. class InputsTest(test_case.TestCase, parameterized.TestCase):
  57. def test_faster_rcnn_resnet50_train_input(self):
  58. """Tests the training input function for FasterRcnnResnet50."""
  59. configs = _get_configs_for_model('faster_rcnn_resnet50_pets')
  60. model_config = configs['model']
  61. model_config.faster_rcnn.num_classes = 37
  62. train_input_fn = inputs.create_train_input_fn(
  63. configs['train_config'], configs['train_input_config'], model_config)
  64. features, labels = _make_initializable_iterator(train_input_fn()).get_next()
  65. self.assertAllEqual([1, None, None, 3],
  66. features[fields.InputDataFields.image].shape.as_list())
  67. self.assertEqual(tf.float32, features[fields.InputDataFields.image].dtype)
  68. self.assertAllEqual([1],
  69. features[inputs.HASH_KEY].shape.as_list())
  70. self.assertEqual(tf.int32, features[inputs.HASH_KEY].dtype)
  71. self.assertAllEqual(
  72. [1, 100, 4],
  73. labels[fields.InputDataFields.groundtruth_boxes].shape.as_list())
  74. self.assertEqual(tf.float32,
  75. labels[fields.InputDataFields.groundtruth_boxes].dtype)
  76. self.assertAllEqual(
  77. [1, 100, model_config.faster_rcnn.num_classes],
  78. labels[fields.InputDataFields.groundtruth_classes].shape.as_list())
  79. self.assertEqual(tf.float32,
  80. labels[fields.InputDataFields.groundtruth_classes].dtype)
  81. self.assertAllEqual(
  82. [1, 100],
  83. labels[fields.InputDataFields.groundtruth_weights].shape.as_list())
  84. self.assertEqual(tf.float32,
  85. labels[fields.InputDataFields.groundtruth_weights].dtype)
  86. self.assertAllEqual(
  87. [1, 100, model_config.faster_rcnn.num_classes],
  88. labels[fields.InputDataFields.groundtruth_confidences].shape.as_list())
  89. self.assertEqual(
  90. tf.float32,
  91. labels[fields.InputDataFields.groundtruth_confidences].dtype)
  92. def test_faster_rcnn_resnet50_train_input_with_additional_channels(self):
  93. """Tests the training input function for FasterRcnnResnet50."""
  94. configs = _get_configs_for_model('faster_rcnn_resnet50_pets')
  95. model_config = configs['model']
  96. configs['train_input_config'].num_additional_channels = 2
  97. configs['train_config'].retain_original_images = True
  98. model_config.faster_rcnn.num_classes = 37
  99. train_input_fn = inputs.create_train_input_fn(
  100. configs['train_config'], configs['train_input_config'], model_config)
  101. features, labels = _make_initializable_iterator(train_input_fn()).get_next()
  102. self.assertAllEqual([1, None, None, 5],
  103. features[fields.InputDataFields.image].shape.as_list())
  104. self.assertAllEqual(
  105. [1, None, None, 3],
  106. features[fields.InputDataFields.original_image].shape.as_list())
  107. self.assertEqual(tf.float32, features[fields.InputDataFields.image].dtype)
  108. self.assertAllEqual([1],
  109. features[inputs.HASH_KEY].shape.as_list())
  110. self.assertEqual(tf.int32, features[inputs.HASH_KEY].dtype)
  111. self.assertAllEqual(
  112. [1, 100, 4],
  113. labels[fields.InputDataFields.groundtruth_boxes].shape.as_list())
  114. self.assertEqual(tf.float32,
  115. labels[fields.InputDataFields.groundtruth_boxes].dtype)
  116. self.assertAllEqual(
  117. [1, 100, model_config.faster_rcnn.num_classes],
  118. labels[fields.InputDataFields.groundtruth_classes].shape.as_list())
  119. self.assertEqual(tf.float32,
  120. labels[fields.InputDataFields.groundtruth_classes].dtype)
  121. self.assertAllEqual(
  122. [1, 100],
  123. labels[fields.InputDataFields.groundtruth_weights].shape.as_list())
  124. self.assertEqual(tf.float32,
  125. labels[fields.InputDataFields.groundtruth_weights].dtype)
  126. self.assertAllEqual(
  127. [1, 100, model_config.faster_rcnn.num_classes],
  128. labels[fields.InputDataFields.groundtruth_confidences].shape.as_list())
  129. self.assertEqual(
  130. tf.float32,
  131. labels[fields.InputDataFields.groundtruth_confidences].dtype)
  132. @parameterized.parameters(
  133. {'eval_batch_size': 1},
  134. {'eval_batch_size': 8}
  135. )
  136. def test_faster_rcnn_resnet50_eval_input(self, eval_batch_size=1):
  137. """Tests the eval input function for FasterRcnnResnet50."""
  138. configs = _get_configs_for_model('faster_rcnn_resnet50_pets')
  139. model_config = configs['model']
  140. model_config.faster_rcnn.num_classes = 37
  141. eval_config = configs['eval_config']
  142. eval_config.batch_size = eval_batch_size
  143. eval_input_fn = inputs.create_eval_input_fn(
  144. eval_config, configs['eval_input_configs'][0], model_config)
  145. features, labels = _make_initializable_iterator(eval_input_fn()).get_next()
  146. self.assertAllEqual([eval_batch_size, None, None, 3],
  147. features[fields.InputDataFields.image].shape.as_list())
  148. self.assertEqual(tf.float32, features[fields.InputDataFields.image].dtype)
  149. self.assertAllEqual(
  150. [eval_batch_size, None, None, 3],
  151. features[fields.InputDataFields.original_image].shape.as_list())
  152. self.assertEqual(tf.uint8,
  153. features[fields.InputDataFields.original_image].dtype)
  154. self.assertAllEqual([eval_batch_size],
  155. features[inputs.HASH_KEY].shape.as_list())
  156. self.assertEqual(tf.int32, features[inputs.HASH_KEY].dtype)
  157. self.assertAllEqual(
  158. [eval_batch_size, 100, 4],
  159. labels[fields.InputDataFields.groundtruth_boxes].shape.as_list())
  160. self.assertEqual(tf.float32,
  161. labels[fields.InputDataFields.groundtruth_boxes].dtype)
  162. self.assertAllEqual(
  163. [eval_batch_size, 100, model_config.faster_rcnn.num_classes],
  164. labels[fields.InputDataFields.groundtruth_classes].shape.as_list())
  165. self.assertEqual(tf.float32,
  166. labels[fields.InputDataFields.groundtruth_classes].dtype)
  167. self.assertAllEqual(
  168. [eval_batch_size, 100],
  169. labels[fields.InputDataFields.groundtruth_weights].shape.as_list())
  170. self.assertEqual(
  171. tf.float32,
  172. labels[fields.InputDataFields.groundtruth_weights].dtype)
  173. self.assertAllEqual(
  174. [eval_batch_size, 100],
  175. labels[fields.InputDataFields.groundtruth_area].shape.as_list())
  176. self.assertEqual(tf.float32,
  177. labels[fields.InputDataFields.groundtruth_area].dtype)
  178. self.assertAllEqual(
  179. [eval_batch_size, 100],
  180. labels[fields.InputDataFields.groundtruth_is_crowd].shape.as_list())
  181. self.assertEqual(
  182. tf.bool, labels[fields.InputDataFields.groundtruth_is_crowd].dtype)
  183. self.assertAllEqual(
  184. [eval_batch_size, 100],
  185. labels[fields.InputDataFields.groundtruth_difficult].shape.as_list())
  186. self.assertEqual(
  187. tf.int32, labels[fields.InputDataFields.groundtruth_difficult].dtype)
  188. def test_ssd_inceptionV2_train_input(self):
  189. """Tests the training input function for SSDInceptionV2."""
  190. configs = _get_configs_for_model('ssd_inception_v2_pets')
  191. model_config = configs['model']
  192. model_config.ssd.num_classes = 37
  193. batch_size = configs['train_config'].batch_size
  194. train_input_fn = inputs.create_train_input_fn(
  195. configs['train_config'], configs['train_input_config'], model_config)
  196. features, labels = _make_initializable_iterator(train_input_fn()).get_next()
  197. self.assertAllEqual([batch_size, 300, 300, 3],
  198. features[fields.InputDataFields.image].shape.as_list())
  199. self.assertEqual(tf.float32, features[fields.InputDataFields.image].dtype)
  200. self.assertAllEqual([batch_size],
  201. features[inputs.HASH_KEY].shape.as_list())
  202. self.assertEqual(tf.int32, features[inputs.HASH_KEY].dtype)
  203. self.assertAllEqual(
  204. [batch_size],
  205. labels[fields.InputDataFields.num_groundtruth_boxes].shape.as_list())
  206. self.assertEqual(tf.int32,
  207. labels[fields.InputDataFields.num_groundtruth_boxes].dtype)
  208. self.assertAllEqual(
  209. [batch_size, 100, 4],
  210. labels[fields.InputDataFields.groundtruth_boxes].shape.as_list())
  211. self.assertEqual(tf.float32,
  212. labels[fields.InputDataFields.groundtruth_boxes].dtype)
  213. self.assertAllEqual(
  214. [batch_size, 100, model_config.ssd.num_classes],
  215. labels[fields.InputDataFields.groundtruth_classes].shape.as_list())
  216. self.assertEqual(tf.float32,
  217. labels[fields.InputDataFields.groundtruth_classes].dtype)
  218. self.assertAllEqual(
  219. [batch_size, 100],
  220. labels[
  221. fields.InputDataFields.groundtruth_weights].shape.as_list())
  222. self.assertEqual(
  223. tf.float32,
  224. labels[fields.InputDataFields.groundtruth_weights].dtype)
  225. @parameterized.parameters(
  226. {'eval_batch_size': 1},
  227. {'eval_batch_size': 8}
  228. )
  229. def test_ssd_inceptionV2_eval_input(self, eval_batch_size=1):
  230. """Tests the eval input function for SSDInceptionV2."""
  231. configs = _get_configs_for_model('ssd_inception_v2_pets')
  232. model_config = configs['model']
  233. model_config.ssd.num_classes = 37
  234. eval_config = configs['eval_config']
  235. eval_config.batch_size = eval_batch_size
  236. eval_input_fn = inputs.create_eval_input_fn(
  237. eval_config, configs['eval_input_configs'][0], model_config)
  238. features, labels = _make_initializable_iterator(eval_input_fn()).get_next()
  239. self.assertAllEqual([eval_batch_size, 300, 300, 3],
  240. features[fields.InputDataFields.image].shape.as_list())
  241. self.assertEqual(tf.float32, features[fields.InputDataFields.image].dtype)
  242. self.assertAllEqual(
  243. [eval_batch_size, 300, 300, 3],
  244. features[fields.InputDataFields.original_image].shape.as_list())
  245. self.assertEqual(tf.uint8,
  246. features[fields.InputDataFields.original_image].dtype)
  247. self.assertAllEqual([eval_batch_size],
  248. features[inputs.HASH_KEY].shape.as_list())
  249. self.assertEqual(tf.int32, features[inputs.HASH_KEY].dtype)
  250. self.assertAllEqual(
  251. [eval_batch_size, 100, 4],
  252. labels[fields.InputDataFields.groundtruth_boxes].shape.as_list())
  253. self.assertEqual(tf.float32,
  254. labels[fields.InputDataFields.groundtruth_boxes].dtype)
  255. self.assertAllEqual(
  256. [eval_batch_size, 100, model_config.ssd.num_classes],
  257. labels[fields.InputDataFields.groundtruth_classes].shape.as_list())
  258. self.assertEqual(tf.float32,
  259. labels[fields.InputDataFields.groundtruth_classes].dtype)
  260. self.assertAllEqual(
  261. [eval_batch_size, 100],
  262. labels[
  263. fields.InputDataFields.groundtruth_weights].shape.as_list())
  264. self.assertEqual(
  265. tf.float32,
  266. labels[fields.InputDataFields.groundtruth_weights].dtype)
  267. self.assertAllEqual(
  268. [eval_batch_size, 100],
  269. labels[fields.InputDataFields.groundtruth_area].shape.as_list())
  270. self.assertEqual(tf.float32,
  271. labels[fields.InputDataFields.groundtruth_area].dtype)
  272. self.assertAllEqual(
  273. [eval_batch_size, 100],
  274. labels[fields.InputDataFields.groundtruth_is_crowd].shape.as_list())
  275. self.assertEqual(
  276. tf.bool, labels[fields.InputDataFields.groundtruth_is_crowd].dtype)
  277. self.assertAllEqual(
  278. [eval_batch_size, 100],
  279. labels[fields.InputDataFields.groundtruth_difficult].shape.as_list())
  280. self.assertEqual(
  281. tf.int32, labels[fields.InputDataFields.groundtruth_difficult].dtype)
  282. def test_predict_input(self):
  283. """Tests the predict input function."""
  284. configs = _get_configs_for_model('ssd_inception_v2_pets')
  285. predict_input_fn = inputs.create_predict_input_fn(
  286. model_config=configs['model'],
  287. predict_input_config=configs['eval_input_configs'][0])
  288. serving_input_receiver = predict_input_fn()
  289. image = serving_input_receiver.features[fields.InputDataFields.image]
  290. receiver_tensors = serving_input_receiver.receiver_tensors[
  291. inputs.SERVING_FED_EXAMPLE_KEY]
  292. self.assertEqual([1, 300, 300, 3], image.shape.as_list())
  293. self.assertEqual(tf.float32, image.dtype)
  294. self.assertEqual(tf.string, receiver_tensors.dtype)
  295. def test_predict_input_with_additional_channels(self):
  296. """Tests the predict input function with additional channels."""
  297. configs = _get_configs_for_model('ssd_inception_v2_pets')
  298. configs['eval_input_configs'][0].num_additional_channels = 2
  299. predict_input_fn = inputs.create_predict_input_fn(
  300. model_config=configs['model'],
  301. predict_input_config=configs['eval_input_configs'][0])
  302. serving_input_receiver = predict_input_fn()
  303. image = serving_input_receiver.features[fields.InputDataFields.image]
  304. receiver_tensors = serving_input_receiver.receiver_tensors[
  305. inputs.SERVING_FED_EXAMPLE_KEY]
  306. # RGB + 2 additional channels = 5 channels.
  307. self.assertEqual([1, 300, 300, 5], image.shape.as_list())
  308. self.assertEqual(tf.float32, image.dtype)
  309. self.assertEqual(tf.string, receiver_tensors.dtype)
  310. def test_error_with_bad_train_config(self):
  311. """Tests that a TypeError is raised with improper train config."""
  312. configs = _get_configs_for_model('ssd_inception_v2_pets')
  313. configs['model'].ssd.num_classes = 37
  314. train_input_fn = inputs.create_train_input_fn(
  315. train_config=configs['eval_config'], # Expecting `TrainConfig`.
  316. train_input_config=configs['train_input_config'],
  317. model_config=configs['model'])
  318. with self.assertRaises(TypeError):
  319. train_input_fn()
  320. def test_error_with_bad_train_input_config(self):
  321. """Tests that a TypeError is raised with improper train input config."""
  322. configs = _get_configs_for_model('ssd_inception_v2_pets')
  323. configs['model'].ssd.num_classes = 37
  324. train_input_fn = inputs.create_train_input_fn(
  325. train_config=configs['train_config'],
  326. train_input_config=configs['model'], # Expecting `InputReader`.
  327. model_config=configs['model'])
  328. with self.assertRaises(TypeError):
  329. train_input_fn()
  330. def test_error_with_bad_train_model_config(self):
  331. """Tests that a TypeError is raised with improper train model config."""
  332. configs = _get_configs_for_model('ssd_inception_v2_pets')
  333. configs['model'].ssd.num_classes = 37
  334. train_input_fn = inputs.create_train_input_fn(
  335. train_config=configs['train_config'],
  336. train_input_config=configs['train_input_config'],
  337. model_config=configs['train_config']) # Expecting `DetectionModel`.
  338. with self.assertRaises(TypeError):
  339. train_input_fn()
  340. def test_error_with_bad_eval_config(self):
  341. """Tests that a TypeError is raised with improper eval config."""
  342. configs = _get_configs_for_model('ssd_inception_v2_pets')
  343. configs['model'].ssd.num_classes = 37
  344. eval_input_fn = inputs.create_eval_input_fn(
  345. eval_config=configs['train_config'], # Expecting `EvalConfig`.
  346. eval_input_config=configs['eval_input_configs'][0],
  347. model_config=configs['model'])
  348. with self.assertRaises(TypeError):
  349. eval_input_fn()
  350. def test_error_with_bad_eval_input_config(self):
  351. """Tests that a TypeError is raised with improper eval input config."""
  352. configs = _get_configs_for_model('ssd_inception_v2_pets')
  353. configs['model'].ssd.num_classes = 37
  354. eval_input_fn = inputs.create_eval_input_fn(
  355. eval_config=configs['eval_config'],
  356. eval_input_config=configs['model'], # Expecting `InputReader`.
  357. model_config=configs['model'])
  358. with self.assertRaises(TypeError):
  359. eval_input_fn()
  360. def test_error_with_bad_eval_model_config(self):
  361. """Tests that a TypeError is raised with improper eval model config."""
  362. configs = _get_configs_for_model('ssd_inception_v2_pets')
  363. configs['model'].ssd.num_classes = 37
  364. eval_input_fn = inputs.create_eval_input_fn(
  365. eval_config=configs['eval_config'],
  366. eval_input_config=configs['eval_input_configs'][0],
  367. model_config=configs['eval_config']) # Expecting `DetectionModel`.
  368. with self.assertRaises(TypeError):
  369. eval_input_fn()
  370. def test_output_equal_in_replace_empty_string_with_random_number(self):
  371. string_placeholder = tf.placeholder(tf.string, shape=[])
  372. replaced_string = inputs._replace_empty_string_with_random_number(
  373. string_placeholder)
  374. test_string = 'hello world'
  375. feed_dict = {string_placeholder: test_string}
  376. with self.test_session() as sess:
  377. out_string = sess.run(replaced_string, feed_dict=feed_dict)
  378. self.assertEqual(test_string, out_string)
  379. def test_output_is_integer_in_replace_empty_string_with_random_number(self):
  380. string_placeholder = tf.placeholder(tf.string, shape=[])
  381. replaced_string = inputs._replace_empty_string_with_random_number(
  382. string_placeholder)
  383. empty_string = ''
  384. feed_dict = {string_placeholder: empty_string}
  385. tf.set_random_seed(0)
  386. with self.test_session() as sess:
  387. out_string = sess.run(replaced_string, feed_dict=feed_dict)
  388. # Test whether out_string is a string which represents an integer.
  389. int(out_string) # throws an error if out_string is not castable to int.
  390. self.assertEqual(out_string, '2798129067578209328')
  391. class DataAugmentationFnTest(test_case.TestCase):
  392. def test_apply_image_and_box_augmentation(self):
  393. data_augmentation_options = [
  394. (preprocessor.resize_image, {
  395. 'new_height': 20,
  396. 'new_width': 20,
  397. 'method': tf.image.ResizeMethod.NEAREST_NEIGHBOR
  398. }),
  399. (preprocessor.scale_boxes_to_pixel_coordinates, {}),
  400. ]
  401. data_augmentation_fn = functools.partial(
  402. inputs.augment_input_data,
  403. data_augmentation_options=data_augmentation_options)
  404. tensor_dict = {
  405. fields.InputDataFields.image:
  406. tf.constant(np.random.rand(10, 10, 3).astype(np.float32)),
  407. fields.InputDataFields.groundtruth_boxes:
  408. tf.constant(np.array([[.5, .5, 1., 1.]], np.float32))
  409. }
  410. augmented_tensor_dict = data_augmentation_fn(tensor_dict=tensor_dict)
  411. with self.test_session() as sess:
  412. augmented_tensor_dict_out = sess.run(augmented_tensor_dict)
  413. self.assertAllEqual(
  414. augmented_tensor_dict_out[fields.InputDataFields.image].shape,
  415. [20, 20, 3]
  416. )
  417. self.assertAllClose(
  418. augmented_tensor_dict_out[fields.InputDataFields.groundtruth_boxes],
  419. [[10, 10, 20, 20]]
  420. )
  421. def test_apply_image_and_box_augmentation_with_scores(self):
  422. data_augmentation_options = [
  423. (preprocessor.resize_image, {
  424. 'new_height': 20,
  425. 'new_width': 20,
  426. 'method': tf.image.ResizeMethod.NEAREST_NEIGHBOR
  427. }),
  428. (preprocessor.scale_boxes_to_pixel_coordinates, {}),
  429. ]
  430. data_augmentation_fn = functools.partial(
  431. inputs.augment_input_data,
  432. data_augmentation_options=data_augmentation_options)
  433. tensor_dict = {
  434. fields.InputDataFields.image:
  435. tf.constant(np.random.rand(10, 10, 3).astype(np.float32)),
  436. fields.InputDataFields.groundtruth_boxes:
  437. tf.constant(np.array([[.5, .5, 1., 1.]], np.float32)),
  438. fields.InputDataFields.groundtruth_classes:
  439. tf.constant(np.array([1.0], np.float32)),
  440. fields.InputDataFields.groundtruth_weights:
  441. tf.constant(np.array([0.8], np.float32)),
  442. }
  443. augmented_tensor_dict = data_augmentation_fn(tensor_dict=tensor_dict)
  444. with self.test_session() as sess:
  445. augmented_tensor_dict_out = sess.run(augmented_tensor_dict)
  446. self.assertAllEqual(
  447. augmented_tensor_dict_out[fields.InputDataFields.image].shape,
  448. [20, 20, 3]
  449. )
  450. self.assertAllClose(
  451. augmented_tensor_dict_out[fields.InputDataFields.groundtruth_boxes],
  452. [[10, 10, 20, 20]]
  453. )
  454. self.assertAllClose(
  455. augmented_tensor_dict_out[fields.InputDataFields.groundtruth_classes],
  456. [1.0]
  457. )
  458. self.assertAllClose(
  459. augmented_tensor_dict_out[
  460. fields.InputDataFields.groundtruth_weights],
  461. [0.8]
  462. )
  463. def test_include_masks_in_data_augmentation(self):
  464. data_augmentation_options = [
  465. (preprocessor.resize_image, {
  466. 'new_height': 20,
  467. 'new_width': 20,
  468. 'method': tf.image.ResizeMethod.NEAREST_NEIGHBOR
  469. })
  470. ]
  471. data_augmentation_fn = functools.partial(
  472. inputs.augment_input_data,
  473. data_augmentation_options=data_augmentation_options)
  474. tensor_dict = {
  475. fields.InputDataFields.image:
  476. tf.constant(np.random.rand(10, 10, 3).astype(np.float32)),
  477. fields.InputDataFields.groundtruth_instance_masks:
  478. tf.constant(np.zeros([2, 10, 10], np.uint8))
  479. }
  480. augmented_tensor_dict = data_augmentation_fn(tensor_dict=tensor_dict)
  481. with self.test_session() as sess:
  482. augmented_tensor_dict_out = sess.run(augmented_tensor_dict)
  483. self.assertAllEqual(
  484. augmented_tensor_dict_out[fields.InputDataFields.image].shape,
  485. [20, 20, 3])
  486. self.assertAllEqual(augmented_tensor_dict_out[
  487. fields.InputDataFields.groundtruth_instance_masks].shape, [2, 20, 20])
  488. def test_include_keypoints_in_data_augmentation(self):
  489. data_augmentation_options = [
  490. (preprocessor.resize_image, {
  491. 'new_height': 20,
  492. 'new_width': 20,
  493. 'method': tf.image.ResizeMethod.NEAREST_NEIGHBOR
  494. }),
  495. (preprocessor.scale_boxes_to_pixel_coordinates, {}),
  496. ]
  497. data_augmentation_fn = functools.partial(
  498. inputs.augment_input_data,
  499. data_augmentation_options=data_augmentation_options)
  500. tensor_dict = {
  501. fields.InputDataFields.image:
  502. tf.constant(np.random.rand(10, 10, 3).astype(np.float32)),
  503. fields.InputDataFields.groundtruth_boxes:
  504. tf.constant(np.array([[.5, .5, 1., 1.]], np.float32)),
  505. fields.InputDataFields.groundtruth_keypoints:
  506. tf.constant(np.array([[[0.5, 1.0], [0.5, 0.5]]], np.float32))
  507. }
  508. augmented_tensor_dict = data_augmentation_fn(tensor_dict=tensor_dict)
  509. with self.test_session() as sess:
  510. augmented_tensor_dict_out = sess.run(augmented_tensor_dict)
  511. self.assertAllEqual(
  512. augmented_tensor_dict_out[fields.InputDataFields.image].shape,
  513. [20, 20, 3]
  514. )
  515. self.assertAllClose(
  516. augmented_tensor_dict_out[fields.InputDataFields.groundtruth_boxes],
  517. [[10, 10, 20, 20]]
  518. )
  519. self.assertAllClose(
  520. augmented_tensor_dict_out[fields.InputDataFields.groundtruth_keypoints],
  521. [[[10, 20], [10, 10]]]
  522. )
  523. def _fake_model_preprocessor_fn(image):
  524. return (image, tf.expand_dims(tf.shape(image)[1:], axis=0))
  525. def _fake_image_resizer_fn(image, mask):
  526. return (image, mask, tf.shape(image))
  527. class DataTransformationFnTest(test_case.TestCase):
  528. def test_combine_additional_channels_if_present(self):
  529. image = np.random.rand(4, 4, 3).astype(np.float32)
  530. additional_channels = np.random.rand(4, 4, 2).astype(np.float32)
  531. tensor_dict = {
  532. fields.InputDataFields.image:
  533. tf.constant(image),
  534. fields.InputDataFields.image_additional_channels:
  535. tf.constant(additional_channels),
  536. fields.InputDataFields.groundtruth_classes:
  537. tf.constant(np.array([1, 1], np.int32))
  538. }
  539. input_transformation_fn = functools.partial(
  540. inputs.transform_input_data,
  541. model_preprocess_fn=_fake_model_preprocessor_fn,
  542. image_resizer_fn=_fake_image_resizer_fn,
  543. num_classes=1)
  544. with self.test_session() as sess:
  545. transformed_inputs = sess.run(
  546. input_transformation_fn(tensor_dict=tensor_dict))
  547. self.assertAllEqual(transformed_inputs[fields.InputDataFields.image].dtype,
  548. tf.float32)
  549. self.assertAllEqual(transformed_inputs[fields.InputDataFields.image].shape,
  550. [4, 4, 5])
  551. self.assertAllClose(transformed_inputs[fields.InputDataFields.image],
  552. np.concatenate((image, additional_channels), axis=2))
  553. def test_returns_correct_class_label_encodings(self):
  554. tensor_dict = {
  555. fields.InputDataFields.image:
  556. tf.constant(np.random.rand(4, 4, 3).astype(np.float32)),
  557. fields.InputDataFields.groundtruth_boxes:
  558. tf.constant(np.array([[0, 0, 1, 1], [.5, .5, 1, 1]], np.float32)),
  559. fields.InputDataFields.groundtruth_classes:
  560. tf.constant(np.array([3, 1], np.int32))
  561. }
  562. num_classes = 3
  563. input_transformation_fn = functools.partial(
  564. inputs.transform_input_data,
  565. model_preprocess_fn=_fake_model_preprocessor_fn,
  566. image_resizer_fn=_fake_image_resizer_fn,
  567. num_classes=num_classes)
  568. with self.test_session() as sess:
  569. transformed_inputs = sess.run(
  570. input_transformation_fn(tensor_dict=tensor_dict))
  571. self.assertAllClose(
  572. transformed_inputs[fields.InputDataFields.groundtruth_classes],
  573. [[0, 0, 1], [1, 0, 0]])
  574. self.assertAllClose(
  575. transformed_inputs[fields.InputDataFields.groundtruth_confidences],
  576. [[0, 0, 1], [1, 0, 0]])
  577. def test_returns_correct_labels_with_unrecognized_class(self):
  578. tensor_dict = {
  579. fields.InputDataFields.image:
  580. tf.constant(np.random.rand(4, 4, 3).astype(np.float32)),
  581. fields.InputDataFields.groundtruth_boxes:
  582. tf.constant(
  583. np.array([[0, 0, 1, 1], [.2, .2, 4, 4], [.5, .5, 1, 1]],
  584. np.float32)),
  585. fields.InputDataFields.groundtruth_area:
  586. tf.constant(np.array([.5, .4, .3])),
  587. fields.InputDataFields.groundtruth_classes:
  588. tf.constant(np.array([3, -1, 1], np.int32)),
  589. fields.InputDataFields.groundtruth_keypoints:
  590. tf.constant(
  591. np.array([[[.1, .1]], [[.2, .2]], [[.5, .5]]],
  592. np.float32)),
  593. fields.InputDataFields.groundtruth_keypoint_visibilities:
  594. tf.constant([True, False, True]),
  595. fields.InputDataFields.groundtruth_instance_masks:
  596. tf.constant(np.random.rand(3, 4, 4).astype(np.float32)),
  597. fields.InputDataFields.groundtruth_is_crowd:
  598. tf.constant([False, True, False]),
  599. fields.InputDataFields.groundtruth_difficult:
  600. tf.constant(np.array([0, 0, 1], np.int32))
  601. }
  602. num_classes = 3
  603. input_transformation_fn = functools.partial(
  604. inputs.transform_input_data,
  605. model_preprocess_fn=_fake_model_preprocessor_fn,
  606. image_resizer_fn=_fake_image_resizer_fn,
  607. num_classes=num_classes)
  608. with self.test_session() as sess:
  609. transformed_inputs = sess.run(
  610. input_transformation_fn(tensor_dict=tensor_dict))
  611. self.assertAllClose(
  612. transformed_inputs[fields.InputDataFields.groundtruth_classes],
  613. [[0, 0, 1], [1, 0, 0]])
  614. self.assertAllEqual(
  615. transformed_inputs[fields.InputDataFields.num_groundtruth_boxes], 2)
  616. self.assertAllClose(
  617. transformed_inputs[fields.InputDataFields.groundtruth_area], [.5, .3])
  618. self.assertAllEqual(
  619. transformed_inputs[fields.InputDataFields.groundtruth_confidences],
  620. [[0, 0, 1], [1, 0, 0]])
  621. self.assertAllClose(
  622. transformed_inputs[fields.InputDataFields.groundtruth_boxes],
  623. [[0, 0, 1, 1], [.5, .5, 1, 1]])
  624. self.assertAllClose(
  625. transformed_inputs[fields.InputDataFields.groundtruth_keypoints],
  626. [[[.1, .1]], [[.5, .5]]])
  627. self.assertAllEqual(
  628. transformed_inputs[
  629. fields.InputDataFields.groundtruth_keypoint_visibilities],
  630. [True, True])
  631. self.assertAllEqual(
  632. transformed_inputs[
  633. fields.InputDataFields.groundtruth_instance_masks].shape, [2, 4, 4])
  634. self.assertAllEqual(
  635. transformed_inputs[fields.InputDataFields.groundtruth_is_crowd],
  636. [False, False])
  637. self.assertAllEqual(
  638. transformed_inputs[fields.InputDataFields.groundtruth_difficult],
  639. [0, 1])
  640. def test_returns_correct_merged_boxes(self):
  641. tensor_dict = {
  642. fields.InputDataFields.image:
  643. tf.constant(np.random.rand(4, 4, 3).astype(np.float32)),
  644. fields.InputDataFields.groundtruth_boxes:
  645. tf.constant(np.array([[.5, .5, 1, 1], [.5, .5, 1, 1]], np.float32)),
  646. fields.InputDataFields.groundtruth_classes:
  647. tf.constant(np.array([3, 1], np.int32))
  648. }
  649. num_classes = 3
  650. input_transformation_fn = functools.partial(
  651. inputs.transform_input_data,
  652. model_preprocess_fn=_fake_model_preprocessor_fn,
  653. image_resizer_fn=_fake_image_resizer_fn,
  654. num_classes=num_classes,
  655. merge_multiple_boxes=True)
  656. with self.test_session() as sess:
  657. transformed_inputs = sess.run(
  658. input_transformation_fn(tensor_dict=tensor_dict))
  659. self.assertAllClose(
  660. transformed_inputs[fields.InputDataFields.groundtruth_boxes],
  661. [[.5, .5, 1., 1.]])
  662. self.assertAllClose(
  663. transformed_inputs[fields.InputDataFields.groundtruth_classes],
  664. [[1, 0, 1]])
  665. self.assertAllClose(
  666. transformed_inputs[fields.InputDataFields.groundtruth_confidences],
  667. [[1, 0, 1]])
  668. self.assertAllClose(
  669. transformed_inputs[fields.InputDataFields.num_groundtruth_boxes],
  670. 1)
  671. def test_returns_correct_groundtruth_confidences_when_input_present(self):
  672. tensor_dict = {
  673. fields.InputDataFields.image:
  674. tf.constant(np.random.rand(4, 4, 3).astype(np.float32)),
  675. fields.InputDataFields.groundtruth_boxes:
  676. tf.constant(np.array([[0, 0, 1, 1], [.5, .5, 1, 1]], np.float32)),
  677. fields.InputDataFields.groundtruth_classes:
  678. tf.constant(np.array([3, 1], np.int32)),
  679. fields.InputDataFields.groundtruth_confidences:
  680. tf.constant(np.array([1.0, -1.0], np.float32))
  681. }
  682. num_classes = 3
  683. input_transformation_fn = functools.partial(
  684. inputs.transform_input_data,
  685. model_preprocess_fn=_fake_model_preprocessor_fn,
  686. image_resizer_fn=_fake_image_resizer_fn,
  687. num_classes=num_classes)
  688. with self.test_session() as sess:
  689. transformed_inputs = sess.run(
  690. input_transformation_fn(tensor_dict=tensor_dict))
  691. self.assertAllClose(
  692. transformed_inputs[fields.InputDataFields.groundtruth_classes],
  693. [[0, 0, 1], [1, 0, 0]])
  694. self.assertAllClose(
  695. transformed_inputs[fields.InputDataFields.groundtruth_confidences],
  696. [[0, 0, 1], [-1, 0, 0]])
  697. def test_returns_resized_masks(self):
  698. tensor_dict = {
  699. fields.InputDataFields.image:
  700. tf.constant(np.random.rand(4, 4, 3).astype(np.float32)),
  701. fields.InputDataFields.groundtruth_instance_masks:
  702. tf.constant(np.random.rand(2, 4, 4).astype(np.float32)),
  703. fields.InputDataFields.groundtruth_classes:
  704. tf.constant(np.array([3, 1], np.int32)),
  705. fields.InputDataFields.original_image_spatial_shape:
  706. tf.constant(np.array([4, 4], np.int32))
  707. }
  708. def fake_image_resizer_fn(image, masks=None):
  709. resized_image = tf.image.resize_images(image, [8, 8])
  710. results = [resized_image]
  711. if masks is not None:
  712. resized_masks = tf.transpose(
  713. tf.image.resize_images(tf.transpose(masks, [1, 2, 0]), [8, 8]),
  714. [2, 0, 1])
  715. results.append(resized_masks)
  716. results.append(tf.shape(resized_image))
  717. return results
  718. num_classes = 3
  719. input_transformation_fn = functools.partial(
  720. inputs.transform_input_data,
  721. model_preprocess_fn=_fake_model_preprocessor_fn,
  722. image_resizer_fn=fake_image_resizer_fn,
  723. num_classes=num_classes,
  724. retain_original_image=True)
  725. with self.test_session() as sess:
  726. transformed_inputs = sess.run(
  727. input_transformation_fn(tensor_dict=tensor_dict))
  728. self.assertAllEqual(transformed_inputs[
  729. fields.InputDataFields.original_image].dtype, tf.uint8)
  730. self.assertAllEqual(transformed_inputs[
  731. fields.InputDataFields.original_image_spatial_shape], [4, 4])
  732. self.assertAllEqual(transformed_inputs[
  733. fields.InputDataFields.original_image].shape, [8, 8, 3])
  734. self.assertAllEqual(transformed_inputs[
  735. fields.InputDataFields.groundtruth_instance_masks].shape, [2, 8, 8])
  736. def test_applies_model_preprocess_fn_to_image_tensor(self):
  737. np_image = np.random.randint(256, size=(4, 4, 3))
  738. tensor_dict = {
  739. fields.InputDataFields.image:
  740. tf.constant(np_image),
  741. fields.InputDataFields.groundtruth_classes:
  742. tf.constant(np.array([3, 1], np.int32))
  743. }
  744. def fake_model_preprocessor_fn(image):
  745. return (image / 255., tf.expand_dims(tf.shape(image)[1:], axis=0))
  746. num_classes = 3
  747. input_transformation_fn = functools.partial(
  748. inputs.transform_input_data,
  749. model_preprocess_fn=fake_model_preprocessor_fn,
  750. image_resizer_fn=_fake_image_resizer_fn,
  751. num_classes=num_classes)
  752. with self.test_session() as sess:
  753. transformed_inputs = sess.run(
  754. input_transformation_fn(tensor_dict=tensor_dict))
  755. self.assertAllClose(transformed_inputs[fields.InputDataFields.image],
  756. np_image / 255.)
  757. self.assertAllClose(transformed_inputs[fields.InputDataFields.
  758. true_image_shape],
  759. [4, 4, 3])
  760. def test_applies_data_augmentation_fn_to_tensor_dict(self):
  761. np_image = np.random.randint(256, size=(4, 4, 3))
  762. tensor_dict = {
  763. fields.InputDataFields.image:
  764. tf.constant(np_image),
  765. fields.InputDataFields.groundtruth_classes:
  766. tf.constant(np.array([3, 1], np.int32))
  767. }
  768. def add_one_data_augmentation_fn(tensor_dict):
  769. return {key: value + 1 for key, value in tensor_dict.items()}
  770. num_classes = 4
  771. input_transformation_fn = functools.partial(
  772. inputs.transform_input_data,
  773. model_preprocess_fn=_fake_model_preprocessor_fn,
  774. image_resizer_fn=_fake_image_resizer_fn,
  775. num_classes=num_classes,
  776. data_augmentation_fn=add_one_data_augmentation_fn)
  777. with self.test_session() as sess:
  778. augmented_tensor_dict = sess.run(
  779. input_transformation_fn(tensor_dict=tensor_dict))
  780. self.assertAllEqual(augmented_tensor_dict[fields.InputDataFields.image],
  781. np_image + 1)
  782. self.assertAllEqual(
  783. augmented_tensor_dict[fields.InputDataFields.groundtruth_classes],
  784. [[0, 0, 0, 1], [0, 1, 0, 0]])
  785. def test_applies_data_augmentation_fn_before_model_preprocess_fn(self):
  786. np_image = np.random.randint(256, size=(4, 4, 3))
  787. tensor_dict = {
  788. fields.InputDataFields.image:
  789. tf.constant(np_image),
  790. fields.InputDataFields.groundtruth_classes:
  791. tf.constant(np.array([3, 1], np.int32))
  792. }
  793. def mul_two_model_preprocessor_fn(image):
  794. return (image * 2, tf.expand_dims(tf.shape(image)[1:], axis=0))
  795. def add_five_to_image_data_augmentation_fn(tensor_dict):
  796. tensor_dict[fields.InputDataFields.image] += 5
  797. return tensor_dict
  798. num_classes = 4
  799. input_transformation_fn = functools.partial(
  800. inputs.transform_input_data,
  801. model_preprocess_fn=mul_two_model_preprocessor_fn,
  802. image_resizer_fn=_fake_image_resizer_fn,
  803. num_classes=num_classes,
  804. data_augmentation_fn=add_five_to_image_data_augmentation_fn)
  805. with self.test_session() as sess:
  806. augmented_tensor_dict = sess.run(
  807. input_transformation_fn(tensor_dict=tensor_dict))
  808. self.assertAllEqual(augmented_tensor_dict[fields.InputDataFields.image],
  809. (np_image + 5) * 2)
  810. class PadInputDataToStaticShapesFnTest(test_case.TestCase):
  811. def test_pad_images_boxes_and_classes(self):
  812. input_tensor_dict = {
  813. fields.InputDataFields.image:
  814. tf.placeholder(tf.float32, [None, None, 3]),
  815. fields.InputDataFields.groundtruth_boxes:
  816. tf.placeholder(tf.float32, [None, 4]),
  817. fields.InputDataFields.groundtruth_classes:
  818. tf.placeholder(tf.int32, [None, 3]),
  819. fields.InputDataFields.true_image_shape:
  820. tf.placeholder(tf.int32, [3]),
  821. fields.InputDataFields.original_image_spatial_shape:
  822. tf.placeholder(tf.int32, [2])
  823. }
  824. padded_tensor_dict = inputs.pad_input_data_to_static_shapes(
  825. tensor_dict=input_tensor_dict,
  826. max_num_boxes=3,
  827. num_classes=3,
  828. spatial_image_shape=[5, 6])
  829. self.assertAllEqual(
  830. padded_tensor_dict[fields.InputDataFields.image].shape.as_list(),
  831. [5, 6, 3])
  832. self.assertAllEqual(
  833. padded_tensor_dict[fields.InputDataFields.true_image_shape]
  834. .shape.as_list(), [3])
  835. self.assertAllEqual(
  836. padded_tensor_dict[fields.InputDataFields.original_image_spatial_shape]
  837. .shape.as_list(), [2])
  838. self.assertAllEqual(
  839. padded_tensor_dict[fields.InputDataFields.groundtruth_boxes]
  840. .shape.as_list(), [3, 4])
  841. self.assertAllEqual(
  842. padded_tensor_dict[fields.InputDataFields.groundtruth_classes]
  843. .shape.as_list(), [3, 3])
  844. def test_clip_boxes_and_classes(self):
  845. input_tensor_dict = {
  846. fields.InputDataFields.groundtruth_boxes:
  847. tf.placeholder(tf.float32, [None, 4]),
  848. fields.InputDataFields.groundtruth_classes:
  849. tf.placeholder(tf.int32, [None, 3]),
  850. fields.InputDataFields.num_groundtruth_boxes:
  851. tf.placeholder(tf.int32, [])
  852. }
  853. padded_tensor_dict = inputs.pad_input_data_to_static_shapes(
  854. tensor_dict=input_tensor_dict,
  855. max_num_boxes=3,
  856. num_classes=3,
  857. spatial_image_shape=[5, 6])
  858. self.assertAllEqual(
  859. padded_tensor_dict[fields.InputDataFields.groundtruth_boxes]
  860. .shape.as_list(), [3, 4])
  861. self.assertAllEqual(
  862. padded_tensor_dict[fields.InputDataFields.groundtruth_classes]
  863. .shape.as_list(), [3, 3])
  864. with self.test_session() as sess:
  865. out_tensor_dict = sess.run(
  866. padded_tensor_dict,
  867. feed_dict={
  868. input_tensor_dict[fields.InputDataFields.groundtruth_boxes]:
  869. np.random.rand(5, 4),
  870. input_tensor_dict[fields.InputDataFields.groundtruth_classes]:
  871. np.random.rand(2, 3),
  872. input_tensor_dict[fields.InputDataFields.num_groundtruth_boxes]:
  873. 5,
  874. })
  875. self.assertAllEqual(
  876. out_tensor_dict[fields.InputDataFields.groundtruth_boxes].shape, [3, 4])
  877. self.assertAllEqual(
  878. out_tensor_dict[fields.InputDataFields.groundtruth_classes].shape,
  879. [3, 3])
  880. self.assertEqual(
  881. out_tensor_dict[fields.InputDataFields.num_groundtruth_boxes],
  882. 3)
  883. def test_do_not_pad_dynamic_images(self):
  884. input_tensor_dict = {
  885. fields.InputDataFields.image:
  886. tf.placeholder(tf.float32, [None, None, 3]),
  887. }
  888. padded_tensor_dict = inputs.pad_input_data_to_static_shapes(
  889. tensor_dict=input_tensor_dict,
  890. max_num_boxes=3,
  891. num_classes=3,
  892. spatial_image_shape=[None, None])
  893. self.assertAllEqual(
  894. padded_tensor_dict[fields.InputDataFields.image].shape.as_list(),
  895. [None, None, 3])
  896. def test_images_and_additional_channels(self):
  897. input_tensor_dict = {
  898. fields.InputDataFields.image:
  899. tf.placeholder(tf.float32, [None, None, 5]),
  900. fields.InputDataFields.image_additional_channels:
  901. tf.placeholder(tf.float32, [None, None, 2]),
  902. }
  903. padded_tensor_dict = inputs.pad_input_data_to_static_shapes(
  904. tensor_dict=input_tensor_dict,
  905. max_num_boxes=3,
  906. num_classes=3,
  907. spatial_image_shape=[5, 6])
  908. # pad_input_data_to_static_shape assumes that image is already concatenated
  909. # with additional channels.
  910. self.assertAllEqual(
  911. padded_tensor_dict[fields.InputDataFields.image].shape.as_list(),
  912. [5, 6, 5])
  913. self.assertAllEqual(
  914. padded_tensor_dict[fields.InputDataFields.image_additional_channels]
  915. .shape.as_list(), [5, 6, 2])
  916. def test_images_and_additional_channels_errors(self):
  917. input_tensor_dict = {
  918. fields.InputDataFields.image:
  919. tf.placeholder(tf.float32, [None, None, 3]),
  920. fields.InputDataFields.image_additional_channels:
  921. tf.placeholder(tf.float32, [None, None, 2]),
  922. fields.InputDataFields.original_image:
  923. tf.placeholder(tf.float32, [None, None, 3]),
  924. }
  925. with self.assertRaises(ValueError):
  926. _ = inputs.pad_input_data_to_static_shapes(
  927. tensor_dict=input_tensor_dict,
  928. max_num_boxes=3,
  929. num_classes=3,
  930. spatial_image_shape=[5, 6])
  931. def test_gray_images(self):
  932. input_tensor_dict = {
  933. fields.InputDataFields.image:
  934. tf.placeholder(tf.float32, [None, None, 1]),
  935. }
  936. padded_tensor_dict = inputs.pad_input_data_to_static_shapes(
  937. tensor_dict=input_tensor_dict,
  938. max_num_boxes=3,
  939. num_classes=3,
  940. spatial_image_shape=[5, 6])
  941. self.assertAllEqual(
  942. padded_tensor_dict[fields.InputDataFields.image].shape.as_list(),
  943. [5, 6, 1])
  944. def test_gray_images_and_additional_channels(self):
  945. input_tensor_dict = {
  946. fields.InputDataFields.image:
  947. tf.placeholder(tf.float32, [None, None, 3]),
  948. fields.InputDataFields.image_additional_channels:
  949. tf.placeholder(tf.float32, [None, None, 2]),
  950. }
  951. # pad_input_data_to_static_shape assumes that image is already concatenated
  952. # with additional channels.
  953. padded_tensor_dict = inputs.pad_input_data_to_static_shapes(
  954. tensor_dict=input_tensor_dict,
  955. max_num_boxes=3,
  956. num_classes=3,
  957. spatial_image_shape=[5, 6])
  958. self.assertAllEqual(
  959. padded_tensor_dict[fields.InputDataFields.image].shape.as_list(),
  960. [5, 6, 3])
  961. self.assertAllEqual(
  962. padded_tensor_dict[fields.InputDataFields.image_additional_channels]
  963. .shape.as_list(), [5, 6, 2])
  964. def test_keypoints(self):
  965. input_tensor_dict = {
  966. fields.InputDataFields.groundtruth_keypoints:
  967. tf.placeholder(tf.float32, [None, 16, 4]),
  968. fields.InputDataFields.groundtruth_keypoint_visibilities:
  969. tf.placeholder(tf.bool, [None, 16]),
  970. }
  971. padded_tensor_dict = inputs.pad_input_data_to_static_shapes(
  972. tensor_dict=input_tensor_dict,
  973. max_num_boxes=3,
  974. num_classes=3,
  975. spatial_image_shape=[5, 6])
  976. self.assertAllEqual(
  977. padded_tensor_dict[fields.InputDataFields.groundtruth_keypoints]
  978. .shape.as_list(), [3, 16, 4])
  979. self.assertAllEqual(
  980. padded_tensor_dict[
  981. fields.InputDataFields.groundtruth_keypoint_visibilities]
  982. .shape.as_list(), [3, 16])
  983. if __name__ == '__main__':
  984. tf.test.main()