You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

332 lines
12 KiB

6 years ago
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for object_detection.models.model_builder."""
  16. from absl.testing import parameterized
  17. import tensorflow as tf
  18. from google.protobuf import text_format
  19. from object_detection.builders import model_builder
  20. from object_detection.meta_architectures import faster_rcnn_meta_arch
  21. from object_detection.meta_architectures import rfcn_meta_arch
  22. from object_detection.meta_architectures import ssd_meta_arch
  23. from object_detection.models import ssd_resnet_v1_fpn_feature_extractor as ssd_resnet_v1_fpn
  24. from object_detection.protos import hyperparams_pb2
  25. from object_detection.protos import losses_pb2
  26. from object_detection.protos import model_pb2
  27. class ModelBuilderTest(tf.test.TestCase, parameterized.TestCase):
  28. def create_model(self, model_config, is_training=True):
  29. """Builds a DetectionModel based on the model config.
  30. Args:
  31. model_config: A model.proto object containing the config for the desired
  32. DetectionModel.
  33. is_training: True if this model is being built for training purposes.
  34. Returns:
  35. DetectionModel based on the config.
  36. """
  37. return model_builder.build(model_config, is_training=is_training)
  38. def create_default_ssd_model_proto(self):
  39. """Creates a DetectionModel proto with ssd model fields populated."""
  40. model_text_proto = """
  41. ssd {
  42. feature_extractor {
  43. type: 'ssd_inception_v2'
  44. conv_hyperparams {
  45. regularizer {
  46. l2_regularizer {
  47. }
  48. }
  49. initializer {
  50. truncated_normal_initializer {
  51. }
  52. }
  53. }
  54. override_base_feature_extractor_hyperparams: true
  55. }
  56. box_coder {
  57. faster_rcnn_box_coder {
  58. }
  59. }
  60. matcher {
  61. argmax_matcher {
  62. }
  63. }
  64. similarity_calculator {
  65. iou_similarity {
  66. }
  67. }
  68. anchor_generator {
  69. ssd_anchor_generator {
  70. aspect_ratios: 1.0
  71. }
  72. }
  73. image_resizer {
  74. fixed_shape_resizer {
  75. height: 320
  76. width: 320
  77. }
  78. }
  79. box_predictor {
  80. convolutional_box_predictor {
  81. conv_hyperparams {
  82. regularizer {
  83. l2_regularizer {
  84. }
  85. }
  86. initializer {
  87. truncated_normal_initializer {
  88. }
  89. }
  90. }
  91. }
  92. }
  93. loss {
  94. classification_loss {
  95. weighted_softmax {
  96. }
  97. }
  98. localization_loss {
  99. weighted_smooth_l1 {
  100. }
  101. }
  102. }
  103. }"""
  104. model_proto = model_pb2.DetectionModel()
  105. text_format.Merge(model_text_proto, model_proto)
  106. return model_proto
  107. def create_default_faster_rcnn_model_proto(self):
  108. """Creates a DetectionModel proto with FasterRCNN model fields populated."""
  109. model_text_proto = """
  110. faster_rcnn {
  111. inplace_batchnorm_update: false
  112. num_classes: 3
  113. image_resizer {
  114. keep_aspect_ratio_resizer {
  115. min_dimension: 600
  116. max_dimension: 1024
  117. }
  118. }
  119. feature_extractor {
  120. type: 'faster_rcnn_resnet101'
  121. }
  122. first_stage_anchor_generator {
  123. grid_anchor_generator {
  124. scales: [0.25, 0.5, 1.0, 2.0]
  125. aspect_ratios: [0.5, 1.0, 2.0]
  126. height_stride: 16
  127. width_stride: 16
  128. }
  129. }
  130. first_stage_box_predictor_conv_hyperparams {
  131. regularizer {
  132. l2_regularizer {
  133. }
  134. }
  135. initializer {
  136. truncated_normal_initializer {
  137. }
  138. }
  139. }
  140. initial_crop_size: 14
  141. maxpool_kernel_size: 2
  142. maxpool_stride: 2
  143. second_stage_box_predictor {
  144. mask_rcnn_box_predictor {
  145. conv_hyperparams {
  146. regularizer {
  147. l2_regularizer {
  148. }
  149. }
  150. initializer {
  151. truncated_normal_initializer {
  152. }
  153. }
  154. }
  155. fc_hyperparams {
  156. op: FC
  157. regularizer {
  158. l2_regularizer {
  159. }
  160. }
  161. initializer {
  162. truncated_normal_initializer {
  163. }
  164. }
  165. }
  166. }
  167. }
  168. second_stage_post_processing {
  169. batch_non_max_suppression {
  170. score_threshold: 0.01
  171. iou_threshold: 0.6
  172. max_detections_per_class: 100
  173. max_total_detections: 300
  174. }
  175. score_converter: SOFTMAX
  176. }
  177. }"""
  178. model_proto = model_pb2.DetectionModel()
  179. text_format.Merge(model_text_proto, model_proto)
  180. return model_proto
  181. def test_create_ssd_models_from_config(self):
  182. model_proto = self.create_default_ssd_model_proto()
  183. ssd_feature_extractor_map = {}
  184. ssd_feature_extractor_map.update(
  185. model_builder.SSD_FEATURE_EXTRACTOR_CLASS_MAP)
  186. ssd_feature_extractor_map.update(
  187. model_builder.SSD_KERAS_FEATURE_EXTRACTOR_CLASS_MAP)
  188. for extractor_type, extractor_class in ssd_feature_extractor_map.items():
  189. model_proto.ssd.feature_extractor.type = extractor_type
  190. model = model_builder.build(model_proto, is_training=True)
  191. self.assertIsInstance(model, ssd_meta_arch.SSDMetaArch)
  192. self.assertIsInstance(model._feature_extractor, extractor_class)
  193. def test_create_ssd_fpn_model_from_config(self):
  194. model_proto = self.create_default_ssd_model_proto()
  195. model_proto.ssd.feature_extractor.type = 'ssd_resnet101_v1_fpn'
  196. model_proto.ssd.feature_extractor.fpn.min_level = 3
  197. model_proto.ssd.feature_extractor.fpn.max_level = 7
  198. model = model_builder.build(model_proto, is_training=True)
  199. self.assertIsInstance(model._feature_extractor,
  200. ssd_resnet_v1_fpn.SSDResnet101V1FpnFeatureExtractor)
  201. self.assertEqual(model._feature_extractor._fpn_min_level, 3)
  202. self.assertEqual(model._feature_extractor._fpn_max_level, 7)
  203. @parameterized.named_parameters(
  204. {
  205. 'testcase_name': 'mask_rcnn_with_matmul',
  206. 'use_matmul_crop_and_resize': False,
  207. 'enable_mask_prediction': True
  208. },
  209. {
  210. 'testcase_name': 'mask_rcnn_without_matmul',
  211. 'use_matmul_crop_and_resize': True,
  212. 'enable_mask_prediction': True
  213. },
  214. {
  215. 'testcase_name': 'faster_rcnn_with_matmul',
  216. 'use_matmul_crop_and_resize': False,
  217. 'enable_mask_prediction': False
  218. },
  219. {
  220. 'testcase_name': 'faster_rcnn_without_matmul',
  221. 'use_matmul_crop_and_resize': True,
  222. 'enable_mask_prediction': False
  223. },
  224. )
  225. def test_create_faster_rcnn_models_from_config(
  226. self, use_matmul_crop_and_resize, enable_mask_prediction):
  227. model_proto = self.create_default_faster_rcnn_model_proto()
  228. faster_rcnn_config = model_proto.faster_rcnn
  229. faster_rcnn_config.use_matmul_crop_and_resize = use_matmul_crop_and_resize
  230. if enable_mask_prediction:
  231. faster_rcnn_config.second_stage_mask_prediction_loss_weight = 3.0
  232. mask_predictor_config = (
  233. faster_rcnn_config.second_stage_box_predictor.mask_rcnn_box_predictor)
  234. mask_predictor_config.predict_instance_masks = True
  235. for extractor_type, extractor_class in (
  236. model_builder.FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP.items()):
  237. faster_rcnn_config.feature_extractor.type = extractor_type
  238. model = model_builder.build(model_proto, is_training=True)
  239. self.assertIsInstance(model, faster_rcnn_meta_arch.FasterRCNNMetaArch)
  240. self.assertIsInstance(model._feature_extractor, extractor_class)
  241. if enable_mask_prediction:
  242. self.assertAlmostEqual(model._second_stage_mask_loss_weight, 3.0)
  243. def test_create_faster_rcnn_model_from_config_with_example_miner(self):
  244. model_proto = self.create_default_faster_rcnn_model_proto()
  245. model_proto.faster_rcnn.hard_example_miner.num_hard_examples = 64
  246. model = model_builder.build(model_proto, is_training=True)
  247. self.assertIsNotNone(model._hard_example_miner)
  248. def test_create_rfcn_model_from_config(self):
  249. model_proto = self.create_default_faster_rcnn_model_proto()
  250. rfcn_predictor_config = (
  251. model_proto.faster_rcnn.second_stage_box_predictor.rfcn_box_predictor)
  252. rfcn_predictor_config.conv_hyperparams.op = hyperparams_pb2.Hyperparams.CONV
  253. for extractor_type, extractor_class in (
  254. model_builder.FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP.items()):
  255. model_proto.faster_rcnn.feature_extractor.type = extractor_type
  256. model = model_builder.build(model_proto, is_training=True)
  257. self.assertIsInstance(model, rfcn_meta_arch.RFCNMetaArch)
  258. self.assertIsInstance(model._feature_extractor, extractor_class)
  259. def test_invalid_model_config_proto(self):
  260. model_proto = ''
  261. with self.assertRaisesRegexp(
  262. ValueError, 'model_config not of type model_pb2.DetectionModel.'):
  263. model_builder.build(model_proto, is_training=True)
  264. def test_unknown_meta_architecture(self):
  265. model_proto = model_pb2.DetectionModel()
  266. with self.assertRaisesRegexp(ValueError, 'Unknown meta architecture'):
  267. model_builder.build(model_proto, is_training=True)
  268. def test_unknown_ssd_feature_extractor(self):
  269. model_proto = self.create_default_ssd_model_proto()
  270. model_proto.ssd.feature_extractor.type = 'unknown_feature_extractor'
  271. with self.assertRaisesRegexp(ValueError, 'Unknown ssd feature_extractor'):
  272. model_builder.build(model_proto, is_training=True)
  273. def test_unknown_faster_rcnn_feature_extractor(self):
  274. model_proto = self.create_default_faster_rcnn_model_proto()
  275. model_proto.faster_rcnn.feature_extractor.type = 'unknown_feature_extractor'
  276. with self.assertRaisesRegexp(ValueError,
  277. 'Unknown Faster R-CNN feature_extractor'):
  278. model_builder.build(model_proto, is_training=True)
  279. def test_invalid_first_stage_nms_iou_threshold(self):
  280. model_proto = self.create_default_faster_rcnn_model_proto()
  281. model_proto.faster_rcnn.first_stage_nms_iou_threshold = 1.1
  282. with self.assertRaisesRegexp(ValueError,
  283. r'iou_threshold not in \[0, 1\.0\]'):
  284. model_builder.build(model_proto, is_training=True)
  285. model_proto.faster_rcnn.first_stage_nms_iou_threshold = -0.1
  286. with self.assertRaisesRegexp(ValueError,
  287. r'iou_threshold not in \[0, 1\.0\]'):
  288. model_builder.build(model_proto, is_training=True)
  289. def test_invalid_second_stage_batch_size(self):
  290. model_proto = self.create_default_faster_rcnn_model_proto()
  291. model_proto.faster_rcnn.first_stage_max_proposals = 1
  292. model_proto.faster_rcnn.second_stage_batch_size = 2
  293. with self.assertRaisesRegexp(
  294. ValueError, 'second_stage_batch_size should be no greater '
  295. 'than first_stage_max_proposals.'):
  296. model_builder.build(model_proto, is_training=True)
  297. def test_invalid_faster_rcnn_batchnorm_update(self):
  298. model_proto = self.create_default_faster_rcnn_model_proto()
  299. model_proto.faster_rcnn.inplace_batchnorm_update = True
  300. with self.assertRaisesRegexp(ValueError,
  301. 'inplace batchnorm updates not supported'):
  302. model_builder.build(model_proto, is_training=True)
  303. if __name__ == '__main__':
  304. tf.test.main()