You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

872 lines
41 KiB

  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for object_detection.utils.config_util."""
  16. import os
  17. import tensorflow as tf
  18. from google.protobuf import text_format
  19. from object_detection.protos import eval_pb2
  20. from object_detection.protos import image_resizer_pb2
  21. from object_detection.protos import input_reader_pb2
  22. from object_detection.protos import model_pb2
  23. from object_detection.protos import pipeline_pb2
  24. from object_detection.protos import train_pb2
  25. from object_detection.utils import config_util
  26. def _write_config(config, config_path):
  27. """Writes a config object to disk."""
  28. config_text = text_format.MessageToString(config)
  29. with tf.gfile.Open(config_path, "wb") as f:
  30. f.write(config_text)
  31. def _update_optimizer_with_constant_learning_rate(optimizer, learning_rate):
  32. """Adds a new constant learning rate."""
  33. constant_lr = optimizer.learning_rate.constant_learning_rate
  34. constant_lr.learning_rate = learning_rate
  35. def _update_optimizer_with_exponential_decay_learning_rate(
  36. optimizer, learning_rate):
  37. """Adds a new exponential decay learning rate."""
  38. exponential_lr = optimizer.learning_rate.exponential_decay_learning_rate
  39. exponential_lr.initial_learning_rate = learning_rate
  40. def _update_optimizer_with_manual_step_learning_rate(
  41. optimizer, initial_learning_rate, learning_rate_scaling):
  42. """Adds a learning rate schedule."""
  43. manual_lr = optimizer.learning_rate.manual_step_learning_rate
  44. manual_lr.initial_learning_rate = initial_learning_rate
  45. for i in range(3):
  46. schedule = manual_lr.schedule.add()
  47. schedule.learning_rate = initial_learning_rate * learning_rate_scaling**i
  48. def _update_optimizer_with_cosine_decay_learning_rate(
  49. optimizer, learning_rate, warmup_learning_rate):
  50. """Adds a new cosine decay learning rate."""
  51. cosine_lr = optimizer.learning_rate.cosine_decay_learning_rate
  52. cosine_lr.learning_rate_base = learning_rate
  53. cosine_lr.warmup_learning_rate = warmup_learning_rate
  54. class ConfigUtilTest(tf.test.TestCase):
  55. def _create_and_load_test_configs(self, pipeline_config):
  56. pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
  57. _write_config(pipeline_config, pipeline_config_path)
  58. return config_util.get_configs_from_pipeline_file(pipeline_config_path)
  59. def test_get_configs_from_pipeline_file(self):
  60. """Test that proto configs can be read from pipeline config file."""
  61. pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
  62. pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  63. pipeline_config.model.faster_rcnn.num_classes = 10
  64. pipeline_config.train_config.batch_size = 32
  65. pipeline_config.train_input_reader.label_map_path = "path/to/label_map"
  66. pipeline_config.eval_config.num_examples = 20
  67. pipeline_config.eval_input_reader.add().queue_capacity = 100
  68. _write_config(pipeline_config, pipeline_config_path)
  69. configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
  70. self.assertProtoEquals(pipeline_config.model, configs["model"])
  71. self.assertProtoEquals(pipeline_config.train_config,
  72. configs["train_config"])
  73. self.assertProtoEquals(pipeline_config.train_input_reader,
  74. configs["train_input_config"])
  75. self.assertProtoEquals(pipeline_config.eval_config,
  76. configs["eval_config"])
  77. self.assertProtoEquals(pipeline_config.eval_input_reader,
  78. configs["eval_input_configs"])
  79. def test_create_configs_from_pipeline_proto(self):
  80. """Tests creating configs dictionary from pipeline proto."""
  81. pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  82. pipeline_config.model.faster_rcnn.num_classes = 10
  83. pipeline_config.train_config.batch_size = 32
  84. pipeline_config.train_input_reader.label_map_path = "path/to/label_map"
  85. pipeline_config.eval_config.num_examples = 20
  86. pipeline_config.eval_input_reader.add().queue_capacity = 100
  87. configs = config_util.create_configs_from_pipeline_proto(pipeline_config)
  88. self.assertProtoEquals(pipeline_config.model, configs["model"])
  89. self.assertProtoEquals(pipeline_config.train_config,
  90. configs["train_config"])
  91. self.assertProtoEquals(pipeline_config.train_input_reader,
  92. configs["train_input_config"])
  93. self.assertProtoEquals(pipeline_config.eval_config, configs["eval_config"])
  94. self.assertProtoEquals(pipeline_config.eval_input_reader,
  95. configs["eval_input_configs"])
  96. def test_create_pipeline_proto_from_configs(self):
  97. """Tests that proto can be reconstructed from configs dictionary."""
  98. pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
  99. pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  100. pipeline_config.model.faster_rcnn.num_classes = 10
  101. pipeline_config.train_config.batch_size = 32
  102. pipeline_config.train_input_reader.label_map_path = "path/to/label_map"
  103. pipeline_config.eval_config.num_examples = 20
  104. pipeline_config.eval_input_reader.add().queue_capacity = 100
  105. _write_config(pipeline_config, pipeline_config_path)
  106. configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
  107. pipeline_config_reconstructed = (
  108. config_util.create_pipeline_proto_from_configs(configs))
  109. self.assertEqual(pipeline_config, pipeline_config_reconstructed)
  110. def test_save_pipeline_config(self):
  111. """Tests that the pipeline config is properly saved to disk."""
  112. pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  113. pipeline_config.model.faster_rcnn.num_classes = 10
  114. pipeline_config.train_config.batch_size = 32
  115. pipeline_config.train_input_reader.label_map_path = "path/to/label_map"
  116. pipeline_config.eval_config.num_examples = 20
  117. pipeline_config.eval_input_reader.add().queue_capacity = 100
  118. config_util.save_pipeline_config(pipeline_config, self.get_temp_dir())
  119. configs = config_util.get_configs_from_pipeline_file(
  120. os.path.join(self.get_temp_dir(), "pipeline.config"))
  121. pipeline_config_reconstructed = (
  122. config_util.create_pipeline_proto_from_configs(configs))
  123. self.assertEqual(pipeline_config, pipeline_config_reconstructed)
  124. def test_get_configs_from_multiple_files(self):
  125. """Tests that proto configs can be read from multiple files."""
  126. temp_dir = self.get_temp_dir()
  127. # Write model config file.
  128. model_config_path = os.path.join(temp_dir, "model.config")
  129. model = model_pb2.DetectionModel()
  130. model.faster_rcnn.num_classes = 10
  131. _write_config(model, model_config_path)
  132. # Write train config file.
  133. train_config_path = os.path.join(temp_dir, "train.config")
  134. train_config = train_config = train_pb2.TrainConfig()
  135. train_config.batch_size = 32
  136. _write_config(train_config, train_config_path)
  137. # Write train input config file.
  138. train_input_config_path = os.path.join(temp_dir, "train_input.config")
  139. train_input_config = input_reader_pb2.InputReader()
  140. train_input_config.label_map_path = "path/to/label_map"
  141. _write_config(train_input_config, train_input_config_path)
  142. # Write eval config file.
  143. eval_config_path = os.path.join(temp_dir, "eval.config")
  144. eval_config = eval_pb2.EvalConfig()
  145. eval_config.num_examples = 20
  146. _write_config(eval_config, eval_config_path)
  147. # Write eval input config file.
  148. eval_input_config_path = os.path.join(temp_dir, "eval_input.config")
  149. eval_input_config = input_reader_pb2.InputReader()
  150. eval_input_config.label_map_path = "path/to/another/label_map"
  151. _write_config(eval_input_config, eval_input_config_path)
  152. configs = config_util.get_configs_from_multiple_files(
  153. model_config_path=model_config_path,
  154. train_config_path=train_config_path,
  155. train_input_config_path=train_input_config_path,
  156. eval_config_path=eval_config_path,
  157. eval_input_config_path=eval_input_config_path)
  158. self.assertProtoEquals(model, configs["model"])
  159. self.assertProtoEquals(train_config, configs["train_config"])
  160. self.assertProtoEquals(train_input_config,
  161. configs["train_input_config"])
  162. self.assertProtoEquals(eval_config, configs["eval_config"])
  163. self.assertProtoEquals(eval_input_config, configs["eval_input_configs"][0])
  164. def _assertOptimizerWithNewLearningRate(self, optimizer_name):
  165. """Asserts successful updating of all learning rate schemes."""
  166. original_learning_rate = 0.7
  167. learning_rate_scaling = 0.1
  168. warmup_learning_rate = 0.07
  169. hparams = tf.contrib.training.HParams(learning_rate=0.15)
  170. pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
  171. # Constant learning rate.
  172. pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  173. optimizer = getattr(pipeline_config.train_config.optimizer, optimizer_name)
  174. _update_optimizer_with_constant_learning_rate(optimizer,
  175. original_learning_rate)
  176. _write_config(pipeline_config, pipeline_config_path)
  177. configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
  178. configs = config_util.merge_external_params_with_configs(configs, hparams)
  179. optimizer = getattr(configs["train_config"].optimizer, optimizer_name)
  180. constant_lr = optimizer.learning_rate.constant_learning_rate
  181. self.assertAlmostEqual(hparams.learning_rate, constant_lr.learning_rate)
  182. # Exponential decay learning rate.
  183. pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  184. optimizer = getattr(pipeline_config.train_config.optimizer, optimizer_name)
  185. _update_optimizer_with_exponential_decay_learning_rate(
  186. optimizer, original_learning_rate)
  187. _write_config(pipeline_config, pipeline_config_path)
  188. configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
  189. configs = config_util.merge_external_params_with_configs(configs, hparams)
  190. optimizer = getattr(configs["train_config"].optimizer, optimizer_name)
  191. exponential_lr = optimizer.learning_rate.exponential_decay_learning_rate
  192. self.assertAlmostEqual(hparams.learning_rate,
  193. exponential_lr.initial_learning_rate)
  194. # Manual step learning rate.
  195. pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  196. optimizer = getattr(pipeline_config.train_config.optimizer, optimizer_name)
  197. _update_optimizer_with_manual_step_learning_rate(
  198. optimizer, original_learning_rate, learning_rate_scaling)
  199. _write_config(pipeline_config, pipeline_config_path)
  200. configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
  201. configs = config_util.merge_external_params_with_configs(configs, hparams)
  202. optimizer = getattr(configs["train_config"].optimizer, optimizer_name)
  203. manual_lr = optimizer.learning_rate.manual_step_learning_rate
  204. self.assertAlmostEqual(hparams.learning_rate,
  205. manual_lr.initial_learning_rate)
  206. for i, schedule in enumerate(manual_lr.schedule):
  207. self.assertAlmostEqual(hparams.learning_rate * learning_rate_scaling**i,
  208. schedule.learning_rate)
  209. # Cosine decay learning rate.
  210. pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  211. optimizer = getattr(pipeline_config.train_config.optimizer, optimizer_name)
  212. _update_optimizer_with_cosine_decay_learning_rate(optimizer,
  213. original_learning_rate,
  214. warmup_learning_rate)
  215. _write_config(pipeline_config, pipeline_config_path)
  216. configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
  217. configs = config_util.merge_external_params_with_configs(configs, hparams)
  218. optimizer = getattr(configs["train_config"].optimizer, optimizer_name)
  219. cosine_lr = optimizer.learning_rate.cosine_decay_learning_rate
  220. self.assertAlmostEqual(hparams.learning_rate, cosine_lr.learning_rate_base)
  221. warmup_scale_factor = warmup_learning_rate / original_learning_rate
  222. self.assertAlmostEqual(hparams.learning_rate * warmup_scale_factor,
  223. cosine_lr.warmup_learning_rate)
  224. def testRMSPropWithNewLearingRate(self):
  225. """Tests new learning rates for RMSProp Optimizer."""
  226. self._assertOptimizerWithNewLearningRate("rms_prop_optimizer")
  227. def testMomentumOptimizerWithNewLearningRate(self):
  228. """Tests new learning rates for Momentum Optimizer."""
  229. self._assertOptimizerWithNewLearningRate("momentum_optimizer")
  230. def testAdamOptimizerWithNewLearningRate(self):
  231. """Tests new learning rates for Adam Optimizer."""
  232. self._assertOptimizerWithNewLearningRate("adam_optimizer")
  233. def testGenericConfigOverride(self):
  234. """Tests generic config overrides for all top-level configs."""
  235. # Set one parameter for each of the top-level pipeline configs:
  236. pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  237. pipeline_config.model.ssd.num_classes = 1
  238. pipeline_config.train_config.batch_size = 1
  239. pipeline_config.eval_config.num_visualizations = 1
  240. pipeline_config.train_input_reader.label_map_path = "/some/path"
  241. pipeline_config.eval_input_reader.add().label_map_path = "/some/path"
  242. pipeline_config.graph_rewriter.quantization.weight_bits = 1
  243. pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
  244. _write_config(pipeline_config, pipeline_config_path)
  245. # Override each of the parameters:
  246. configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
  247. hparams = tf.contrib.training.HParams(
  248. **{
  249. "model.ssd.num_classes": 2,
  250. "train_config.batch_size": 2,
  251. "train_input_config.label_map_path": "/some/other/path",
  252. "eval_config.num_visualizations": 2,
  253. "graph_rewriter_config.quantization.weight_bits": 2
  254. })
  255. configs = config_util.merge_external_params_with_configs(configs, hparams)
  256. # Ensure that the parameters have the overridden values:
  257. self.assertEqual(2, configs["model"].ssd.num_classes)
  258. self.assertEqual(2, configs["train_config"].batch_size)
  259. self.assertEqual("/some/other/path",
  260. configs["train_input_config"].label_map_path)
  261. self.assertEqual(2, configs["eval_config"].num_visualizations)
  262. self.assertEqual(2,
  263. configs["graph_rewriter_config"].quantization.weight_bits)
  264. def testNewBatchSize(self):
  265. """Tests that batch size is updated appropriately."""
  266. original_batch_size = 2
  267. hparams = tf.contrib.training.HParams(batch_size=16)
  268. pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
  269. pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  270. pipeline_config.train_config.batch_size = original_batch_size
  271. _write_config(pipeline_config, pipeline_config_path)
  272. configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
  273. configs = config_util.merge_external_params_with_configs(configs, hparams)
  274. new_batch_size = configs["train_config"].batch_size
  275. self.assertEqual(16, new_batch_size)
  276. def testNewBatchSizeWithClipping(self):
  277. """Tests that batch size is clipped to 1 from below."""
  278. original_batch_size = 2
  279. hparams = tf.contrib.training.HParams(batch_size=0.5)
  280. pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
  281. pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  282. pipeline_config.train_config.batch_size = original_batch_size
  283. _write_config(pipeline_config, pipeline_config_path)
  284. configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
  285. configs = config_util.merge_external_params_with_configs(configs, hparams)
  286. new_batch_size = configs["train_config"].batch_size
  287. self.assertEqual(1, new_batch_size) # Clipped to 1.0.
  288. def testOverwriteBatchSizeWithKeyValue(self):
  289. """Tests that batch size is overwritten based on key/value."""
  290. pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  291. pipeline_config.train_config.batch_size = 2
  292. configs = self._create_and_load_test_configs(pipeline_config)
  293. hparams = tf.contrib.training.HParams(**{"train_config.batch_size": 10})
  294. configs = config_util.merge_external_params_with_configs(configs, hparams)
  295. new_batch_size = configs["train_config"].batch_size
  296. self.assertEqual(10, new_batch_size)
  297. def testKeyValueOverrideBadKey(self):
  298. """Tests that overwriting with a bad key causes an exception."""
  299. pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  300. configs = self._create_and_load_test_configs(pipeline_config)
  301. hparams = tf.contrib.training.HParams(**{"train_config.no_such_field": 10})
  302. with self.assertRaises(ValueError):
  303. config_util.merge_external_params_with_configs(configs, hparams)
  304. def testOverwriteBatchSizeWithBadValueType(self):
  305. """Tests that overwriting with a bad valuye type causes an exception."""
  306. pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  307. pipeline_config.train_config.batch_size = 2
  308. configs = self._create_and_load_test_configs(pipeline_config)
  309. # Type should be an integer, but we're passing a string "10".
  310. hparams = tf.contrib.training.HParams(**{"train_config.batch_size": "10"})
  311. with self.assertRaises(TypeError):
  312. config_util.merge_external_params_with_configs(configs, hparams)
  313. def testNewMomentumOptimizerValue(self):
  314. """Tests that new momentum value is updated appropriately."""
  315. original_momentum_value = 0.4
  316. hparams = tf.contrib.training.HParams(momentum_optimizer_value=1.1)
  317. pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
  318. pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  319. optimizer_config = pipeline_config.train_config.optimizer.rms_prop_optimizer
  320. optimizer_config.momentum_optimizer_value = original_momentum_value
  321. _write_config(pipeline_config, pipeline_config_path)
  322. configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
  323. configs = config_util.merge_external_params_with_configs(configs, hparams)
  324. optimizer_config = configs["train_config"].optimizer.rms_prop_optimizer
  325. new_momentum_value = optimizer_config.momentum_optimizer_value
  326. self.assertAlmostEqual(1.0, new_momentum_value) # Clipped to 1.0.
  327. def testNewClassificationLocalizationWeightRatio(self):
  328. """Tests that the loss weight ratio is updated appropriately."""
  329. original_localization_weight = 0.1
  330. original_classification_weight = 0.2
  331. new_weight_ratio = 5.0
  332. hparams = tf.contrib.training.HParams(
  333. classification_localization_weight_ratio=new_weight_ratio)
  334. pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
  335. pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  336. pipeline_config.model.ssd.loss.localization_weight = (
  337. original_localization_weight)
  338. pipeline_config.model.ssd.loss.classification_weight = (
  339. original_classification_weight)
  340. _write_config(pipeline_config, pipeline_config_path)
  341. configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
  342. configs = config_util.merge_external_params_with_configs(configs, hparams)
  343. loss = configs["model"].ssd.loss
  344. self.assertAlmostEqual(1.0, loss.localization_weight)
  345. self.assertAlmostEqual(new_weight_ratio, loss.classification_weight)
  346. def testNewFocalLossParameters(self):
  347. """Tests that the loss weight ratio is updated appropriately."""
  348. original_alpha = 1.0
  349. original_gamma = 1.0
  350. new_alpha = 0.3
  351. new_gamma = 2.0
  352. hparams = tf.contrib.training.HParams(
  353. focal_loss_alpha=new_alpha, focal_loss_gamma=new_gamma)
  354. pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
  355. pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  356. classification_loss = pipeline_config.model.ssd.loss.classification_loss
  357. classification_loss.weighted_sigmoid_focal.alpha = original_alpha
  358. classification_loss.weighted_sigmoid_focal.gamma = original_gamma
  359. _write_config(pipeline_config, pipeline_config_path)
  360. configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
  361. configs = config_util.merge_external_params_with_configs(configs, hparams)
  362. classification_loss = configs["model"].ssd.loss.classification_loss
  363. self.assertAlmostEqual(new_alpha,
  364. classification_loss.weighted_sigmoid_focal.alpha)
  365. self.assertAlmostEqual(new_gamma,
  366. classification_loss.weighted_sigmoid_focal.gamma)
  367. def testMergingKeywordArguments(self):
  368. """Tests that keyword arguments get merged as do hyperparameters."""
  369. original_num_train_steps = 100
  370. desired_num_train_steps = 10
  371. pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
  372. pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  373. pipeline_config.train_config.num_steps = original_num_train_steps
  374. _write_config(pipeline_config, pipeline_config_path)
  375. configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
  376. override_dict = {"train_steps": desired_num_train_steps}
  377. configs = config_util.merge_external_params_with_configs(
  378. configs, kwargs_dict=override_dict)
  379. train_steps = configs["train_config"].num_steps
  380. self.assertEqual(desired_num_train_steps, train_steps)
  381. def testGetNumberOfClasses(self):
  382. """Tests that number of classes can be retrieved."""
  383. pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
  384. pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  385. pipeline_config.model.faster_rcnn.num_classes = 20
  386. _write_config(pipeline_config, pipeline_config_path)
  387. configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
  388. number_of_classes = config_util.get_number_of_classes(configs["model"])
  389. self.assertEqual(20, number_of_classes)
  390. def testNewTrainInputPath(self):
  391. """Tests that train input path can be overwritten with single file."""
  392. original_train_path = ["path/to/data"]
  393. new_train_path = "another/path/to/data"
  394. pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
  395. pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  396. reader_config = pipeline_config.train_input_reader.tf_record_input_reader
  397. reader_config.input_path.extend(original_train_path)
  398. _write_config(pipeline_config, pipeline_config_path)
  399. configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
  400. override_dict = {"train_input_path": new_train_path}
  401. configs = config_util.merge_external_params_with_configs(
  402. configs, kwargs_dict=override_dict)
  403. reader_config = configs["train_input_config"].tf_record_input_reader
  404. final_path = reader_config.input_path
  405. self.assertEqual([new_train_path], final_path)
  406. def testNewTrainInputPathList(self):
  407. """Tests that train input path can be overwritten with multiple files."""
  408. original_train_path = ["path/to/data"]
  409. new_train_path = ["another/path/to/data", "yet/another/path/to/data"]
  410. pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
  411. pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  412. reader_config = pipeline_config.train_input_reader.tf_record_input_reader
  413. reader_config.input_path.extend(original_train_path)
  414. _write_config(pipeline_config, pipeline_config_path)
  415. configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
  416. override_dict = {"train_input_path": new_train_path}
  417. configs = config_util.merge_external_params_with_configs(
  418. configs, kwargs_dict=override_dict)
  419. reader_config = configs["train_input_config"].tf_record_input_reader
  420. final_path = reader_config.input_path
  421. self.assertEqual(new_train_path, final_path)
  422. def testNewLabelMapPath(self):
  423. """Tests that label map path can be overwritten in input readers."""
  424. original_label_map_path = "path/to/original/label_map"
  425. new_label_map_path = "path//to/new/label_map"
  426. pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
  427. pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  428. train_input_reader = pipeline_config.train_input_reader
  429. train_input_reader.label_map_path = original_label_map_path
  430. eval_input_reader = pipeline_config.eval_input_reader.add()
  431. eval_input_reader.label_map_path = original_label_map_path
  432. _write_config(pipeline_config, pipeline_config_path)
  433. configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
  434. override_dict = {"label_map_path": new_label_map_path}
  435. configs = config_util.merge_external_params_with_configs(
  436. configs, kwargs_dict=override_dict)
  437. self.assertEqual(new_label_map_path,
  438. configs["train_input_config"].label_map_path)
  439. for eval_input_config in configs["eval_input_configs"]:
  440. self.assertEqual(new_label_map_path, eval_input_config.label_map_path)
  441. def testDontOverwriteEmptyLabelMapPath(self):
  442. """Tests that label map path will not by overwritten with empty string."""
  443. original_label_map_path = "path/to/original/label_map"
  444. new_label_map_path = ""
  445. pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
  446. pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  447. train_input_reader = pipeline_config.train_input_reader
  448. train_input_reader.label_map_path = original_label_map_path
  449. eval_input_reader = pipeline_config.eval_input_reader.add()
  450. eval_input_reader.label_map_path = original_label_map_path
  451. _write_config(pipeline_config, pipeline_config_path)
  452. configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
  453. override_dict = {"label_map_path": new_label_map_path}
  454. configs = config_util.merge_external_params_with_configs(
  455. configs, kwargs_dict=override_dict)
  456. self.assertEqual(original_label_map_path,
  457. configs["train_input_config"].label_map_path)
  458. self.assertEqual(original_label_map_path,
  459. configs["eval_input_configs"][0].label_map_path)
  460. def testNewMaskType(self):
  461. """Tests that mask type can be overwritten in input readers."""
  462. original_mask_type = input_reader_pb2.NUMERICAL_MASKS
  463. new_mask_type = input_reader_pb2.PNG_MASKS
  464. pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
  465. pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  466. train_input_reader = pipeline_config.train_input_reader
  467. train_input_reader.mask_type = original_mask_type
  468. eval_input_reader = pipeline_config.eval_input_reader.add()
  469. eval_input_reader.mask_type = original_mask_type
  470. _write_config(pipeline_config, pipeline_config_path)
  471. configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
  472. override_dict = {"mask_type": new_mask_type}
  473. configs = config_util.merge_external_params_with_configs(
  474. configs, kwargs_dict=override_dict)
  475. self.assertEqual(new_mask_type, configs["train_input_config"].mask_type)
  476. self.assertEqual(new_mask_type, configs["eval_input_configs"][0].mask_type)
  477. def testUseMovingAverageForEval(self):
  478. use_moving_averages_orig = False
  479. pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
  480. pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  481. pipeline_config.eval_config.use_moving_averages = use_moving_averages_orig
  482. _write_config(pipeline_config, pipeline_config_path)
  483. configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
  484. override_dict = {"eval_with_moving_averages": True}
  485. configs = config_util.merge_external_params_with_configs(
  486. configs, kwargs_dict=override_dict)
  487. self.assertEqual(True, configs["eval_config"].use_moving_averages)
  488. def testGetImageResizerConfig(self):
  489. """Tests that number of classes can be retrieved."""
  490. model_config = model_pb2.DetectionModel()
  491. model_config.faster_rcnn.image_resizer.fixed_shape_resizer.height = 100
  492. model_config.faster_rcnn.image_resizer.fixed_shape_resizer.width = 300
  493. image_resizer_config = config_util.get_image_resizer_config(model_config)
  494. self.assertEqual(image_resizer_config.fixed_shape_resizer.height, 100)
  495. self.assertEqual(image_resizer_config.fixed_shape_resizer.width, 300)
  496. def testGetSpatialImageSizeFromFixedShapeResizerConfig(self):
  497. image_resizer_config = image_resizer_pb2.ImageResizer()
  498. image_resizer_config.fixed_shape_resizer.height = 100
  499. image_resizer_config.fixed_shape_resizer.width = 200
  500. image_shape = config_util.get_spatial_image_size(image_resizer_config)
  501. self.assertAllEqual(image_shape, [100, 200])
  502. def testGetSpatialImageSizeFromAspectPreservingResizerConfig(self):
  503. image_resizer_config = image_resizer_pb2.ImageResizer()
  504. image_resizer_config.keep_aspect_ratio_resizer.min_dimension = 100
  505. image_resizer_config.keep_aspect_ratio_resizer.max_dimension = 600
  506. image_resizer_config.keep_aspect_ratio_resizer.pad_to_max_dimension = True
  507. image_shape = config_util.get_spatial_image_size(image_resizer_config)
  508. self.assertAllEqual(image_shape, [600, 600])
  509. def testGetSpatialImageSizeFromAspectPreservingResizerDynamic(self):
  510. image_resizer_config = image_resizer_pb2.ImageResizer()
  511. image_resizer_config.keep_aspect_ratio_resizer.min_dimension = 100
  512. image_resizer_config.keep_aspect_ratio_resizer.max_dimension = 600
  513. image_shape = config_util.get_spatial_image_size(image_resizer_config)
  514. self.assertAllEqual(image_shape, [-1, -1])
  515. def testGetSpatialImageSizeFromConditionalShapeResizer(self):
  516. image_resizer_config = image_resizer_pb2.ImageResizer()
  517. image_resizer_config.conditional_shape_resizer.size_threshold = 100
  518. image_shape = config_util.get_spatial_image_size(image_resizer_config)
  519. self.assertAllEqual(image_shape, [-1, -1])
  520. def testEvalShuffle(self):
  521. """Tests that `eval_shuffle` keyword arguments are applied correctly."""
  522. original_shuffle = True
  523. desired_shuffle = False
  524. pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
  525. pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  526. pipeline_config.eval_input_reader.add().shuffle = original_shuffle
  527. _write_config(pipeline_config, pipeline_config_path)
  528. configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
  529. override_dict = {"eval_shuffle": desired_shuffle}
  530. configs = config_util.merge_external_params_with_configs(
  531. configs, kwargs_dict=override_dict)
  532. self.assertEqual(desired_shuffle, configs["eval_input_configs"][0].shuffle)
  533. def testTrainShuffle(self):
  534. """Tests that `train_shuffle` keyword arguments are applied correctly."""
  535. original_shuffle = True
  536. desired_shuffle = False
  537. pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
  538. pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  539. pipeline_config.train_input_reader.shuffle = original_shuffle
  540. _write_config(pipeline_config, pipeline_config_path)
  541. configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
  542. override_dict = {"train_shuffle": desired_shuffle}
  543. configs = config_util.merge_external_params_with_configs(
  544. configs, kwargs_dict=override_dict)
  545. train_shuffle = configs["train_input_config"].shuffle
  546. self.assertEqual(desired_shuffle, train_shuffle)
  547. def testOverWriteRetainOriginalImages(self):
  548. """Tests that `train_shuffle` keyword arguments are applied correctly."""
  549. original_retain_original_images = True
  550. desired_retain_original_images = False
  551. pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
  552. pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  553. pipeline_config.eval_config.retain_original_images = (
  554. original_retain_original_images)
  555. _write_config(pipeline_config, pipeline_config_path)
  556. configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
  557. override_dict = {
  558. "retain_original_images_in_eval": desired_retain_original_images
  559. }
  560. configs = config_util.merge_external_params_with_configs(
  561. configs, kwargs_dict=override_dict)
  562. retain_original_images = configs["eval_config"].retain_original_images
  563. self.assertEqual(desired_retain_original_images, retain_original_images)
  564. def testOverwriteAllEvalSampling(self):
  565. original_num_eval_examples = 1
  566. new_num_eval_examples = 10
  567. pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
  568. pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  569. pipeline_config.eval_input_reader.add().sample_1_of_n_examples = (
  570. original_num_eval_examples)
  571. pipeline_config.eval_input_reader.add().sample_1_of_n_examples = (
  572. original_num_eval_examples)
  573. _write_config(pipeline_config, pipeline_config_path)
  574. configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
  575. override_dict = {"sample_1_of_n_eval_examples": new_num_eval_examples}
  576. configs = config_util.merge_external_params_with_configs(
  577. configs, kwargs_dict=override_dict)
  578. for eval_input_config in configs["eval_input_configs"]:
  579. self.assertEqual(new_num_eval_examples,
  580. eval_input_config.sample_1_of_n_examples)
  581. def testOverwriteAllEvalNumEpochs(self):
  582. original_num_epochs = 10
  583. new_num_epochs = 1
  584. pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
  585. pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  586. pipeline_config.eval_input_reader.add().num_epochs = original_num_epochs
  587. pipeline_config.eval_input_reader.add().num_epochs = original_num_epochs
  588. _write_config(pipeline_config, pipeline_config_path)
  589. configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
  590. override_dict = {"eval_num_epochs": new_num_epochs}
  591. configs = config_util.merge_external_params_with_configs(
  592. configs, kwargs_dict=override_dict)
  593. for eval_input_config in configs["eval_input_configs"]:
  594. self.assertEqual(new_num_epochs, eval_input_config.num_epochs)
  595. def testUpdateMaskTypeForAllInputConfigs(self):
  596. original_mask_type = input_reader_pb2.NUMERICAL_MASKS
  597. new_mask_type = input_reader_pb2.PNG_MASKS
  598. pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
  599. pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  600. train_config = pipeline_config.train_input_reader
  601. train_config.mask_type = original_mask_type
  602. eval_1 = pipeline_config.eval_input_reader.add()
  603. eval_1.mask_type = original_mask_type
  604. eval_1.name = "eval_1"
  605. eval_2 = pipeline_config.eval_input_reader.add()
  606. eval_2.mask_type = original_mask_type
  607. eval_2.name = "eval_2"
  608. _write_config(pipeline_config, pipeline_config_path)
  609. configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
  610. override_dict = {"mask_type": new_mask_type}
  611. configs = config_util.merge_external_params_with_configs(
  612. configs, kwargs_dict=override_dict)
  613. self.assertEqual(configs["train_input_config"].mask_type, new_mask_type)
  614. for eval_input_config in configs["eval_input_configs"]:
  615. self.assertEqual(eval_input_config.mask_type, new_mask_type)
  616. def testErrorOverwritingMultipleInputConfig(self):
  617. original_shuffle = False
  618. new_shuffle = True
  619. pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
  620. pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  621. eval_1 = pipeline_config.eval_input_reader.add()
  622. eval_1.shuffle = original_shuffle
  623. eval_1.name = "eval_1"
  624. eval_2 = pipeline_config.eval_input_reader.add()
  625. eval_2.shuffle = original_shuffle
  626. eval_2.name = "eval_2"
  627. _write_config(pipeline_config, pipeline_config_path)
  628. configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
  629. override_dict = {"eval_shuffle": new_shuffle}
  630. with self.assertRaises(ValueError):
  631. configs = config_util.merge_external_params_with_configs(
  632. configs, kwargs_dict=override_dict)
  633. def testCheckAndParseInputConfigKey(self):
  634. pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
  635. pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  636. pipeline_config.eval_input_reader.add().name = "eval_1"
  637. pipeline_config.eval_input_reader.add().name = "eval_2"
  638. _write_config(pipeline_config, pipeline_config_path)
  639. configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
  640. specific_shuffle_update_key = "eval_input_configs:eval_2:shuffle"
  641. is_valid_input_config_key, key_name, input_name, field_name = (
  642. config_util.check_and_parse_input_config_key(
  643. configs, specific_shuffle_update_key))
  644. self.assertTrue(is_valid_input_config_key)
  645. self.assertEqual(key_name, "eval_input_configs")
  646. self.assertEqual(input_name, "eval_2")
  647. self.assertEqual(field_name, "shuffle")
  648. legacy_shuffle_update_key = "eval_shuffle"
  649. is_valid_input_config_key, key_name, input_name, field_name = (
  650. config_util.check_and_parse_input_config_key(configs,
  651. legacy_shuffle_update_key))
  652. self.assertTrue(is_valid_input_config_key)
  653. self.assertEqual(key_name, "eval_input_configs")
  654. self.assertEqual(input_name, None)
  655. self.assertEqual(field_name, "shuffle")
  656. non_input_config_update_key = "label_map_path"
  657. is_valid_input_config_key, key_name, input_name, field_name = (
  658. config_util.check_and_parse_input_config_key(
  659. configs, non_input_config_update_key))
  660. self.assertFalse(is_valid_input_config_key)
  661. self.assertEqual(key_name, None)
  662. self.assertEqual(input_name, None)
  663. self.assertEqual(field_name, "label_map_path")
  664. with self.assertRaisesRegexp(ValueError,
  665. "Invalid key format when overriding configs."):
  666. config_util.check_and_parse_input_config_key(
  667. configs, "train_input_config:shuffle")
  668. with self.assertRaisesRegexp(
  669. ValueError, "Invalid key_name when overriding input config."):
  670. config_util.check_and_parse_input_config_key(
  671. configs, "invalid_key_name:train_name:shuffle")
  672. with self.assertRaisesRegexp(
  673. ValueError, "Invalid input_name when overriding input config."):
  674. config_util.check_and_parse_input_config_key(
  675. configs, "eval_input_configs:unknown_eval_name:shuffle")
  676. with self.assertRaisesRegexp(
  677. ValueError, "Invalid field_name when overriding input config."):
  678. config_util.check_and_parse_input_config_key(
  679. configs, "eval_input_configs:eval_2:unknown_field_name")
  680. def testUpdateInputReaderConfigSuccess(self):
  681. original_shuffle = False
  682. new_shuffle = True
  683. pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
  684. pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  685. pipeline_config.train_input_reader.shuffle = original_shuffle
  686. _write_config(pipeline_config, pipeline_config_path)
  687. configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
  688. config_util.update_input_reader_config(
  689. configs,
  690. key_name="train_input_config",
  691. input_name=None,
  692. field_name="shuffle",
  693. value=new_shuffle)
  694. self.assertEqual(configs["train_input_config"].shuffle, new_shuffle)
  695. config_util.update_input_reader_config(
  696. configs,
  697. key_name="train_input_config",
  698. input_name=None,
  699. field_name="shuffle",
  700. value=new_shuffle)
  701. self.assertEqual(configs["train_input_config"].shuffle, new_shuffle)
  702. def testUpdateInputReaderConfigErrors(self):
  703. pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
  704. pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  705. pipeline_config.eval_input_reader.add().name = "same_eval_name"
  706. pipeline_config.eval_input_reader.add().name = "same_eval_name"
  707. _write_config(pipeline_config, pipeline_config_path)
  708. configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
  709. with self.assertRaisesRegexp(ValueError,
  710. "Duplicate input name found when overriding."):
  711. config_util.update_input_reader_config(
  712. configs,
  713. key_name="eval_input_configs",
  714. input_name="same_eval_name",
  715. field_name="shuffle",
  716. value=False)
  717. with self.assertRaisesRegexp(
  718. ValueError, "Input name name_not_exist not found when overriding."):
  719. config_util.update_input_reader_config(
  720. configs,
  721. key_name="eval_input_configs",
  722. input_name="name_not_exist",
  723. field_name="shuffle",
  724. value=False)
  725. with self.assertRaisesRegexp(ValueError,
  726. "Unknown input config overriding."):
  727. config_util.update_input_reader_config(
  728. configs,
  729. key_name="eval_input_configs",
  730. input_name=None,
  731. field_name="shuffle",
  732. value=False)
  733. if __name__ == "__main__":
  734. tf.test.main()