You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

291 lines
9.5 KiB

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for object_detection.trainer."""
import tensorflow as tf
from google.protobuf import text_format
from object_detection.core import losses
from object_detection.core import model
from object_detection.core import standard_fields as fields
from object_detection.legacy import trainer
from object_detection.protos import train_pb2
NUMBER_OF_CLASSES = 2
def get_input_function():
"""A function to get test inputs. Returns an image with one box."""
image = tf.random_uniform([32, 32, 3], dtype=tf.float32)
key = tf.constant('image_000000')
class_label = tf.random_uniform(
[1], minval=0, maxval=NUMBER_OF_CLASSES, dtype=tf.int32)
box_label = tf.random_uniform(
[1, 4], minval=0.4, maxval=0.6, dtype=tf.float32)
multiclass_scores = tf.random_uniform(
[1, NUMBER_OF_CLASSES], minval=0.4, maxval=0.6, dtype=tf.float32)
return {
fields.InputDataFields.image: image,
fields.InputDataFields.key: key,
fields.InputDataFields.groundtruth_classes: class_label,
fields.InputDataFields.groundtruth_boxes: box_label,
fields.InputDataFields.multiclass_scores: multiclass_scores
}
class FakeDetectionModel(model.DetectionModel):
"""A simple (and poor) DetectionModel for use in test."""
def __init__(self):
super(FakeDetectionModel, self).__init__(num_classes=NUMBER_OF_CLASSES)
self._classification_loss = losses.WeightedSigmoidClassificationLoss()
self._localization_loss = losses.WeightedSmoothL1LocalizationLoss()
def preprocess(self, inputs):
"""Input preprocessing, resizes images to 28x28.
Args:
inputs: a [batch, height_in, width_in, channels] float32 tensor
representing a batch of images with values between 0 and 255.0.
Returns:
preprocessed_inputs: a [batch, 28, 28, channels] float32 tensor.
true_image_shapes: int32 tensor of shape [batch, 3] where each row is
of the form [height, width, channels] indicating the shapes
of true images in the resized images, as resized images can be padded
with zeros.
"""
true_image_shapes = [inputs.shape[:-1].as_list()
for _ in range(inputs.shape[-1])]
return tf.image.resize_images(inputs, [28, 28]), true_image_shapes
def predict(self, preprocessed_inputs, true_image_shapes):
"""Prediction tensors from inputs tensor.
Args:
preprocessed_inputs: a [batch, 28, 28, channels] float32 tensor.
true_image_shapes: int32 tensor of shape [batch, 3] where each row is
of the form [height, width, channels] indicating the shapes
of true images in the resized images, as resized images can be padded
with zeros.
Returns:
prediction_dict: a dictionary holding prediction tensors to be
passed to the Loss or Postprocess functions.
"""
flattened_inputs = tf.contrib.layers.flatten(preprocessed_inputs)
class_prediction = tf.contrib.layers.fully_connected(
flattened_inputs, self._num_classes)
box_prediction = tf.contrib.layers.fully_connected(flattened_inputs, 4)
return {
'class_predictions_with_background': tf.reshape(
class_prediction, [-1, 1, self._num_classes]),
'box_encodings': tf.reshape(box_prediction, [-1, 1, 4])
}
def postprocess(self, prediction_dict, true_image_shapes, **params):
"""Convert predicted output tensors to final detections. Unused.
Args:
prediction_dict: a dictionary holding prediction tensors.
true_image_shapes: int32 tensor of shape [batch, 3] where each row is
of the form [height, width, channels] indicating the shapes
of true images in the resized images, as resized images can be padded
with zeros.
**params: Additional keyword arguments for specific implementations of
DetectionModel.
Returns:
detections: a dictionary with empty fields.
"""
return {
'detection_boxes': None,
'detection_scores': None,
'detection_classes': None,
'num_detections': None
}
def loss(self, prediction_dict, true_image_shapes):
"""Compute scalar loss tensors with respect to provided groundtruth.
Calling this function requires that groundtruth tensors have been
provided via the provide_groundtruth function.
Args:
prediction_dict: a dictionary holding predicted tensors
true_image_shapes: int32 tensor of shape [batch, 3] where each row is
of the form [height, width, channels] indicating the shapes
of true images in the resized images, as resized images can be padded
with zeros.
Returns:
a dictionary mapping strings (loss names) to scalar tensors representing
loss values.
"""
batch_reg_targets = tf.stack(
self.groundtruth_lists(fields.BoxListFields.boxes))
batch_cls_targets = tf.stack(
self.groundtruth_lists(fields.BoxListFields.classes))
weights = tf.constant(
1.0, dtype=tf.float32,
shape=[len(self.groundtruth_lists(fields.BoxListFields.boxes)), 1])
location_losses = self._localization_loss(
prediction_dict['box_encodings'], batch_reg_targets,
weights=weights)
cls_losses = self._classification_loss(
prediction_dict['class_predictions_with_background'], batch_cls_targets,
weights=weights)
loss_dict = {
'localization_loss': tf.reduce_sum(location_losses),
'classification_loss': tf.reduce_sum(cls_losses),
}
return loss_dict
def regularization_losses(self):
"""Returns a list of regularization losses for this model.
Returns a list of regularization losses for this model that the estimator
needs to use during training/optimization.
Returns:
A list of regularization loss tensors.
"""
pass
def restore_map(self, fine_tune_checkpoint_type='detection'):
"""Returns a map of variables to load from a foreign checkpoint.
Args:
fine_tune_checkpoint_type: whether to restore from a full detection
checkpoint (with compatible variable names) or to restore from a
classification checkpoint for initialization prior to training.
Valid values: `detection`, `classification`. Default 'detection'.
Returns:
A dict mapping variable names to variables.
"""
return {var.op.name: var for var in tf.global_variables()}
def updates(self):
"""Returns a list of update operators for this model.
Returns a list of update operators for this model that must be executed at
each training step. The estimator's train op needs to have a control
dependency on these updates.
Returns:
A list of update operators.
"""
pass
class TrainerTest(tf.test.TestCase):
def test_configure_trainer_and_train_two_steps(self):
train_config_text_proto = """
optimizer {
adam_optimizer {
learning_rate {
constant_learning_rate {
learning_rate: 0.01
}
}
}
}
data_augmentation_options {
random_adjust_brightness {
max_delta: 0.2
}
}
data_augmentation_options {
random_adjust_contrast {
min_delta: 0.7
max_delta: 1.1
}
}
num_steps: 2
"""
train_config = train_pb2.TrainConfig()
text_format.Merge(train_config_text_proto, train_config)
train_dir = self.get_temp_dir()
trainer.train(
create_tensor_dict_fn=get_input_function,
create_model_fn=FakeDetectionModel,
train_config=train_config,
master='',
task=0,
num_clones=1,
worker_replicas=1,
clone_on_cpu=True,
ps_tasks=0,
worker_job_name='worker',
is_chief=True,
train_dir=train_dir)
def test_configure_trainer_with_multiclass_scores_and_train_two_steps(self):
train_config_text_proto = """
optimizer {
adam_optimizer {
learning_rate {
constant_learning_rate {
learning_rate: 0.01
}
}
}
}
data_augmentation_options {
random_adjust_brightness {
max_delta: 0.2
}
}
data_augmentation_options {
random_adjust_contrast {
min_delta: 0.7
max_delta: 1.1
}
}
num_steps: 2
use_multiclass_scores: true
"""
train_config = train_pb2.TrainConfig()
text_format.Merge(train_config_text_proto, train_config)
train_dir = self.get_temp_dir()
trainer.train(create_tensor_dict_fn=get_input_function,
create_model_fn=FakeDetectionModel,
train_config=train_config,
master='',
task=0,
num_clones=1,
worker_replicas=1,
clone_on_cpu=True,
ps_tasks=0,
worker_job_name='worker',
is_chief=True,
train_dir=train_dir)
if __name__ == '__main__':
tf.test.main()