- # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """Tests for object_detection.trainer."""
- import tensorflow as tf
- from google.protobuf import text_format
- from object_detection.core import losses
- from object_detection.core import model
- from object_detection.core import standard_fields as fields
- from object_detection.legacy import trainer
- from object_detection.protos import train_pb2
- def get_input_function():
- """A function to get test inputs. Returns an image with one box."""
- image = tf.random_uniform([32, 32, 3], dtype=tf.float32)
- key = tf.constant('image_000000')
- class_label = tf.random_uniform(
- [1], minval=0, maxval=NUMBER_OF_CLASSES, dtype=tf.int32)
- box_label = tf.random_uniform(
- [1, 4], minval=0.4, maxval=0.6, dtype=tf.float32)
- multiclass_scores = tf.random_uniform(
- [1, NUMBER_OF_CLASSES], minval=0.4, maxval=0.6, dtype=tf.float32)
- return {
- fields.InputDataFields.image: image,
- fields.InputDataFields.key: key,
- fields.InputDataFields.groundtruth_classes: class_label,
- fields.InputDataFields.groundtruth_boxes: box_label,
- fields.InputDataFields.multiclass_scores: multiclass_scores
- }
- class FakeDetectionModel(model.DetectionModel):
- """A simple (and poor) DetectionModel for use in test."""
- def __init__(self):
- super(FakeDetectionModel, self).__init__(num_classes=NUMBER_OF_CLASSES)
- self._classification_loss = losses.WeightedSigmoidClassificationLoss()
- self._localization_loss = losses.WeightedSmoothL1LocalizationLoss()
- def preprocess(self, inputs):
- """Input preprocessing, resizes images to 28x28.
- Args:
- inputs: a [batch, height_in, width_in, channels] float32 tensor
- representing a batch of images with values between 0 and 255.0.
- Returns:
- preprocessed_inputs: a [batch, 28, 28, channels] float32 tensor.
- true_image_shapes: int32 tensor of shape [batch, 3] where each row is
- of the form [height, width, channels] indicating the shapes
- of true images in the resized images, as resized images can be padded
- with zeros.
- """
- true_image_shapes = [inputs.shape[:-1].as_list()
- for _ in range(inputs.shape[-1])]
- return tf.image.resize_images(inputs, [28, 28]), true_image_shapes
- def predict(self, preprocessed_inputs, true_image_shapes):
- """Prediction tensors from inputs tensor.
- Args:
- preprocessed_inputs: a [batch, 28, 28, channels] float32 tensor.
- true_image_shapes: int32 tensor of shape [batch, 3] where each row is
- of the form [height, width, channels] indicating the shapes
- of true images in the resized images, as resized images can be padded
- with zeros.
- Returns:
- prediction_dict: a dictionary holding prediction tensors to be
- passed to the Loss or Postprocess functions.
- """
- flattened_inputs = tf.contrib.layers.flatten(preprocessed_inputs)
- class_prediction = tf.contrib.layers.fully_connected(
- flattened_inputs, self._num_classes)
- box_prediction = tf.contrib.layers.fully_connected(flattened_inputs, 4)
- return {
- 'class_predictions_with_background': tf.reshape(
- class_prediction, [-1, 1, self._num_classes]),
- 'box_encodings': tf.reshape(box_prediction, [-1, 1, 4])
- }
- def postprocess(self, prediction_dict, true_image_shapes, **params):
- """Convert predicted output tensors to final detections. Unused.
- Args:
- prediction_dict: a dictionary holding prediction tensors.
- true_image_shapes: int32 tensor of shape [batch, 3] where each row is
- of the form [height, width, channels] indicating the shapes
- of true images in the resized images, as resized images can be padded
- with zeros.
- **params: Additional keyword arguments for specific implementations of
- DetectionModel.
- Returns:
- detections: a dictionary with empty fields.
- """
- return {
- 'detection_boxes': None,
- 'detection_scores': None,
- 'detection_classes': None,
- 'num_detections': None
- }
- def loss(self, prediction_dict, true_image_shapes):
- """Compute scalar loss tensors with respect to provided groundtruth.
- Calling this function requires that groundtruth tensors have been
- provided via the provide_groundtruth function.
- Args:
- prediction_dict: a dictionary holding predicted tensors
- true_image_shapes: int32 tensor of shape [batch, 3] where each row is
- of the form [height, width, channels] indicating the shapes
- of true images in the resized images, as resized images can be padded
- with zeros.
- Returns:
- a dictionary mapping strings (loss names) to scalar tensors representing
- loss values.
- """
- batch_reg_targets = tf.stack(
- self.groundtruth_lists(fields.BoxListFields.boxes))
- batch_cls_targets = tf.stack(
- self.groundtruth_lists(fields.BoxListFields.classes))
- weights = tf.constant(
- 1.0, dtype=tf.float32,
- shape=[len(self.groundtruth_lists(fields.BoxListFields.boxes)), 1])
- location_losses = self._localization_loss(
- prediction_dict['box_encodings'], batch_reg_targets,
- weights=weights)
- cls_losses = self._classification_loss(
- prediction_dict['class_predictions_with_background'], batch_cls_targets,
- weights=weights)
- loss_dict = {
- 'localization_loss': tf.reduce_sum(location_losses),
- 'classification_loss': tf.reduce_sum(cls_losses),
- }
- return loss_dict
- def regularization_losses(self):
- """Returns a list of regularization losses for this model.
- Returns a list of regularization losses for this model that the estimator
- needs to use during training/optimization.
- Returns:
- A list of regularization loss tensors.
- """
- pass
- def restore_map(self, fine_tune_checkpoint_type='detection'):
- """Returns a map of variables to load from a foreign checkpoint.
- Args:
- fine_tune_checkpoint_type: whether to restore from a full detection
- checkpoint (with compatible variable names) or to restore from a
- classification checkpoint for initialization prior to training.
- Valid values: `detection`, `classification`. Default 'detection'.
- Returns:
- A dict mapping variable names to variables.
- """
- return {var.op.name: var for var in tf.global_variables()}
- def updates(self):
- """Returns a list of update operators for this model.
- Returns a list of update operators for this model that must be executed at
- each training step. The estimator's train op needs to have a control
- dependency on these updates.
- Returns:
- A list of update operators.
- """
- pass
- class TrainerTest(tf.test.TestCase):
- def test_configure_trainer_and_train_two_steps(self):
- train_config_text_proto = """
- optimizer {
- adam_optimizer {
- learning_rate {
- constant_learning_rate {
- learning_rate: 0.01
- }
- }
- }
- }
- data_augmentation_options {
- random_adjust_brightness {
- max_delta: 0.2
- }
- }
- data_augmentation_options {
- random_adjust_contrast {
- min_delta: 0.7
- max_delta: 1.1
- }
- }
- num_steps: 2
- """
- train_config = train_pb2.TrainConfig()
- text_format.Merge(train_config_text_proto, train_config)
- train_dir = self.get_temp_dir()
- trainer.train(
- create_tensor_dict_fn=get_input_function,
- create_model_fn=FakeDetectionModel,
- train_config=train_config,
- master='',
- task=0,
- num_clones=1,
- worker_replicas=1,
- clone_on_cpu=True,
- ps_tasks=0,
- worker_job_name='worker',
- is_chief=True,
- train_dir=train_dir)
- def test_configure_trainer_with_multiclass_scores_and_train_two_steps(self):
- train_config_text_proto = """
- optimizer {
- adam_optimizer {
- learning_rate {
- constant_learning_rate {
- learning_rate: 0.01
- }
- }
- }
- }
- data_augmentation_options {
- random_adjust_brightness {
- max_delta: 0.2
- }
- }
- data_augmentation_options {
- random_adjust_contrast {
- min_delta: 0.7
- max_delta: 1.1
- }
- }
- num_steps: 2
- use_multiclass_scores: true
- """
- train_config = train_pb2.TrainConfig()
- text_format.Merge(train_config_text_proto, train_config)
- train_dir = self.get_temp_dir()
- trainer.train(create_tensor_dict_fn=get_input_function,
- create_model_fn=FakeDetectionModel,
- train_config=train_config,
- master='',
- task=0,
- num_clones=1,
- worker_replicas=1,
- clone_on_cpu=True,
- ps_tasks=0,
- worker_job_name='worker',
- is_chief=True,
- train_dir=train_dir)
- if __name__ == '__main__':
- tf.test.main()