You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

291 lines
9.5 KiB

  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for object_detection.trainer."""
  16. import tensorflow as tf
  17. from google.protobuf import text_format
  18. from object_detection.core import losses
  19. from object_detection.core import model
  20. from object_detection.core import standard_fields as fields
  21. from object_detection.legacy import trainer
  22. from object_detection.protos import train_pb2
  23. NUMBER_OF_CLASSES = 2
  24. def get_input_function():
  25. """A function to get test inputs. Returns an image with one box."""
  26. image = tf.random_uniform([32, 32, 3], dtype=tf.float32)
  27. key = tf.constant('image_000000')
  28. class_label = tf.random_uniform(
  29. [1], minval=0, maxval=NUMBER_OF_CLASSES, dtype=tf.int32)
  30. box_label = tf.random_uniform(
  31. [1, 4], minval=0.4, maxval=0.6, dtype=tf.float32)
  32. multiclass_scores = tf.random_uniform(
  33. [1, NUMBER_OF_CLASSES], minval=0.4, maxval=0.6, dtype=tf.float32)
  34. return {
  35. fields.InputDataFields.image: image,
  36. fields.InputDataFields.key: key,
  37. fields.InputDataFields.groundtruth_classes: class_label,
  38. fields.InputDataFields.groundtruth_boxes: box_label,
  39. fields.InputDataFields.multiclass_scores: multiclass_scores
  40. }
  41. class FakeDetectionModel(model.DetectionModel):
  42. """A simple (and poor) DetectionModel for use in test."""
  43. def __init__(self):
  44. super(FakeDetectionModel, self).__init__(num_classes=NUMBER_OF_CLASSES)
  45. self._classification_loss = losses.WeightedSigmoidClassificationLoss()
  46. self._localization_loss = losses.WeightedSmoothL1LocalizationLoss()
  47. def preprocess(self, inputs):
  48. """Input preprocessing, resizes images to 28x28.
  49. Args:
  50. inputs: a [batch, height_in, width_in, channels] float32 tensor
  51. representing a batch of images with values between 0 and 255.0.
  52. Returns:
  53. preprocessed_inputs: a [batch, 28, 28, channels] float32 tensor.
  54. true_image_shapes: int32 tensor of shape [batch, 3] where each row is
  55. of the form [height, width, channels] indicating the shapes
  56. of true images in the resized images, as resized images can be padded
  57. with zeros.
  58. """
  59. true_image_shapes = [inputs.shape[:-1].as_list()
  60. for _ in range(inputs.shape[-1])]
  61. return tf.image.resize_images(inputs, [28, 28]), true_image_shapes
  62. def predict(self, preprocessed_inputs, true_image_shapes):
  63. """Prediction tensors from inputs tensor.
  64. Args:
  65. preprocessed_inputs: a [batch, 28, 28, channels] float32 tensor.
  66. true_image_shapes: int32 tensor of shape [batch, 3] where each row is
  67. of the form [height, width, channels] indicating the shapes
  68. of true images in the resized images, as resized images can be padded
  69. with zeros.
  70. Returns:
  71. prediction_dict: a dictionary holding prediction tensors to be
  72. passed to the Loss or Postprocess functions.
  73. """
  74. flattened_inputs = tf.contrib.layers.flatten(preprocessed_inputs)
  75. class_prediction = tf.contrib.layers.fully_connected(
  76. flattened_inputs, self._num_classes)
  77. box_prediction = tf.contrib.layers.fully_connected(flattened_inputs, 4)
  78. return {
  79. 'class_predictions_with_background': tf.reshape(
  80. class_prediction, [-1, 1, self._num_classes]),
  81. 'box_encodings': tf.reshape(box_prediction, [-1, 1, 4])
  82. }
  83. def postprocess(self, prediction_dict, true_image_shapes, **params):
  84. """Convert predicted output tensors to final detections. Unused.
  85. Args:
  86. prediction_dict: a dictionary holding prediction tensors.
  87. true_image_shapes: int32 tensor of shape [batch, 3] where each row is
  88. of the form [height, width, channels] indicating the shapes
  89. of true images in the resized images, as resized images can be padded
  90. with zeros.
  91. **params: Additional keyword arguments for specific implementations of
  92. DetectionModel.
  93. Returns:
  94. detections: a dictionary with empty fields.
  95. """
  96. return {
  97. 'detection_boxes': None,
  98. 'detection_scores': None,
  99. 'detection_classes': None,
  100. 'num_detections': None
  101. }
  102. def loss(self, prediction_dict, true_image_shapes):
  103. """Compute scalar loss tensors with respect to provided groundtruth.
  104. Calling this function requires that groundtruth tensors have been
  105. provided via the provide_groundtruth function.
  106. Args:
  107. prediction_dict: a dictionary holding predicted tensors
  108. true_image_shapes: int32 tensor of shape [batch, 3] where each row is
  109. of the form [height, width, channels] indicating the shapes
  110. of true images in the resized images, as resized images can be padded
  111. with zeros.
  112. Returns:
  113. a dictionary mapping strings (loss names) to scalar tensors representing
  114. loss values.
  115. """
  116. batch_reg_targets = tf.stack(
  117. self.groundtruth_lists(fields.BoxListFields.boxes))
  118. batch_cls_targets = tf.stack(
  119. self.groundtruth_lists(fields.BoxListFields.classes))
  120. weights = tf.constant(
  121. 1.0, dtype=tf.float32,
  122. shape=[len(self.groundtruth_lists(fields.BoxListFields.boxes)), 1])
  123. location_losses = self._localization_loss(
  124. prediction_dict['box_encodings'], batch_reg_targets,
  125. weights=weights)
  126. cls_losses = self._classification_loss(
  127. prediction_dict['class_predictions_with_background'], batch_cls_targets,
  128. weights=weights)
  129. loss_dict = {
  130. 'localization_loss': tf.reduce_sum(location_losses),
  131. 'classification_loss': tf.reduce_sum(cls_losses),
  132. }
  133. return loss_dict
  134. def regularization_losses(self):
  135. """Returns a list of regularization losses for this model.
  136. Returns a list of regularization losses for this model that the estimator
  137. needs to use during training/optimization.
  138. Returns:
  139. A list of regularization loss tensors.
  140. """
  141. pass
  142. def restore_map(self, fine_tune_checkpoint_type='detection'):
  143. """Returns a map of variables to load from a foreign checkpoint.
  144. Args:
  145. fine_tune_checkpoint_type: whether to restore from a full detection
  146. checkpoint (with compatible variable names) or to restore from a
  147. classification checkpoint for initialization prior to training.
  148. Valid values: `detection`, `classification`. Default 'detection'.
  149. Returns:
  150. A dict mapping variable names to variables.
  151. """
  152. return {var.op.name: var for var in tf.global_variables()}
  153. def updates(self):
  154. """Returns a list of update operators for this model.
  155. Returns a list of update operators for this model that must be executed at
  156. each training step. The estimator's train op needs to have a control
  157. dependency on these updates.
  158. Returns:
  159. A list of update operators.
  160. """
  161. pass
  162. class TrainerTest(tf.test.TestCase):
  163. def test_configure_trainer_and_train_two_steps(self):
  164. train_config_text_proto = """
  165. optimizer {
  166. adam_optimizer {
  167. learning_rate {
  168. constant_learning_rate {
  169. learning_rate: 0.01
  170. }
  171. }
  172. }
  173. }
  174. data_augmentation_options {
  175. random_adjust_brightness {
  176. max_delta: 0.2
  177. }
  178. }
  179. data_augmentation_options {
  180. random_adjust_contrast {
  181. min_delta: 0.7
  182. max_delta: 1.1
  183. }
  184. }
  185. num_steps: 2
  186. """
  187. train_config = train_pb2.TrainConfig()
  188. text_format.Merge(train_config_text_proto, train_config)
  189. train_dir = self.get_temp_dir()
  190. trainer.train(
  191. create_tensor_dict_fn=get_input_function,
  192. create_model_fn=FakeDetectionModel,
  193. train_config=train_config,
  194. master='',
  195. task=0,
  196. num_clones=1,
  197. worker_replicas=1,
  198. clone_on_cpu=True,
  199. ps_tasks=0,
  200. worker_job_name='worker',
  201. is_chief=True,
  202. train_dir=train_dir)
  203. def test_configure_trainer_with_multiclass_scores_and_train_two_steps(self):
  204. train_config_text_proto = """
  205. optimizer {
  206. adam_optimizer {
  207. learning_rate {
  208. constant_learning_rate {
  209. learning_rate: 0.01
  210. }
  211. }
  212. }
  213. }
  214. data_augmentation_options {
  215. random_adjust_brightness {
  216. max_delta: 0.2
  217. }
  218. }
  219. data_augmentation_options {
  220. random_adjust_contrast {
  221. min_delta: 0.7
  222. max_delta: 1.1
  223. }
  224. }
  225. num_steps: 2
  226. use_multiclass_scores: true
  227. """
  228. train_config = train_pb2.TrainConfig()
  229. text_format.Merge(train_config_text_proto, train_config)
  230. train_dir = self.get_temp_dir()
  231. trainer.train(create_tensor_dict_fn=get_input_function,
  232. create_model_fn=FakeDetectionModel,
  233. train_config=train_config,
  234. master='',
  235. task=0,
  236. num_clones=1,
  237. worker_replicas=1,
  238. clone_on_cpu=True,
  239. ps_tasks=0,
  240. worker_job_name='worker',
  241. is_chief=True,
  242. train_dir=train_dir)
  243. if __name__ == '__main__':
  244. tf.test.main()