You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

865 lines
29 KiB

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests object_detection.core.hyperparams_builder."""
import numpy as np
import tensorflow as tf
from google.protobuf import text_format
from object_detection.builders import hyperparams_builder
from object_detection.core import freezable_batch_norm
from object_detection.protos import hyperparams_pb2
slim = tf.contrib.slim
def _get_scope_key(op):
return getattr(op, '_key_op', str(op))
class HyperparamsBuilderTest(tf.test.TestCase):
def test_default_arg_scope_has_conv2d_op(self):
conv_hyperparams_text_proto = """
regularizer {
l1_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
self.assertTrue(_get_scope_key(slim.conv2d) in scope)
def test_default_arg_scope_has_separable_conv2d_op(self):
conv_hyperparams_text_proto = """
regularizer {
l1_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
self.assertTrue(_get_scope_key(slim.separable_conv2d) in scope)
def test_default_arg_scope_has_conv2d_transpose_op(self):
conv_hyperparams_text_proto = """
regularizer {
l1_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
self.assertTrue(_get_scope_key(slim.conv2d_transpose) in scope)
def test_explicit_fc_op_arg_scope_has_fully_connected_op(self):
conv_hyperparams_text_proto = """
op: FC
regularizer {
l1_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
self.assertTrue(_get_scope_key(slim.fully_connected) in scope)
def test_separable_conv2d_and_conv2d_and_transpose_have_same_parameters(self):
conv_hyperparams_text_proto = """
regularizer {
l1_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
kwargs_1, kwargs_2, kwargs_3 = scope.values()
self.assertDictEqual(kwargs_1, kwargs_2)
self.assertDictEqual(kwargs_1, kwargs_3)
def test_return_l1_regularized_weights(self):
conv_hyperparams_text_proto = """
regularizer {
l1_regularizer {
weight: 0.5
}
}
initializer {
truncated_normal_initializer {
}
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope.values()[0]
regularizer = conv_scope_arguments['weights_regularizer']
weights = np.array([1., -1, 4., 2.])
with self.test_session() as sess:
result = sess.run(regularizer(tf.constant(weights)))
self.assertAllClose(np.abs(weights).sum() * 0.5, result)
def test_return_l1_regularized_weights_keras(self):
conv_hyperparams_text_proto = """
regularizer {
l1_regularizer {
weight: 0.5
}
}
initializer {
truncated_normal_initializer {
}
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
keras_config = hyperparams_builder.KerasLayerHyperparams(
conv_hyperparams_proto)
regularizer = keras_config.params()['kernel_regularizer']
weights = np.array([1., -1, 4., 2.])
with self.test_session() as sess:
result = sess.run(regularizer(tf.constant(weights)))
self.assertAllClose(np.abs(weights).sum() * 0.5, result)
def test_return_l2_regularizer_weights(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
weight: 0.42
}
}
initializer {
truncated_normal_initializer {
}
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
regularizer = conv_scope_arguments['weights_regularizer']
weights = np.array([1., -1, 4., 2.])
with self.test_session() as sess:
result = sess.run(regularizer(tf.constant(weights)))
self.assertAllClose(np.power(weights, 2).sum() / 2.0 * 0.42, result)
def test_return_l2_regularizer_weights_keras(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
weight: 0.42
}
}
initializer {
truncated_normal_initializer {
}
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
keras_config = hyperparams_builder.KerasLayerHyperparams(
conv_hyperparams_proto)
regularizer = keras_config.params()['kernel_regularizer']
weights = np.array([1., -1, 4., 2.])
with self.test_session() as sess:
result = sess.run(regularizer(tf.constant(weights)))
self.assertAllClose(np.power(weights, 2).sum() / 2.0 * 0.42, result)
def test_return_non_default_batch_norm_params_with_train_during_train(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
batch_norm {
decay: 0.7
center: false
scale: true
epsilon: 0.03
train: true
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
self.assertEqual(conv_scope_arguments['normalizer_fn'], slim.batch_norm)
batch_norm_params = scope[_get_scope_key(slim.batch_norm)]
self.assertAlmostEqual(batch_norm_params['decay'], 0.7)
self.assertAlmostEqual(batch_norm_params['epsilon'], 0.03)
self.assertFalse(batch_norm_params['center'])
self.assertTrue(batch_norm_params['scale'])
self.assertTrue(batch_norm_params['is_training'])
def test_return_non_default_batch_norm_params_keras(
self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
batch_norm {
decay: 0.7
center: false
scale: true
epsilon: 0.03
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
keras_config = hyperparams_builder.KerasLayerHyperparams(
conv_hyperparams_proto)
self.assertTrue(keras_config.use_batch_norm())
batch_norm_params = keras_config.batch_norm_params()
self.assertAlmostEqual(batch_norm_params['momentum'], 0.7)
self.assertAlmostEqual(batch_norm_params['epsilon'], 0.03)
self.assertFalse(batch_norm_params['center'])
self.assertTrue(batch_norm_params['scale'])
batch_norm_layer = keras_config.build_batch_norm()
self.assertTrue(isinstance(batch_norm_layer,
freezable_batch_norm.FreezableBatchNorm))
def test_return_non_default_batch_norm_params_keras_override(
self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
batch_norm {
decay: 0.7
center: false
scale: true
epsilon: 0.03
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
keras_config = hyperparams_builder.KerasLayerHyperparams(
conv_hyperparams_proto)
self.assertTrue(keras_config.use_batch_norm())
batch_norm_params = keras_config.batch_norm_params(momentum=0.4)
self.assertAlmostEqual(batch_norm_params['momentum'], 0.4)
self.assertAlmostEqual(batch_norm_params['epsilon'], 0.03)
self.assertFalse(batch_norm_params['center'])
self.assertTrue(batch_norm_params['scale'])
def test_return_batch_norm_params_with_notrain_during_eval(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
batch_norm {
decay: 0.7
center: false
scale: true
epsilon: 0.03
train: true
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=False)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
self.assertEqual(conv_scope_arguments['normalizer_fn'], slim.batch_norm)
batch_norm_params = scope[_get_scope_key(slim.batch_norm)]
self.assertAlmostEqual(batch_norm_params['decay'], 0.7)
self.assertAlmostEqual(batch_norm_params['epsilon'], 0.03)
self.assertFalse(batch_norm_params['center'])
self.assertTrue(batch_norm_params['scale'])
self.assertFalse(batch_norm_params['is_training'])
def test_return_batch_norm_params_with_notrain_when_train_is_false(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
batch_norm {
decay: 0.7
center: false
scale: true
epsilon: 0.03
train: false
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
self.assertEqual(conv_scope_arguments['normalizer_fn'], slim.batch_norm)
batch_norm_params = scope[_get_scope_key(slim.batch_norm)]
self.assertAlmostEqual(batch_norm_params['decay'], 0.7)
self.assertAlmostEqual(batch_norm_params['epsilon'], 0.03)
self.assertFalse(batch_norm_params['center'])
self.assertTrue(batch_norm_params['scale'])
self.assertFalse(batch_norm_params['is_training'])
def test_do_not_use_batch_norm_if_default(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
self.assertEqual(conv_scope_arguments['normalizer_fn'], None)
def test_do_not_use_batch_norm_if_default_keras(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
keras_config = hyperparams_builder.KerasLayerHyperparams(
conv_hyperparams_proto)
self.assertFalse(keras_config.use_batch_norm())
self.assertEqual(keras_config.batch_norm_params(), {})
# The batch norm builder should build an identity Lambda layer
identity_layer = keras_config.build_batch_norm()
self.assertTrue(isinstance(identity_layer,
tf.keras.layers.Lambda))
def test_use_none_activation(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
activation: NONE
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
self.assertEqual(conv_scope_arguments['activation_fn'], None)
def test_use_none_activation_keras(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
activation: NONE
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
keras_config = hyperparams_builder.KerasLayerHyperparams(
conv_hyperparams_proto)
self.assertEqual(keras_config.params()['activation'], None)
self.assertEqual(
keras_config.params(include_activation=True)['activation'], None)
activation_layer = keras_config.build_activation_layer()
self.assertTrue(isinstance(activation_layer, tf.keras.layers.Lambda))
self.assertEqual(activation_layer.function, tf.identity)
def test_use_relu_activation(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
activation: RELU
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
self.assertEqual(conv_scope_arguments['activation_fn'], tf.nn.relu)
def test_use_relu_activation_keras(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
activation: RELU
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
keras_config = hyperparams_builder.KerasLayerHyperparams(
conv_hyperparams_proto)
self.assertEqual(keras_config.params()['activation'], None)
self.assertEqual(
keras_config.params(include_activation=True)['activation'], tf.nn.relu)
activation_layer = keras_config.build_activation_layer()
self.assertTrue(isinstance(activation_layer, tf.keras.layers.Lambda))
self.assertEqual(activation_layer.function, tf.nn.relu)
def test_use_relu_6_activation(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
activation: RELU_6
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
self.assertEqual(conv_scope_arguments['activation_fn'], tf.nn.relu6)
def test_use_relu_6_activation_keras(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
activation: RELU_6
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
keras_config = hyperparams_builder.KerasLayerHyperparams(
conv_hyperparams_proto)
self.assertEqual(keras_config.params()['activation'], None)
self.assertEqual(
keras_config.params(include_activation=True)['activation'], tf.nn.relu6)
activation_layer = keras_config.build_activation_layer()
self.assertTrue(isinstance(activation_layer, tf.keras.layers.Lambda))
self.assertEqual(activation_layer.function, tf.nn.relu6)
def test_override_activation_keras(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
activation: RELU_6
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
keras_config = hyperparams_builder.KerasLayerHyperparams(
conv_hyperparams_proto)
new_params = keras_config.params(activation=tf.nn.relu)
self.assertEqual(new_params['activation'], tf.nn.relu)
def _assert_variance_in_range(self, initializer, shape, variance,
tol=1e-2):
with tf.Graph().as_default() as g:
with self.test_session(graph=g) as sess:
var = tf.get_variable(
name='test',
shape=shape,
dtype=tf.float32,
initializer=initializer)
sess.run(tf.global_variables_initializer())
values = sess.run(var)
self.assertAllClose(np.var(values), variance, tol, tol)
def test_variance_in_range_with_variance_scaling_initializer_fan_in(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
variance_scaling_initializer {
factor: 2.0
mode: FAN_IN
uniform: false
}
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
initializer = conv_scope_arguments['weights_initializer']
self._assert_variance_in_range(initializer, shape=[100, 40],
variance=2. / 100.)
def test_variance_in_range_with_variance_scaling_initializer_fan_in_keras(
self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
variance_scaling_initializer {
factor: 2.0
mode: FAN_IN
uniform: false
}
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
keras_config = hyperparams_builder.KerasLayerHyperparams(
conv_hyperparams_proto)
initializer = keras_config.params()['kernel_initializer']
self._assert_variance_in_range(initializer, shape=[100, 40],
variance=2. / 100.)
def test_variance_in_range_with_variance_scaling_initializer_fan_out(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
variance_scaling_initializer {
factor: 2.0
mode: FAN_OUT
uniform: false
}
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
initializer = conv_scope_arguments['weights_initializer']
self._assert_variance_in_range(initializer, shape=[100, 40],
variance=2. / 40.)
def test_variance_in_range_with_variance_scaling_initializer_fan_out_keras(
self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
variance_scaling_initializer {
factor: 2.0
mode: FAN_OUT
uniform: false
}
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
keras_config = hyperparams_builder.KerasLayerHyperparams(
conv_hyperparams_proto)
initializer = keras_config.params()['kernel_initializer']
self._assert_variance_in_range(initializer, shape=[100, 40],
variance=2. / 40.)
def test_variance_in_range_with_variance_scaling_initializer_fan_avg(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
variance_scaling_initializer {
factor: 2.0
mode: FAN_AVG
uniform: false
}
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
initializer = conv_scope_arguments['weights_initializer']
self._assert_variance_in_range(initializer, shape=[100, 40],
variance=4. / (100. + 40.))
def test_variance_in_range_with_variance_scaling_initializer_fan_avg_keras(
self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
variance_scaling_initializer {
factor: 2.0
mode: FAN_AVG
uniform: false
}
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
keras_config = hyperparams_builder.KerasLayerHyperparams(
conv_hyperparams_proto)
initializer = keras_config.params()['kernel_initializer']
self._assert_variance_in_range(initializer, shape=[100, 40],
variance=4. / (100. + 40.))
def test_variance_in_range_with_variance_scaling_initializer_uniform(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
variance_scaling_initializer {
factor: 2.0
mode: FAN_IN
uniform: true
}
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
initializer = conv_scope_arguments['weights_initializer']
self._assert_variance_in_range(initializer, shape=[100, 40],
variance=2. / 100.)
def test_variance_in_range_with_variance_scaling_initializer_uniform_keras(
self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
variance_scaling_initializer {
factor: 2.0
mode: FAN_IN
uniform: true
}
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
keras_config = hyperparams_builder.KerasLayerHyperparams(
conv_hyperparams_proto)
initializer = keras_config.params()['kernel_initializer']
self._assert_variance_in_range(initializer, shape=[100, 40],
variance=2. / 100.)
def test_variance_in_range_with_truncated_normal_initializer(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
mean: 0.0
stddev: 0.8
}
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
initializer = conv_scope_arguments['weights_initializer']
self._assert_variance_in_range(initializer, shape=[100, 40],
variance=0.49, tol=1e-1)
def test_variance_in_range_with_truncated_normal_initializer_keras(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
mean: 0.0
stddev: 0.8
}
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
keras_config = hyperparams_builder.KerasLayerHyperparams(
conv_hyperparams_proto)
initializer = keras_config.params()['kernel_initializer']
self._assert_variance_in_range(initializer, shape=[100, 40],
variance=0.49, tol=1e-1)
def test_variance_in_range_with_random_normal_initializer(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
random_normal_initializer {
mean: 0.0
stddev: 0.8
}
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
scope_fn = hyperparams_builder.build(conv_hyperparams_proto,
is_training=True)
scope = scope_fn()
conv_scope_arguments = scope[_get_scope_key(slim.conv2d)]
initializer = conv_scope_arguments['weights_initializer']
self._assert_variance_in_range(initializer, shape=[100, 40],
variance=0.64, tol=1e-1)
def test_variance_in_range_with_random_normal_initializer_keras(self):
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
random_normal_initializer {
mean: 0.0
stddev: 0.8
}
}
"""
conv_hyperparams_proto = hyperparams_pb2.Hyperparams()
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams_proto)
keras_config = hyperparams_builder.KerasLayerHyperparams(
conv_hyperparams_proto)
initializer = keras_config.params()['kernel_initializer']
self._assert_variance_in_range(initializer, shape=[100, 40],
variance=0.64, tol=1e-1)
if __name__ == '__main__':
tf.test.main()