You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

430 lines
18 KiB

6 years ago
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for object detection model library."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import functools
  20. import os
  21. import numpy as np
  22. import tensorflow as tf
  23. from tensorflow.contrib.tpu.python.tpu import tpu_config
  24. from tensorflow.contrib.tpu.python.tpu import tpu_estimator
  25. from object_detection import inputs
  26. from object_detection import model_hparams
  27. from object_detection import model_lib
  28. from object_detection.builders import model_builder
  29. from object_detection.core import standard_fields as fields
  30. from object_detection.utils import config_util
  31. # Model for test. Options are:
  32. # 'ssd_inception_v2_pets', 'faster_rcnn_resnet50_pets'
  33. MODEL_NAME_FOR_TEST = 'ssd_inception_v2_pets'
  34. def _get_data_path():
  35. """Returns an absolute path to TFRecord file."""
  36. return os.path.join(tf.resource_loader.get_data_files_path(), 'test_data',
  37. 'pets_examples.record')
  38. def get_pipeline_config_path(model_name):
  39. """Returns path to the local pipeline config file."""
  40. return os.path.join(tf.resource_loader.get_data_files_path(), 'samples',
  41. 'configs', model_name + '.config')
  42. def _get_labelmap_path():
  43. """Returns an absolute path to label map file."""
  44. return os.path.join(tf.resource_loader.get_data_files_path(), 'data',
  45. 'pet_label_map.pbtxt')
  46. def _get_configs_for_model(model_name):
  47. """Returns configurations for model."""
  48. filename = get_pipeline_config_path(model_name)
  49. data_path = _get_data_path()
  50. label_map_path = _get_labelmap_path()
  51. configs = config_util.get_configs_from_pipeline_file(filename)
  52. override_dict = {
  53. 'train_input_path': data_path,
  54. 'eval_input_path': data_path,
  55. 'label_map_path': label_map_path
  56. }
  57. configs = config_util.merge_external_params_with_configs(
  58. configs, kwargs_dict=override_dict)
  59. return configs
  60. def _make_initializable_iterator(dataset):
  61. """Creates an iterator, and initializes tables.
  62. Args:
  63. dataset: A `tf.data.Dataset` object.
  64. Returns:
  65. A `tf.data.Iterator`.
  66. """
  67. iterator = dataset.make_initializable_iterator()
  68. tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
  69. return iterator
  70. class ModelLibTest(tf.test.TestCase):
  71. @classmethod
  72. def setUpClass(cls):
  73. tf.reset_default_graph()
  74. def _assert_model_fn_for_train_eval(self, configs, mode,
  75. class_agnostic=False):
  76. model_config = configs['model']
  77. train_config = configs['train_config']
  78. with tf.Graph().as_default():
  79. if mode == 'train':
  80. features, labels = _make_initializable_iterator(
  81. inputs.create_train_input_fn(configs['train_config'],
  82. configs['train_input_config'],
  83. configs['model'])()).get_next()
  84. model_mode = tf.estimator.ModeKeys.TRAIN
  85. batch_size = train_config.batch_size
  86. elif mode == 'eval':
  87. features, labels = _make_initializable_iterator(
  88. inputs.create_eval_input_fn(configs['eval_config'],
  89. configs['eval_input_config'],
  90. configs['model'])()).get_next()
  91. model_mode = tf.estimator.ModeKeys.EVAL
  92. batch_size = 1
  93. elif mode == 'eval_on_train':
  94. features, labels = _make_initializable_iterator(
  95. inputs.create_eval_input_fn(configs['eval_config'],
  96. configs['train_input_config'],
  97. configs['model'])()).get_next()
  98. model_mode = tf.estimator.ModeKeys.EVAL
  99. batch_size = 1
  100. detection_model_fn = functools.partial(
  101. model_builder.build, model_config=model_config, is_training=True)
  102. hparams = model_hparams.create_hparams(
  103. hparams_overrides='load_pretrained=false')
  104. model_fn = model_lib.create_model_fn(detection_model_fn, configs, hparams)
  105. estimator_spec = model_fn(features, labels, model_mode)
  106. self.assertIsNotNone(estimator_spec.loss)
  107. self.assertIsNotNone(estimator_spec.predictions)
  108. if mode == 'eval' or mode == 'eval_on_train':
  109. if class_agnostic:
  110. self.assertNotIn('detection_classes', estimator_spec.predictions)
  111. else:
  112. detection_classes = estimator_spec.predictions['detection_classes']
  113. self.assertEqual(batch_size, detection_classes.shape.as_list()[0])
  114. self.assertEqual(tf.float32, detection_classes.dtype)
  115. detection_boxes = estimator_spec.predictions['detection_boxes']
  116. detection_scores = estimator_spec.predictions['detection_scores']
  117. num_detections = estimator_spec.predictions['num_detections']
  118. self.assertEqual(batch_size, detection_boxes.shape.as_list()[0])
  119. self.assertEqual(tf.float32, detection_boxes.dtype)
  120. self.assertEqual(batch_size, detection_scores.shape.as_list()[0])
  121. self.assertEqual(tf.float32, detection_scores.dtype)
  122. self.assertEqual(tf.float32, num_detections.dtype)
  123. if mode == 'eval':
  124. self.assertIn('Detections_Left_Groundtruth_Right/0',
  125. estimator_spec.eval_metric_ops)
  126. if model_mode == tf.estimator.ModeKeys.TRAIN:
  127. self.assertIsNotNone(estimator_spec.train_op)
  128. return estimator_spec
  129. def _assert_model_fn_for_predict(self, configs):
  130. model_config = configs['model']
  131. with tf.Graph().as_default():
  132. features, _ = _make_initializable_iterator(
  133. inputs.create_eval_input_fn(configs['eval_config'],
  134. configs['eval_input_config'],
  135. configs['model'])()).get_next()
  136. detection_model_fn = functools.partial(
  137. model_builder.build, model_config=model_config, is_training=False)
  138. hparams = model_hparams.create_hparams(
  139. hparams_overrides='load_pretrained=false')
  140. model_fn = model_lib.create_model_fn(detection_model_fn, configs, hparams)
  141. estimator_spec = model_fn(features, None, tf.estimator.ModeKeys.PREDICT)
  142. self.assertIsNone(estimator_spec.loss)
  143. self.assertIsNone(estimator_spec.train_op)
  144. self.assertIsNotNone(estimator_spec.predictions)
  145. self.assertIsNotNone(estimator_spec.export_outputs)
  146. self.assertIn(tf.saved_model.signature_constants.PREDICT_METHOD_NAME,
  147. estimator_spec.export_outputs)
  148. def test_model_fn_in_train_mode(self):
  149. """Tests the model function in TRAIN mode."""
  150. configs = _get_configs_for_model(MODEL_NAME_FOR_TEST)
  151. self._assert_model_fn_for_train_eval(configs, 'train')
  152. def test_model_fn_in_train_mode_freeze_all_variables(self):
  153. """Tests model_fn TRAIN mode with all variables frozen."""
  154. configs = _get_configs_for_model(MODEL_NAME_FOR_TEST)
  155. configs['train_config'].freeze_variables.append('.*')
  156. with self.assertRaisesRegexp(ValueError, 'No variables to optimize'):
  157. self._assert_model_fn_for_train_eval(configs, 'train')
  158. def test_model_fn_in_train_mode_freeze_all_included_variables(self):
  159. """Tests model_fn TRAIN mode with all included variables frozen."""
  160. configs = _get_configs_for_model(MODEL_NAME_FOR_TEST)
  161. train_config = configs['train_config']
  162. train_config.update_trainable_variables.append('FeatureExtractor')
  163. train_config.freeze_variables.append('.*')
  164. with self.assertRaisesRegexp(ValueError, 'No variables to optimize'):
  165. self._assert_model_fn_for_train_eval(configs, 'train')
  166. def test_model_fn_in_train_mode_freeze_box_predictor(self):
  167. """Tests model_fn TRAIN mode with FeatureExtractor variables frozen."""
  168. configs = _get_configs_for_model(MODEL_NAME_FOR_TEST)
  169. train_config = configs['train_config']
  170. train_config.update_trainable_variables.append('FeatureExtractor')
  171. train_config.update_trainable_variables.append('BoxPredictor')
  172. train_config.freeze_variables.append('FeatureExtractor')
  173. self._assert_model_fn_for_train_eval(configs, 'train')
  174. def test_model_fn_in_eval_mode(self):
  175. """Tests the model function in EVAL mode."""
  176. configs = _get_configs_for_model(MODEL_NAME_FOR_TEST)
  177. self._assert_model_fn_for_train_eval(configs, 'eval')
  178. def test_model_fn_in_eval_on_train_mode(self):
  179. """Tests the model function in EVAL mode with train data."""
  180. configs = _get_configs_for_model(MODEL_NAME_FOR_TEST)
  181. self._assert_model_fn_for_train_eval(configs, 'eval_on_train')
  182. def test_model_fn_in_predict_mode(self):
  183. """Tests the model function in PREDICT mode."""
  184. configs = _get_configs_for_model(MODEL_NAME_FOR_TEST)
  185. self._assert_model_fn_for_predict(configs)
  186. def test_create_estimator_and_inputs(self):
  187. """Tests that Estimator and input function are constructed correctly."""
  188. run_config = tf.estimator.RunConfig()
  189. hparams = model_hparams.create_hparams(
  190. hparams_overrides='load_pretrained=false')
  191. pipeline_config_path = get_pipeline_config_path(MODEL_NAME_FOR_TEST)
  192. train_steps = 20
  193. train_and_eval_dict = model_lib.create_estimator_and_inputs(
  194. run_config,
  195. hparams,
  196. pipeline_config_path,
  197. train_steps=train_steps)
  198. estimator = train_and_eval_dict['estimator']
  199. train_steps = train_and_eval_dict['train_steps']
  200. self.assertIsInstance(estimator, tf.estimator.Estimator)
  201. self.assertEqual(20, train_steps)
  202. self.assertIn('train_input_fn', train_and_eval_dict)
  203. self.assertIn('eval_input_fns', train_and_eval_dict)
  204. self.assertIn('eval_on_train_input_fn', train_and_eval_dict)
  205. def test_create_estimator_with_default_train_eval_steps(self):
  206. """Tests that number of train/eval defaults to config values."""
  207. run_config = tf.estimator.RunConfig()
  208. hparams = model_hparams.create_hparams(
  209. hparams_overrides='load_pretrained=false')
  210. pipeline_config_path = get_pipeline_config_path(MODEL_NAME_FOR_TEST)
  211. configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
  212. config_train_steps = configs['train_config'].num_steps
  213. train_and_eval_dict = model_lib.create_estimator_and_inputs(
  214. run_config, hparams, pipeline_config_path)
  215. estimator = train_and_eval_dict['estimator']
  216. train_steps = train_and_eval_dict['train_steps']
  217. self.assertIsInstance(estimator, tf.estimator.Estimator)
  218. self.assertEqual(config_train_steps, train_steps)
  219. def test_create_tpu_estimator_and_inputs(self):
  220. """Tests that number of train/eval defaults to config values."""
  221. run_config = tpu_config.RunConfig()
  222. hparams = model_hparams.create_hparams(
  223. hparams_overrides='load_pretrained=false')
  224. pipeline_config_path = get_pipeline_config_path(MODEL_NAME_FOR_TEST)
  225. train_steps = 20
  226. train_and_eval_dict = model_lib.create_estimator_and_inputs(
  227. run_config,
  228. hparams,
  229. pipeline_config_path,
  230. train_steps=train_steps,
  231. use_tpu_estimator=True)
  232. estimator = train_and_eval_dict['estimator']
  233. train_steps = train_and_eval_dict['train_steps']
  234. self.assertIsInstance(estimator, tpu_estimator.TPUEstimator)
  235. self.assertEqual(20, train_steps)
  236. def test_create_train_and_eval_specs(self):
  237. """Tests that `TrainSpec` and `EvalSpec` is created correctly."""
  238. run_config = tf.estimator.RunConfig()
  239. hparams = model_hparams.create_hparams(
  240. hparams_overrides='load_pretrained=false')
  241. pipeline_config_path = get_pipeline_config_path(MODEL_NAME_FOR_TEST)
  242. train_steps = 20
  243. train_and_eval_dict = model_lib.create_estimator_and_inputs(
  244. run_config,
  245. hparams,
  246. pipeline_config_path,
  247. train_steps=train_steps)
  248. train_input_fn = train_and_eval_dict['train_input_fn']
  249. eval_input_fns = train_and_eval_dict['eval_input_fns']
  250. eval_on_train_input_fn = train_and_eval_dict['eval_on_train_input_fn']
  251. predict_input_fn = train_and_eval_dict['predict_input_fn']
  252. train_steps = train_and_eval_dict['train_steps']
  253. train_spec, eval_specs = model_lib.create_train_and_eval_specs(
  254. train_input_fn,
  255. eval_input_fns,
  256. eval_on_train_input_fn,
  257. predict_input_fn,
  258. train_steps,
  259. eval_on_train_data=True,
  260. final_exporter_name='exporter',
  261. eval_spec_names=['holdout'])
  262. self.assertEqual(train_steps, train_spec.max_steps)
  263. self.assertEqual(2, len(eval_specs))
  264. self.assertEqual(None, eval_specs[0].steps)
  265. self.assertEqual('holdout', eval_specs[0].name)
  266. self.assertEqual('exporter', eval_specs[0].exporters[0].name)
  267. self.assertEqual(None, eval_specs[1].steps)
  268. self.assertEqual('eval_on_train', eval_specs[1].name)
  269. def test_experiment(self):
  270. """Tests that the `Experiment` object is constructed correctly."""
  271. run_config = tf.estimator.RunConfig()
  272. hparams = model_hparams.create_hparams(
  273. hparams_overrides='load_pretrained=false')
  274. pipeline_config_path = get_pipeline_config_path(MODEL_NAME_FOR_TEST)
  275. experiment = model_lib.populate_experiment(
  276. run_config,
  277. hparams,
  278. pipeline_config_path,
  279. train_steps=10,
  280. eval_steps=20)
  281. self.assertEqual(10, experiment.train_steps)
  282. self.assertEqual(None, experiment.eval_steps)
  283. class UnbatchTensorsTest(tf.test.TestCase):
  284. def test_unbatch_without_unpadding(self):
  285. image_placeholder = tf.placeholder(tf.float32, [2, None, None, None])
  286. groundtruth_boxes_placeholder = tf.placeholder(tf.float32, [2, None, None])
  287. groundtruth_classes_placeholder = tf.placeholder(tf.float32,
  288. [2, None, None])
  289. groundtruth_weights_placeholder = tf.placeholder(tf.float32, [2, None])
  290. tensor_dict = {
  291. fields.InputDataFields.image:
  292. image_placeholder,
  293. fields.InputDataFields.groundtruth_boxes:
  294. groundtruth_boxes_placeholder,
  295. fields.InputDataFields.groundtruth_classes:
  296. groundtruth_classes_placeholder,
  297. fields.InputDataFields.groundtruth_weights:
  298. groundtruth_weights_placeholder
  299. }
  300. unbatched_tensor_dict = model_lib.unstack_batch(
  301. tensor_dict, unpad_groundtruth_tensors=False)
  302. with self.test_session() as sess:
  303. unbatched_tensor_dict_out = sess.run(
  304. unbatched_tensor_dict,
  305. feed_dict={
  306. image_placeholder:
  307. np.random.rand(2, 4, 4, 3).astype(np.float32),
  308. groundtruth_boxes_placeholder:
  309. np.random.rand(2, 5, 4).astype(np.float32),
  310. groundtruth_classes_placeholder:
  311. np.random.rand(2, 5, 6).astype(np.float32),
  312. groundtruth_weights_placeholder:
  313. np.random.rand(2, 5).astype(np.float32)
  314. })
  315. for image_out in unbatched_tensor_dict_out[fields.InputDataFields.image]:
  316. self.assertAllEqual(image_out.shape, [4, 4, 3])
  317. for groundtruth_boxes_out in unbatched_tensor_dict_out[
  318. fields.InputDataFields.groundtruth_boxes]:
  319. self.assertAllEqual(groundtruth_boxes_out.shape, [5, 4])
  320. for groundtruth_classes_out in unbatched_tensor_dict_out[
  321. fields.InputDataFields.groundtruth_classes]:
  322. self.assertAllEqual(groundtruth_classes_out.shape, [5, 6])
  323. for groundtruth_weights_out in unbatched_tensor_dict_out[
  324. fields.InputDataFields.groundtruth_weights]:
  325. self.assertAllEqual(groundtruth_weights_out.shape, [5])
  326. def test_unbatch_and_unpad_groundtruth_tensors(self):
  327. image_placeholder = tf.placeholder(tf.float32, [2, None, None, None])
  328. groundtruth_boxes_placeholder = tf.placeholder(tf.float32, [2, 5, None])
  329. groundtruth_classes_placeholder = tf.placeholder(tf.float32, [2, 5, None])
  330. groundtruth_weights_placeholder = tf.placeholder(tf.float32, [2, 5])
  331. num_groundtruth_placeholder = tf.placeholder(tf.int32, [2])
  332. tensor_dict = {
  333. fields.InputDataFields.image:
  334. image_placeholder,
  335. fields.InputDataFields.groundtruth_boxes:
  336. groundtruth_boxes_placeholder,
  337. fields.InputDataFields.groundtruth_classes:
  338. groundtruth_classes_placeholder,
  339. fields.InputDataFields.groundtruth_weights:
  340. groundtruth_weights_placeholder,
  341. fields.InputDataFields.num_groundtruth_boxes:
  342. num_groundtruth_placeholder
  343. }
  344. unbatched_tensor_dict = model_lib.unstack_batch(
  345. tensor_dict, unpad_groundtruth_tensors=True)
  346. with self.test_session() as sess:
  347. unbatched_tensor_dict_out = sess.run(
  348. unbatched_tensor_dict,
  349. feed_dict={
  350. image_placeholder:
  351. np.random.rand(2, 4, 4, 3).astype(np.float32),
  352. groundtruth_boxes_placeholder:
  353. np.random.rand(2, 5, 4).astype(np.float32),
  354. groundtruth_classes_placeholder:
  355. np.random.rand(2, 5, 6).astype(np.float32),
  356. groundtruth_weights_placeholder:
  357. np.random.rand(2, 5).astype(np.float32),
  358. num_groundtruth_placeholder:
  359. np.array([3, 3], np.int32)
  360. })
  361. for image_out in unbatched_tensor_dict_out[fields.InputDataFields.image]:
  362. self.assertAllEqual(image_out.shape, [4, 4, 3])
  363. for groundtruth_boxes_out in unbatched_tensor_dict_out[
  364. fields.InputDataFields.groundtruth_boxes]:
  365. self.assertAllEqual(groundtruth_boxes_out.shape, [3, 4])
  366. for groundtruth_classes_out in unbatched_tensor_dict_out[
  367. fields.InputDataFields.groundtruth_classes]:
  368. self.assertAllEqual(groundtruth_classes_out.shape, [3, 6])
  369. for groundtruth_weights_out in unbatched_tensor_dict_out[
  370. fields.InputDataFields.groundtruth_weights]:
  371. self.assertAllEqual(groundtruth_weights_out.shape, [3])
  372. if __name__ == '__main__':
  373. tf.test.main()