You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

865 lines
29 KiB

6 years ago
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests object_detection.core.hyperparams_builder."""
  16. import numpy as np
  17. import tensorflow as tf
  18. from google.protobuf import text_format
  19. from object_detection.builders import hyperparams_builder
  20. from object_detection.core import freezable_batch_norm
  21. from object_detection.protos import hyperparams_pb2
  22. slim = tf.contrib.slim
  23. def _get_scope_key(op):
  24. return getattr(op, '_key_op', str(op))
  25. class HyperparamsBuilderTest(tf.test.TestCase):
  26. def test_default_arg_scope_has_conv2d_op(self):
  27. conv_hyperparams_text_proto = """
  28. regularizer {
  29. l1_regularizer {
  30. }
  31. }
  32. initializer {
  33. truncated_normal_initializer {
  34. }
  35. }
  36. """
  37. conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
  38. text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
  39. scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
  40. is_training=True)
  41. scope = scope_fn()
  42. self.assertTrue(_get_scope_key(slim.conv2d) in scope)
  43. def test_default_arg_scope_has_separable_conv2d_op(self):
  44. conv_hyperparams_text_proto = """
  45. regularizer {
  46. l1_regularizer {
  47. }
  48. }
  49. initializer {
  50. truncated_normal_initializer {
  51. }
  52. }
  53. """
  54. conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
  55. text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
  56. scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
  57. is_training=True)
  58. scope = scope_fn()
  59. self.assertTrue(_get_scope_key(slim.separable_conv2d) in scope)
  60. def test_default_arg_scope_has_conv2d_transpose_op(self):
  61. conv_hyperparams_text_proto = """
  62. regularizer {
  63. l1_regularizer {
  64. }
  65. }
  66. initializer {
  67. truncated_normal_initializer {
  68. }
  69. }
  70. """
  71. conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
  72. text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
  73. scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
  74. is_training=True)
  75. scope = scope_fn()
  76. self.assertTrue(_get_scope_key(slim.conv2d_transpose) in scope)
  77. def test_explicit_fc_op_arg_scope_has_fully_connected_op(self):
  78. conv_hyperparams_text_proto = """
  79. op: FC
  80. regularizer {
  81. l1_regularizer {
  82. }
  83. }
  84. initializer {
  85. truncated_normal_initializer {
  86. }
  87. }
  88. """
  89. conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
  90. text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
  91. scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
  92. is_training=True)
  93. scope = scope_fn()
  94. self.assertTrue(_get_scope_key(slim.fully_connected) in scope)
  95. def test_separable_conv2d_and_conv2d_and_transpose_have_same_parameters(self):
  96. conv_hyperparams_text_proto = """
  97. regularizer {
  98. l1_regularizer {
  99. }
  100. }
  101. initializer {
  102. truncated_normal_initializer {
  103. }
  104. }
  105. """
  106. conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
  107. text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
  108. scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
  109. is_training=True)
  110. scope = scope_fn()
  111. kwargs_1, kwargs_2, kwargs_3 = scope.values()
  112. self.assertDictEqual(kwargs_1, kwargs_2)
  113. self.assertDictEqual(kwargs_1, kwargs_3)
  114. def test_return_l1_regularized_weights(self):
  115. conv_hyperparams_text_proto = """
  116. regularizer {
  117. l1_regularizer {
  118. weight: 0.5
  119. }
  120. }
  121. initializer {
  122. truncated_normal_initializer {
  123. }
  124. }
  125. """
  126. conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
  127. text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
  128. scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
  129. is_training=True)
  130. scope = scope_fn()
  131. conv_scope_arguments = scope.values()[0]
  132. regularizer = conv_scope_arguments['weights_regularizer']
  133. weights = np.array([1., -1, 4., 2.])
  134. with self.test_session() as sess:
  135. result = sess.run(regularizer(tf.constant(weights)))
  136. self.assertAllClose(np.abs(weights).sum() * 0.5, result)
  137. def test_return_l1_regularized_weights_keras(self):
  138. conv_hyperparams_text_proto = """
  139. regularizer {
  140. l1_regularizer {
  141. weight: 0.5
  142. }
  143. }
  144. initializer {
  145. truncated_normal_initializer {
  146. }
  147. }
  148. """
  149. conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
  150. text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
  151. keras_config = hyperparams_builder.KerasLayerHyperparams(
  152. conv_hyperparams_proto)
  153. regularizer = keras_config.params()['kernel_regularizer']
  154. weights = np.array([1., -1, 4., 2.])
  155. with self.test_session() as sess:
  156. result = sess.run(regularizer(tf.constant(weights)))
  157. self.assertAllClose(np.abs(weights).sum() * 0.5, result)
  158. def test_return_l2_regularizer_weights(self):
  159. conv_hyperparams_text_proto = """
  160. regularizer {
  161. l2_regularizer {
  162. weight: 0.42
  163. }
  164. }
  165. initializer {
  166. truncated_normal_initializer {
  167. }
  168. }
  169. """
  170. conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
  171. text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
  172. scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
  173. is_training=True)
  174. scope = scope_fn()
  175. conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
  176. regularizer = conv_scope_arguments['weights_regularizer']
  177. weights = np.array([1., -1, 4., 2.])
  178. with self.test_session() as sess:
  179. result = sess.run(regularizer(tf.constant(weights)))
  180. self.assertAllClose(np.power(weights, 2).sum() / 2.0 * 0.42, result)
  181. def test_return_l2_regularizer_weights_keras(self):
  182. conv_hyperparams_text_proto = """
  183. regularizer {
  184. l2_regularizer {
  185. weight: 0.42
  186. }
  187. }
  188. initializer {
  189. truncated_normal_initializer {
  190. }
  191. }
  192. """
  193. conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
  194. text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
  195. keras_config = hyperparams_builder.KerasLayerHyperparams(
  196. conv_hyperparams_proto)
  197. regularizer = keras_config.params()['kernel_regularizer']
  198. weights = np.array([1., -1, 4., 2.])
  199. with self.test_session() as sess:
  200. result = sess.run(regularizer(tf.constant(weights)))
  201. self.assertAllClose(np.power(weights, 2).sum() / 2.0 * 0.42, result)
  202. def test_return_non_default_batch_norm_params_with_train_during_train(self):
  203. conv_hyperparams_text_proto = """
  204. regularizer {
  205. l2_regularizer {
  206. }
  207. }
  208. initializer {
  209. truncated_normal_initializer {
  210. }
  211. }
  212. batch_norm {
  213. decay: 0.7
  214. center: false
  215. scale: true
  216. epsilon: 0.03
  217. train: true
  218. }
  219. """
  220. conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
  221. text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
  222. scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
  223. is_training=True)
  224. scope = scope_fn()
  225. conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
  226. self.assertEqual(conv_scope_arguments['normalizer_fn'], slim.batch_norm)
  227. batch_norm_params = scope[_get_scope_key(slim.batch_norm)]
  228. self.assertAlmostEqual(batch_norm_params['decay'], 0.7)
  229. self.assertAlmostEqual(batch_norm_params['epsilon'], 0.03)
  230. self.assertFalse(batch_norm_params['center'])
  231. self.assertTrue(batch_norm_params['scale'])
  232. self.assertTrue(batch_norm_params['is_training'])
  233. def test_return_non_default_batch_norm_params_keras(
  234. self):
  235. conv_hyperparams_text_proto = """
  236. regularizer {
  237. l2_regularizer {
  238. }
  239. }
  240. initializer {
  241. truncated_normal_initializer {
  242. }
  243. }
  244. batch_norm {
  245. decay: 0.7
  246. center: false
  247. scale: true
  248. epsilon: 0.03
  249. }
  250. """
  251. conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
  252. text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
  253. keras_config = hyperparams_builder.KerasLayerHyperparams(
  254. conv_hyperparams_proto)
  255. self.assertTrue(keras_config.use_batch_norm())
  256. batch_norm_params = keras_config.batch_norm_params()
  257. self.assertAlmostEqual(batch_norm_params['momentum'], 0.7)
  258. self.assertAlmostEqual(batch_norm_params['epsilon'], 0.03)
  259. self.assertFalse(batch_norm_params['center'])
  260. self.assertTrue(batch_norm_params['scale'])
  261. batch_norm_layer = keras_config.build_batch_norm()
  262. self.assertTrue(isinstance(batch_norm_layer,
  263. freezable_batch_norm.FreezableBatchNorm))
  264. def test_return_non_default_batch_norm_params_keras_override(
  265. self):
  266. conv_hyperparams_text_proto = """
  267. regularizer {
  268. l2_regularizer {
  269. }
  270. }
  271. initializer {
  272. truncated_normal_initializer {
  273. }
  274. }
  275. batch_norm {
  276. decay: 0.7
  277. center: false
  278. scale: true
  279. epsilon: 0.03
  280. }
  281. """
  282. conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
  283. text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
  284. keras_config = hyperparams_builder.KerasLayerHyperparams(
  285. conv_hyperparams_proto)
  286. self.assertTrue(keras_config.use_batch_norm())
  287. batch_norm_params = keras_config.batch_norm_params(momentum=0.4)
  288. self.assertAlmostEqual(batch_norm_params['momentum'], 0.4)
  289. self.assertAlmostEqual(batch_norm_params['epsilon'], 0.03)
  290. self.assertFalse(batch_norm_params['center'])
  291. self.assertTrue(batch_norm_params['scale'])
  292. def test_return_batch_norm_params_with_notrain_during_eval(self):
  293. conv_hyperparams_text_proto = """
  294. regularizer {
  295. l2_regularizer {
  296. }
  297. }
  298. initializer {
  299. truncated_normal_initializer {
  300. }
  301. }
  302. batch_norm {
  303. decay: 0.7
  304. center: false
  305. scale: true
  306. epsilon: 0.03
  307. train: true
  308. }
  309. """
  310. conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
  311. text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
  312. scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
  313. is_training=False)
  314. scope = scope_fn()
  315. conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
  316. self.assertEqual(conv_scope_arguments['normalizer_fn'], slim.batch_norm)
  317. batch_norm_params = scope[_get_scope_key(slim.batch_norm)]
  318. self.assertAlmostEqual(batch_norm_params['decay'], 0.7)
  319. self.assertAlmostEqual(batch_norm_params['epsilon'], 0.03)
  320. self.assertFalse(batch_norm_params['center'])
  321. self.assertTrue(batch_norm_params['scale'])
  322. self.assertFalse(batch_norm_params['is_training'])
  323. def test_return_batch_norm_params_with_notrain_when_train_is_false(self):
  324. conv_hyperparams_text_proto = """
  325. regularizer {
  326. l2_regularizer {
  327. }
  328. }
  329. initializer {
  330. truncated_normal_initializer {
  331. }
  332. }
  333. batch_norm {
  334. decay: 0.7
  335. center: false
  336. scale: true
  337. epsilon: 0.03
  338. train: false
  339. }
  340. """
  341. conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
  342. text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
  343. scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
  344. is_training=True)
  345. scope = scope_fn()
  346. conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
  347. self.assertEqual(conv_scope_arguments['normalizer_fn'], slim.batch_norm)
  348. batch_norm_params = scope[_get_scope_key(slim.batch_norm)]
  349. self.assertAlmostEqual(batch_norm_params['decay'], 0.7)
  350. self.assertAlmostEqual(batch_norm_params['epsilon'], 0.03)
  351. self.assertFalse(batch_norm_params['center'])
  352. self.assertTrue(batch_norm_params['scale'])
  353. self.assertFalse(batch_norm_params['is_training'])
  354. def test_do_not_use_batch_norm_if_default(self):
  355. conv_hyperparams_text_proto = """
  356. regularizer {
  357. l2_regularizer {
  358. }
  359. }
  360. initializer {
  361. truncated_normal_initializer {
  362. }
  363. }
  364. """
  365. conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
  366. text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
  367. scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
  368. is_training=True)
  369. scope = scope_fn()
  370. conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
  371. self.assertEqual(conv_scope_arguments['normalizer_fn'], None)
  372. def test_do_not_use_batch_norm_if_default_keras(self):
  373. conv_hyperparams_text_proto = """
  374. regularizer {
  375. l2_regularizer {
  376. }
  377. }
  378. initializer {
  379. truncated_normal_initializer {
  380. }
  381. }
  382. """
  383. conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
  384. text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
  385. keras_config = hyperparams_builder.KerasLayerHyperparams(
  386. conv_hyperparams_proto)
  387. self.assertFalse(keras_config.use_batch_norm())
  388. self.assertEqual(keras_config.batch_norm_params(), {})
  389. # The batch norm builder should build an identity Lambda layer
  390. identity_layer = keras_config.build_batch_norm()
  391. self.assertTrue(isinstance(identity_layer,
  392. tf.keras.layers.Lambda))
  393. def test_use_none_activation(self):
  394. conv_hyperparams_text_proto = """
  395. regularizer {
  396. l2_regularizer {
  397. }
  398. }
  399. initializer {
  400. truncated_normal_initializer {
  401. }
  402. }
  403. activation: NONE
  404. """
  405. conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
  406. text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
  407. scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
  408. is_training=True)
  409. scope = scope_fn()
  410. conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
  411. self.assertEqual(conv_scope_arguments['activation_fn'], None)
  412. def test_use_none_activation_keras(self):
  413. conv_hyperparams_text_proto = """
  414. regularizer {
  415. l2_regularizer {
  416. }
  417. }
  418. initializer {
  419. truncated_normal_initializer {
  420. }
  421. }
  422. activation: NONE
  423. """
  424. conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
  425. text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
  426. keras_config = hyperparams_builder.KerasLayerHyperparams(
  427. conv_hyperparams_proto)
  428. self.assertEqual(keras_config.params()['activation'], None)
  429. self.assertEqual(
  430. keras_config.params(include_activation=True)['activation'], None)
  431. activation_layer = keras_config.build_activation_layer()
  432. self.assertTrue(isinstance(activation_layer, tf.keras.layers.Lambda))
  433. self.assertEqual(activation_layer.function, tf.identity)
  434. def test_use_relu_activation(self):
  435. conv_hyperparams_text_proto = """
  436. regularizer {
  437. l2_regularizer {
  438. }
  439. }
  440. initializer {
  441. truncated_normal_initializer {
  442. }
  443. }
  444. activation: RELU
  445. """
  446. conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
  447. text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
  448. scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
  449. is_training=True)
  450. scope = scope_fn()
  451. conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
  452. self.assertEqual(conv_scope_arguments['activation_fn'], tf.nn.relu)
  453. def test_use_relu_activation_keras(self):
  454. conv_hyperparams_text_proto = """
  455. regularizer {
  456. l2_regularizer {
  457. }
  458. }
  459. initializer {
  460. truncated_normal_initializer {
  461. }
  462. }
  463. activation: RELU
  464. """
  465. conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
  466. text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
  467. keras_config = hyperparams_builder.KerasLayerHyperparams(
  468. conv_hyperparams_proto)
  469. self.assertEqual(keras_config.params()['activation'], None)
  470. self.assertEqual(
  471. keras_config.params(include_activation=True)['activation'], tf.nn.relu)
  472. activation_layer = keras_config.build_activation_layer()
  473. self.assertTrue(isinstance(activation_layer, tf.keras.layers.Lambda))
  474. self.assertEqual(activation_layer.function, tf.nn.relu)
  475. def test_use_relu_6_activation(self):
  476. conv_hyperparams_text_proto = """
  477. regularizer {
  478. l2_regularizer {
  479. }
  480. }
  481. initializer {
  482. truncated_normal_initializer {
  483. }
  484. }
  485. activation: RELU_6
  486. """
  487. conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
  488. text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
  489. scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
  490. is_training=True)
  491. scope = scope_fn()
  492. conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
  493. self.assertEqual(conv_scope_arguments['activation_fn'], tf.nn.relu6)
  494. def test_use_relu_6_activation_keras(self):
  495. conv_hyperparams_text_proto = """
  496. regularizer {
  497. l2_regularizer {
  498. }
  499. }
  500. initializer {
  501. truncated_normal_initializer {
  502. }
  503. }
  504. activation: RELU_6
  505. """
  506. conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
  507. text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
  508. keras_config = hyperparams_builder.KerasLayerHyperparams(
  509. conv_hyperparams_proto)
  510. self.assertEqual(keras_config.params()['activation'], None)
  511. self.assertEqual(
  512. keras_config.params(include_activation=True)['activation'], tf.nn.relu6)
  513. activation_layer = keras_config.build_activation_layer()
  514. self.assertTrue(isinstance(activation_layer, tf.keras.layers.Lambda))
  515. self.assertEqual(activation_layer.function, tf.nn.relu6)
  516. def test_override_activation_keras(self):
  517. conv_hyperparams_text_proto = """
  518. regularizer {
  519. l2_regularizer {
  520. }
  521. }
  522. initializer {
  523. truncated_normal_initializer {
  524. }
  525. }
  526. activation: RELU_6
  527. """
  528. conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
  529. text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
  530. keras_config = hyperparams_builder.KerasLayerHyperparams(
  531. conv_hyperparams_proto)
  532. new_params = keras_config.params(activation=tf.nn.relu)
  533. self.assertEqual(new_params['activation'], tf.nn.relu)
  534. def _assert_variance_in_range(self, initializer, shape, variance,
  535. tol=1e-2):
  536. with tf.Graph().as_default() as g:
  537. with self.test_session(graph=g) as sess:
  538. var = tf.get_variable(
  539. name='test',
  540. shape=shape,
  541. dtype=tf.float32,
  542. initializer=initializer)
  543. sess.run(tf.global_variables_initializer())
  544. values = sess.run(var)
  545. self.assertAllClose(np.var(values), variance, tol, tol)
  546. def test_variance_in_range_with_variance_scaling_initializer_fan_in(self):
  547. conv_hyperparams_text_proto = """
  548. regularizer {
  549. l2_regularizer {
  550. }
  551. }
  552. initializer {
  553. variance_scaling_initializer {
  554. factor: 2.0
  555. mode: FAN_IN
  556. uniform: false
  557. }
  558. }
  559. """
  560. conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
  561. text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
  562. scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
  563. is_training=True)
  564. scope = scope_fn()
  565. conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
  566. initializer = conv_scope_arguments['weights_initializer']
  567. self._assert_variance_in_range(initializer, shape=[100, 40],
  568. variance=2. / 100.)
  569. def test_variance_in_range_with_variance_scaling_initializer_fan_in_keras(
  570. self):
  571. conv_hyperparams_text_proto = """
  572. regularizer {
  573. l2_regularizer {
  574. }
  575. }
  576. initializer {
  577. variance_scaling_initializer {
  578. factor: 2.0
  579. mode: FAN_IN
  580. uniform: false
  581. }
  582. }
  583. """
  584. conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
  585. text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
  586. keras_config = hyperparams_builder.KerasLayerHyperparams(
  587. conv_hyperparams_proto)
  588. initializer = keras_config.params()['kernel_initializer']
  589. self._assert_variance_in_range(initializer, shape=[100, 40],
  590. variance=2. / 100.)
  591. def test_variance_in_range_with_variance_scaling_initializer_fan_out(self):
  592. conv_hyperparams_text_proto = """
  593. regularizer {
  594. l2_regularizer {
  595. }
  596. }
  597. initializer {
  598. variance_scaling_initializer {
  599. factor: 2.0
  600. mode: FAN_OUT
  601. uniform: false
  602. }
  603. }
  604. """
  605. conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
  606. text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
  607. scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
  608. is_training=True)
  609. scope = scope_fn()
  610. conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
  611. initializer = conv_scope_arguments['weights_initializer']
  612. self._assert_variance_in_range(initializer, shape=[100, 40],
  613. variance=2. / 40.)
  614. def test_variance_in_range_with_variance_scaling_initializer_fan_out_keras(
  615. self):
  616. conv_hyperparams_text_proto = """
  617. regularizer {
  618. l2_regularizer {
  619. }
  620. }
  621. initializer {
  622. variance_scaling_initializer {
  623. factor: 2.0
  624. mode: FAN_OUT
  625. uniform: false
  626. }
  627. }
  628. """
  629. conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
  630. text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
  631. keras_config = hyperparams_builder.KerasLayerHyperparams(
  632. conv_hyperparams_proto)
  633. initializer = keras_config.params()['kernel_initializer']
  634. self._assert_variance_in_range(initializer, shape=[100, 40],
  635. variance=2. / 40.)
  636. def test_variance_in_range_with_variance_scaling_initializer_fan_avg(self):
  637. conv_hyperparams_text_proto = """
  638. regularizer {
  639. l2_regularizer {
  640. }
  641. }
  642. initializer {
  643. variance_scaling_initializer {
  644. factor: 2.0
  645. mode: FAN_AVG
  646. uniform: false
  647. }
  648. }
  649. """
  650. conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
  651. text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
  652. scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
  653. is_training=True)
  654. scope = scope_fn()
  655. conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
  656. initializer = conv_scope_arguments['weights_initializer']
  657. self._assert_variance_in_range(initializer, shape=[100, 40],
  658. variance=4. / (100. + 40.))
  659. def test_variance_in_range_with_variance_scaling_initializer_fan_avg_keras(
  660. self):
  661. conv_hyperparams_text_proto = """
  662. regularizer {
  663. l2_regularizer {
  664. }
  665. }
  666. initializer {
  667. variance_scaling_initializer {
  668. factor: 2.0
  669. mode: FAN_AVG
  670. uniform: false
  671. }
  672. }
  673. """
  674. conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
  675. text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
  676. keras_config = hyperparams_builder.KerasLayerHyperparams(
  677. conv_hyperparams_proto)
  678. initializer = keras_config.params()['kernel_initializer']
  679. self._assert_variance_in_range(initializer, shape=[100, 40],
  680. variance=4. / (100. + 40.))
  681. def test_variance_in_range_with_variance_scaling_initializer_uniform(self):
  682. conv_hyperparams_text_proto = """
  683. regularizer {
  684. l2_regularizer {
  685. }
  686. }
  687. initializer {
  688. variance_scaling_initializer {
  689. factor: 2.0
  690. mode: FAN_IN
  691. uniform: true
  692. }
  693. }
  694. """
  695. conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
  696. text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
  697. scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
  698. is_training=True)
  699. scope = scope_fn()
  700. conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
  701. initializer = conv_scope_arguments['weights_initializer']
  702. self._assert_variance_in_range(initializer, shape=[100, 40],
  703. variance=2. / 100.)
  704. def test_variance_in_range_with_variance_scaling_initializer_uniform_keras(
  705. self):
  706. conv_hyperparams_text_proto = """
  707. regularizer {
  708. l2_regularizer {
  709. }
  710. }
  711. initializer {
  712. variance_scaling_initializer {
  713. factor: 2.0
  714. mode: FAN_IN
  715. uniform: true
  716. }
  717. }
  718. """
  719. conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
  720. text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
  721. keras_config = hyperparams_builder.KerasLayerHyperparams(
  722. conv_hyperparams_proto)
  723. initializer = keras_config.params()['kernel_initializer']
  724. self._assert_variance_in_range(initializer, shape=[100, 40],
  725. variance=2. / 100.)
  726. def test_variance_in_range_with_truncated_normal_initializer(self):
  727. conv_hyperparams_text_proto = """
  728. regularizer {
  729. l2_regularizer {
  730. }
  731. }
  732. initializer {
  733. truncated_normal_initializer {
  734. mean: 0.0
  735. stddev: 0.8
  736. }
  737. }
  738. """
  739. conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
  740. text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
  741. scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
  742. is_training=True)
  743. scope = scope_fn()
  744. conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
  745. initializer = conv_scope_arguments['weights_initializer']
  746. self._assert_variance_in_range(initializer, shape=[100, 40],
  747. variance=0.49, tol=1e-1)
  748. def test_variance_in_range_with_truncated_normal_initializer_keras(self):
  749. conv_hyperparams_text_proto = """
  750. regularizer {
  751. l2_regularizer {
  752. }
  753. }
  754. initializer {
  755. truncated_normal_initializer {
  756. mean: 0.0
  757. stddev: 0.8
  758. }
  759. }
  760. """
  761. conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
  762. text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
  763. keras_config = hyperparams_builder.KerasLayerHyperparams(
  764. conv_hyperparams_proto)
  765. initializer = keras_config.params()['kernel_initializer']
  766. self._assert_variance_in_range(initializer, shape=[100, 40],
  767. variance=0.49, tol=1e-1)
  768. def test_variance_in_range_with_random_normal_initializer(self):
  769. conv_hyperparams_text_proto = """
  770. regularizer {
  771. l2_regularizer {
  772. }
  773. }
  774. initializer {
  775. random_normal_initializer {
  776. mean: 0.0
  777. stddev: 0.8
  778. }
  779. }
  780. """
  781. conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
  782. text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
  783. scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
  784. is_training=True)
  785. scope = scope_fn()
  786. conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
  787. initializer = conv_scope_arguments['weights_initializer']
  788. self._assert_variance_in_range(initializer, shape=[100, 40],
  789. variance=0.64, tol=1e-1)
  790. def test_variance_in_range_with_random_normal_initializer_keras(self):
  791. conv_hyperparams_text_proto = """
  792. regularizer {
  793. l2_regularizer {
  794. }
  795. }
  796. initializer {
  797. random_normal_initializer {
  798. mean: 0.0
  799. stddev: 0.8
  800. }
  801. }
  802. """
  803. conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
  804. text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
  805. keras_config = hyperparams_builder.KerasLayerHyperparams(
  806. conv_hyperparams_proto)
  807. initializer = keras_config.params()['kernel_initializer']
  808. self._assert_variance_in_range(initializer, shape=[100, 40],
  809. variance=0.64, tol=1e-1)
  810. if __name__ == '__main__':
  811. tf.test.main()