You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

597 lines
27 KiB

  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """A function to build a DetectionModel from configuration."""
  16. import functools
  17. from object_detection.builders import anchor_generator_builder
  18. from object_detection.builders import box_coder_builder
  19. from object_detection.builders import box_predictor_builder
  20. from object_detection.builders import hyperparams_builder
  21. from object_detection.builders import image_resizer_builder
  22. from object_detection.builders import losses_builder
  23. from object_detection.builders import matcher_builder
  24. from object_detection.builders import post_processing_builder
  25. from object_detection.builders import region_similarity_calculator_builder as sim_calc
  26. from object_detection.core import balanced_positive_negative_sampler as sampler
  27. from object_detection.core import post_processing
  28. from object_detection.core import target_assigner
  29. from object_detection.meta_architectures import faster_rcnn_meta_arch
  30. from object_detection.meta_architectures import rfcn_meta_arch
  31. from object_detection.meta_architectures import ssd_meta_arch
  32. from object_detection.models import faster_rcnn_inception_resnet_v2_feature_extractor as frcnn_inc_res
  33. from object_detection.models import faster_rcnn_inception_resnet_v2_keras_feature_extractor as frcnn_inc_res_keras
  34. from object_detection.models import faster_rcnn_inception_v2_feature_extractor as frcnn_inc_v2
  35. from object_detection.models import faster_rcnn_nas_feature_extractor as frcnn_nas
  36. from object_detection.models import faster_rcnn_pnas_feature_extractor as frcnn_pnas
  37. from object_detection.models import faster_rcnn_resnet_v1_feature_extractor as frcnn_resnet_v1
  38. from object_detection.models import ssd_resnet_v1_fpn_feature_extractor as ssd_resnet_v1_fpn
  39. from object_detection.models import ssd_resnet_v1_ppn_feature_extractor as ssd_resnet_v1_ppn
  40. from object_detection.models.embedded_ssd_mobilenet_v1_feature_extractor import EmbeddedSSDMobileNetV1FeatureExtractor
  41. from object_detection.models.ssd_inception_v2_feature_extractor import SSDInceptionV2FeatureExtractor
  42. from object_detection.models.ssd_inception_v3_feature_extractor import SSDInceptionV3FeatureExtractor
  43. from object_detection.models.ssd_mobilenet_v1_feature_extractor import SSDMobileNetV1FeatureExtractor
  44. from object_detection.models.ssd_mobilenet_v1_fpn_feature_extractor import SSDMobileNetV1FpnFeatureExtractor
  45. from object_detection.models.ssd_mobilenet_v1_fpn_keras_feature_extractor import SSDMobileNetV1FpnKerasFeatureExtractor
  46. from object_detection.models.ssd_mobilenet_v1_keras_feature_extractor import SSDMobileNetV1KerasFeatureExtractor
  47. from object_detection.models.ssd_mobilenet_v1_ppn_feature_extractor import SSDMobileNetV1PpnFeatureExtractor
  48. from object_detection.models.ssd_mobilenet_v2_feature_extractor import SSDMobileNetV2FeatureExtractor
  49. from object_detection.models.ssd_mobilenet_v2_fpn_feature_extractor import SSDMobileNetV2FpnFeatureExtractor
  50. from object_detection.models.ssd_mobilenet_v2_fpn_keras_feature_extractor import SSDMobileNetV2FpnKerasFeatureExtractor
  51. from object_detection.models.ssd_mobilenet_v2_keras_feature_extractor import SSDMobileNetV2KerasFeatureExtractor
  52. from object_detection.models.ssd_pnasnet_feature_extractor import SSDPNASNetFeatureExtractor
  53. from object_detection.predictors import rfcn_box_predictor
  54. from object_detection.predictors import rfcn_keras_box_predictor
  55. from object_detection.predictors.heads import mask_head
  56. from object_detection.protos import model_pb2
  57. from object_detection.utils import ops
  58. # A map of names to SSD feature extractors.
  59. SSD_FEATURE_EXTRACTOR_CLASS_MAP = {
  60. 'ssd_inception_v2': SSDInceptionV2FeatureExtractor,
  61. 'ssd_inception_v3': SSDInceptionV3FeatureExtractor,
  62. 'ssd_mobilenet_v1': SSDMobileNetV1FeatureExtractor,
  63. 'ssd_mobilenet_v1_fpn': SSDMobileNetV1FpnFeatureExtractor,
  64. 'ssd_mobilenet_v1_ppn': SSDMobileNetV1PpnFeatureExtractor,
  65. 'ssd_mobilenet_v2': SSDMobileNetV2FeatureExtractor,
  66. 'ssd_mobilenet_v2_fpn': SSDMobileNetV2FpnFeatureExtractor,
  67. 'ssd_resnet50_v1_fpn': ssd_resnet_v1_fpn.SSDResnet50V1FpnFeatureExtractor,
  68. 'ssd_resnet101_v1_fpn': ssd_resnet_v1_fpn.SSDResnet101V1FpnFeatureExtractor,
  69. 'ssd_resnet152_v1_fpn': ssd_resnet_v1_fpn.SSDResnet152V1FpnFeatureExtractor,
  70. 'ssd_resnet50_v1_ppn': ssd_resnet_v1_ppn.SSDResnet50V1PpnFeatureExtractor,
  71. 'ssd_resnet101_v1_ppn':
  72. ssd_resnet_v1_ppn.SSDResnet101V1PpnFeatureExtractor,
  73. 'ssd_resnet152_v1_ppn':
  74. ssd_resnet_v1_ppn.SSDResnet152V1PpnFeatureExtractor,
  75. 'embedded_ssd_mobilenet_v1': EmbeddedSSDMobileNetV1FeatureExtractor,
  76. 'ssd_pnasnet': SSDPNASNetFeatureExtractor,
  77. }
  78. SSD_KERAS_FEATURE_EXTRACTOR_CLASS_MAP = {
  79. 'ssd_mobilenet_v1_keras': SSDMobileNetV1KerasFeatureExtractor,
  80. 'ssd_mobilenet_v1_fpn_keras': SSDMobileNetV1FpnKerasFeatureExtractor,
  81. 'ssd_mobilenet_v2_keras': SSDMobileNetV2KerasFeatureExtractor,
  82. 'ssd_mobilenet_v2_fpn_keras': SSDMobileNetV2FpnKerasFeatureExtractor,
  83. }
  84. # A map of names to Faster R-CNN feature extractors.
  85. FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP = {
  86. 'faster_rcnn_nas':
  87. frcnn_nas.FasterRCNNNASFeatureExtractor,
  88. 'faster_rcnn_pnas':
  89. frcnn_pnas.FasterRCNNPNASFeatureExtractor,
  90. 'faster_rcnn_inception_resnet_v2':
  91. frcnn_inc_res.FasterRCNNInceptionResnetV2FeatureExtractor,
  92. 'faster_rcnn_inception_v2':
  93. frcnn_inc_v2.FasterRCNNInceptionV2FeatureExtractor,
  94. 'faster_rcnn_resnet50':
  95. frcnn_resnet_v1.FasterRCNNResnet50FeatureExtractor,
  96. 'faster_rcnn_resnet101':
  97. frcnn_resnet_v1.FasterRCNNResnet101FeatureExtractor,
  98. 'faster_rcnn_resnet152':
  99. frcnn_resnet_v1.FasterRCNNResnet152FeatureExtractor,
  100. }
  101. FASTER_RCNN_KERAS_FEATURE_EXTRACTOR_CLASS_MAP = {
  102. 'faster_rcnn_inception_resnet_v2_keras':
  103. frcnn_inc_res_keras.FasterRCNNInceptionResnetV2KerasFeatureExtractor,
  104. }
  105. def build(model_config, is_training, add_summaries=True):
  106. """Builds a DetectionModel based on the model config.
  107. Args:
  108. model_config: A model.proto object containing the config for the desired
  109. DetectionModel.
  110. is_training: True if this model is being built for training purposes.
  111. add_summaries: Whether to add tensorflow summaries in the model graph.
  112. Returns:
  113. DetectionModel based on the config.
  114. Raises:
  115. ValueError: On invalid meta architecture or model.
  116. """
  117. if not isinstance(model_config, model_pb2.DetectionModel):
  118. raise ValueError('model_config not of type model_pb2.DetectionModel.')
  119. meta_architecture = model_config.WhichOneof('model')
  120. if meta_architecture == 'ssd':
  121. return _build_ssd_model(model_config.ssd, is_training, add_summaries)
  122. if meta_architecture == 'faster_rcnn':
  123. return _build_faster_rcnn_model(model_config.faster_rcnn, is_training,
  124. add_summaries)
  125. raise ValueError('Unknown meta architecture: {}'.format(meta_architecture))
  126. def _build_ssd_feature_extractor(feature_extractor_config,
  127. is_training,
  128. freeze_batchnorm,
  129. reuse_weights=None):
  130. """Builds a ssd_meta_arch.SSDFeatureExtractor based on config.
  131. Args:
  132. feature_extractor_config: A SSDFeatureExtractor proto config from ssd.proto.
  133. is_training: True if this feature extractor is being built for training.
  134. freeze_batchnorm: Whether to freeze batch norm parameters during
  135. training or not. When training with a small batch size (e.g. 1), it is
  136. desirable to freeze batch norm update and use pretrained batch norm
  137. params.
  138. reuse_weights: if the feature extractor should reuse weights.
  139. Returns:
  140. ssd_meta_arch.SSDFeatureExtractor based on config.
  141. Raises:
  142. ValueError: On invalid feature extractor type.
  143. """
  144. feature_type = feature_extractor_config.type
  145. is_keras_extractor = feature_type in SSD_KERAS_FEATURE_EXTRACTOR_CLASS_MAP
  146. depth_multiplier = feature_extractor_config.depth_multiplier
  147. min_depth = feature_extractor_config.min_depth
  148. pad_to_multiple = feature_extractor_config.pad_to_multiple
  149. use_explicit_padding = feature_extractor_config.use_explicit_padding
  150. use_depthwise = feature_extractor_config.use_depthwise
  151. if is_keras_extractor:
  152. conv_hyperparams = hyperparams_builder.KerasLayerHyperparams(
  153. feature_extractor_config.conv_hyperparams)
  154. else:
  155. conv_hyperparams = hyperparams_builder.build(
  156. feature_extractor_config.conv_hyperparams, is_training)
  157. override_base_feature_extractor_hyperparams = (
  158. feature_extractor_config.override_base_feature_extractor_hyperparams)
  159. if (feature_type not in SSD_FEATURE_EXTRACTOR_CLASS_MAP) and (
  160. not is_keras_extractor):
  161. raise ValueError('Unknown ssd feature_extractor: {}'.format(feature_type))
  162. if is_keras_extractor:
  163. feature_extractor_class = SSD_KERAS_FEATURE_EXTRACTOR_CLASS_MAP[
  164. feature_type]
  165. else:
  166. feature_extractor_class = SSD_FEATURE_EXTRACTOR_CLASS_MAP[feature_type]
  167. kwargs = {
  168. 'is_training':
  169. is_training,
  170. 'depth_multiplier':
  171. depth_multiplier,
  172. 'min_depth':
  173. min_depth,
  174. 'pad_to_multiple':
  175. pad_to_multiple,
  176. 'use_explicit_padding':
  177. use_explicit_padding,
  178. 'use_depthwise':
  179. use_depthwise,
  180. 'override_base_feature_extractor_hyperparams':
  181. override_base_feature_extractor_hyperparams
  182. }
  183. if feature_extractor_config.HasField('replace_preprocessor_with_placeholder'):
  184. kwargs.update({
  185. 'replace_preprocessor_with_placeholder':
  186. feature_extractor_config.replace_preprocessor_with_placeholder
  187. })
  188. if is_keras_extractor:
  189. kwargs.update({
  190. 'conv_hyperparams': conv_hyperparams,
  191. 'inplace_batchnorm_update': False,
  192. 'freeze_batchnorm': freeze_batchnorm
  193. })
  194. else:
  195. kwargs.update({
  196. 'conv_hyperparams_fn': conv_hyperparams,
  197. 'reuse_weights': reuse_weights,
  198. })
  199. if feature_extractor_config.HasField('fpn'):
  200. kwargs.update({
  201. 'fpn_min_level':
  202. feature_extractor_config.fpn.min_level,
  203. 'fpn_max_level':
  204. feature_extractor_config.fpn.max_level,
  205. 'additional_layer_depth':
  206. feature_extractor_config.fpn.additional_layer_depth,
  207. })
  208. return feature_extractor_class(**kwargs)
  209. def _build_ssd_model(ssd_config, is_training, add_summaries):
  210. """Builds an SSD detection model based on the model config.
  211. Args:
  212. ssd_config: A ssd.proto object containing the config for the desired
  213. SSDMetaArch.
  214. is_training: True if this model is being built for training purposes.
  215. add_summaries: Whether to add tf summaries in the model.
  216. Returns:
  217. SSDMetaArch based on the config.
  218. Raises:
  219. ValueError: If ssd_config.type is not recognized (i.e. not registered in
  220. model_class_map).
  221. """
  222. num_classes = ssd_config.num_classes
  223. # Feature extractor
  224. feature_extractor = _build_ssd_feature_extractor(
  225. feature_extractor_config=ssd_config.feature_extractor,
  226. freeze_batchnorm=ssd_config.freeze_batchnorm,
  227. is_training=is_training)
  228. box_coder = box_coder_builder.build(ssd_config.box_coder)
  229. matcher = matcher_builder.build(ssd_config.matcher)
  230. region_similarity_calculator = sim_calc.build(
  231. ssd_config.similarity_calculator)
  232. encode_background_as_zeros = ssd_config.encode_background_as_zeros
  233. negative_class_weight = ssd_config.negative_class_weight
  234. anchor_generator = anchor_generator_builder.build(
  235. ssd_config.anchor_generator)
  236. if feature_extractor.is_keras_model:
  237. ssd_box_predictor = box_predictor_builder.build_keras(
  238. hyperparams_fn=hyperparams_builder.KerasLayerHyperparams,
  239. freeze_batchnorm=ssd_config.freeze_batchnorm,
  240. inplace_batchnorm_update=False,
  241. num_predictions_per_location_list=anchor_generator
  242. .num_anchors_per_location(),
  243. box_predictor_config=ssd_config.box_predictor,
  244. is_training=is_training,
  245. num_classes=num_classes,
  246. add_background_class=ssd_config.add_background_class)
  247. else:
  248. ssd_box_predictor = box_predictor_builder.build(
  249. hyperparams_builder.build, ssd_config.box_predictor, is_training,
  250. num_classes, ssd_config.add_background_class)
  251. image_resizer_fn = image_resizer_builder.build(ssd_config.image_resizer)
  252. non_max_suppression_fn, score_conversion_fn = post_processing_builder.build(
  253. ssd_config.post_processing)
  254. (classification_loss, localization_loss, classification_weight,
  255. localization_weight, hard_example_miner, random_example_sampler,
  256. expected_loss_weights_fn) = losses_builder.build(ssd_config.loss)
  257. normalize_loss_by_num_matches = ssd_config.normalize_loss_by_num_matches
  258. normalize_loc_loss_by_codesize = ssd_config.normalize_loc_loss_by_codesize
  259. equalization_loss_config = ops.EqualizationLossConfig(
  260. weight=ssd_config.loss.equalization_loss.weight,
  261. exclude_prefixes=ssd_config.loss.equalization_loss.exclude_prefixes)
  262. target_assigner_instance = target_assigner.TargetAssigner(
  263. region_similarity_calculator,
  264. matcher,
  265. box_coder,
  266. negative_class_weight=negative_class_weight)
  267. ssd_meta_arch_fn = ssd_meta_arch.SSDMetaArch
  268. kwargs = {}
  269. return ssd_meta_arch_fn(
  270. is_training=is_training,
  271. anchor_generator=anchor_generator,
  272. box_predictor=ssd_box_predictor,
  273. box_coder=box_coder,
  274. feature_extractor=feature_extractor,
  275. encode_background_as_zeros=encode_background_as_zeros,
  276. image_resizer_fn=image_resizer_fn,
  277. non_max_suppression_fn=non_max_suppression_fn,
  278. score_conversion_fn=score_conversion_fn,
  279. classification_loss=classification_loss,
  280. localization_loss=localization_loss,
  281. classification_loss_weight=classification_weight,
  282. localization_loss_weight=localization_weight,
  283. normalize_loss_by_num_matches=normalize_loss_by_num_matches,
  284. hard_example_miner=hard_example_miner,
  285. target_assigner_instance=target_assigner_instance,
  286. add_summaries=add_summaries,
  287. normalize_loc_loss_by_codesize=normalize_loc_loss_by_codesize,
  288. freeze_batchnorm=ssd_config.freeze_batchnorm,
  289. inplace_batchnorm_update=ssd_config.inplace_batchnorm_update,
  290. add_background_class=ssd_config.add_background_class,
  291. explicit_background_class=ssd_config.explicit_background_class,
  292. random_example_sampler=random_example_sampler,
  293. expected_loss_weights_fn=expected_loss_weights_fn,
  294. use_confidences_as_targets=ssd_config.use_confidences_as_targets,
  295. implicit_example_weight=ssd_config.implicit_example_weight,
  296. equalization_loss_config=equalization_loss_config,
  297. **kwargs)
  298. def _build_faster_rcnn_feature_extractor(
  299. feature_extractor_config, is_training, reuse_weights=None,
  300. inplace_batchnorm_update=False):
  301. """Builds a faster_rcnn_meta_arch.FasterRCNNFeatureExtractor based on config.
  302. Args:
  303. feature_extractor_config: A FasterRcnnFeatureExtractor proto config from
  304. faster_rcnn.proto.
  305. is_training: True if this feature extractor is being built for training.
  306. reuse_weights: if the feature extractor should reuse weights.
  307. inplace_batchnorm_update: Whether to update batch_norm inplace during
  308. training. This is required for batch norm to work correctly on TPUs. When
  309. this is false, user must add a control dependency on
  310. tf.GraphKeys.UPDATE_OPS for train/loss op in order to update the batch
  311. norm moving average parameters.
  312. Returns:
  313. faster_rcnn_meta_arch.FasterRCNNFeatureExtractor based on config.
  314. Raises:
  315. ValueError: On invalid feature extractor type.
  316. """
  317. if inplace_batchnorm_update:
  318. raise ValueError('inplace batchnorm updates not supported.')
  319. feature_type = feature_extractor_config.type
  320. first_stage_features_stride = (
  321. feature_extractor_config.first_stage_features_stride)
  322. batch_norm_trainable = feature_extractor_config.batch_norm_trainable
  323. if feature_type not in FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP:
  324. raise ValueError('Unknown Faster R-CNN feature_extractor: {}'.format(
  325. feature_type))
  326. feature_extractor_class = FASTER_RCNN_FEATURE_EXTRACTOR_CLASS_MAP[
  327. feature_type]
  328. return feature_extractor_class(
  329. is_training, first_stage_features_stride,
  330. batch_norm_trainable, reuse_weights=reuse_weights)
  331. def _build_faster_rcnn_keras_feature_extractor(
  332. feature_extractor_config, is_training,
  333. inplace_batchnorm_update=False):
  334. """Builds a faster_rcnn_meta_arch.FasterRCNNKerasFeatureExtractor from config.
  335. Args:
  336. feature_extractor_config: A FasterRcnnFeatureExtractor proto config from
  337. faster_rcnn.proto.
  338. is_training: True if this feature extractor is being built for training.
  339. inplace_batchnorm_update: Whether to update batch_norm inplace during
  340. training. This is required for batch norm to work correctly on TPUs. When
  341. this is false, user must add a control dependency on
  342. tf.GraphKeys.UPDATE_OPS for train/loss op in order to update the batch
  343. norm moving average parameters.
  344. Returns:
  345. faster_rcnn_meta_arch.FasterRCNNKerasFeatureExtractor based on config.
  346. Raises:
  347. ValueError: On invalid feature extractor type.
  348. """
  349. if inplace_batchnorm_update:
  350. raise ValueError('inplace batchnorm updates not supported.')
  351. feature_type = feature_extractor_config.type
  352. first_stage_features_stride = (
  353. feature_extractor_config.first_stage_features_stride)
  354. batch_norm_trainable = feature_extractor_config.batch_norm_trainable
  355. if feature_type not in FASTER_RCNN_KERAS_FEATURE_EXTRACTOR_CLASS_MAP:
  356. raise ValueError('Unknown Faster R-CNN feature_extractor: {}'.format(
  357. feature_type))
  358. feature_extractor_class = FASTER_RCNN_KERAS_FEATURE_EXTRACTOR_CLASS_MAP[
  359. feature_type]
  360. return feature_extractor_class(
  361. is_training, first_stage_features_stride,
  362. batch_norm_trainable)
  363. def _build_faster_rcnn_model(frcnn_config, is_training, add_summaries):
  364. """Builds a Faster R-CNN or R-FCN detection model based on the model config.
  365. Builds R-FCN model if the second_stage_box_predictor in the config is of type
  366. `rfcn_box_predictor` else builds a Faster R-CNN model.
  367. Args:
  368. frcnn_config: A faster_rcnn.proto object containing the config for the
  369. desired FasterRCNNMetaArch or RFCNMetaArch.
  370. is_training: True if this model is being built for training purposes.
  371. add_summaries: Whether to add tf summaries in the model.
  372. Returns:
  373. FasterRCNNMetaArch based on the config.
  374. Raises:
  375. ValueError: If frcnn_config.type is not recognized (i.e. not registered in
  376. model_class_map).
  377. """
  378. num_classes = frcnn_config.num_classes
  379. image_resizer_fn = image_resizer_builder.build(frcnn_config.image_resizer)
  380. is_keras = (frcnn_config.feature_extractor.type in
  381. FASTER_RCNN_KERAS_FEATURE_EXTRACTOR_CLASS_MAP)
  382. if is_keras:
  383. feature_extractor = _build_faster_rcnn_keras_feature_extractor(
  384. frcnn_config.feature_extractor, is_training,
  385. inplace_batchnorm_update=frcnn_config.inplace_batchnorm_update)
  386. else:
  387. feature_extractor = _build_faster_rcnn_feature_extractor(
  388. frcnn_config.feature_extractor, is_training,
  389. inplace_batchnorm_update=frcnn_config.inplace_batchnorm_update)
  390. number_of_stages = frcnn_config.number_of_stages
  391. first_stage_anchor_generator = anchor_generator_builder.build(
  392. frcnn_config.first_stage_anchor_generator)
  393. first_stage_target_assigner = target_assigner.create_target_assigner(
  394. 'FasterRCNN',
  395. 'proposal',
  396. use_matmul_gather=frcnn_config.use_matmul_gather_in_matcher)
  397. first_stage_atrous_rate = frcnn_config.first_stage_atrous_rate
  398. if is_keras:
  399. first_stage_box_predictor_arg_scope_fn = (
  400. hyperparams_builder.KerasLayerHyperparams(
  401. frcnn_config.first_stage_box_predictor_conv_hyperparams))
  402. else:
  403. first_stage_box_predictor_arg_scope_fn = hyperparams_builder.build(
  404. frcnn_config.first_stage_box_predictor_conv_hyperparams, is_training)
  405. first_stage_box_predictor_kernel_size = (
  406. frcnn_config.first_stage_box_predictor_kernel_size)
  407. first_stage_box_predictor_depth = frcnn_config.first_stage_box_predictor_depth
  408. first_stage_minibatch_size = frcnn_config.first_stage_minibatch_size
  409. use_static_shapes = frcnn_config.use_static_shapes and (
  410. frcnn_config.use_static_shapes_for_eval or is_training)
  411. first_stage_sampler = sampler.BalancedPositiveNegativeSampler(
  412. positive_fraction=frcnn_config.first_stage_positive_balance_fraction,
  413. is_static=(frcnn_config.use_static_balanced_label_sampler and
  414. use_static_shapes))
  415. first_stage_max_proposals = frcnn_config.first_stage_max_proposals
  416. if (frcnn_config.first_stage_nms_iou_threshold < 0 or
  417. frcnn_config.first_stage_nms_iou_threshold > 1.0):
  418. raise ValueError('iou_threshold not in [0, 1.0].')
  419. if (is_training and frcnn_config.second_stage_batch_size >
  420. first_stage_max_proposals):
  421. raise ValueError('second_stage_batch_size should be no greater than '
  422. 'first_stage_max_proposals.')
  423. first_stage_non_max_suppression_fn = functools.partial(
  424. post_processing.batch_multiclass_non_max_suppression,
  425. score_thresh=frcnn_config.first_stage_nms_score_threshold,
  426. iou_thresh=frcnn_config.first_stage_nms_iou_threshold,
  427. max_size_per_class=frcnn_config.first_stage_max_proposals,
  428. max_total_size=frcnn_config.first_stage_max_proposals,
  429. use_static_shapes=use_static_shapes)
  430. first_stage_loc_loss_weight = (
  431. frcnn_config.first_stage_localization_loss_weight)
  432. first_stage_obj_loss_weight = frcnn_config.first_stage_objectness_loss_weight
  433. initial_crop_size = frcnn_config.initial_crop_size
  434. maxpool_kernel_size = frcnn_config.maxpool_kernel_size
  435. maxpool_stride = frcnn_config.maxpool_stride
  436. second_stage_target_assigner = target_assigner.create_target_assigner(
  437. 'FasterRCNN',
  438. 'detection',
  439. use_matmul_gather=frcnn_config.use_matmul_gather_in_matcher)
  440. if is_keras:
  441. second_stage_box_predictor = box_predictor_builder.build_keras(
  442. hyperparams_builder.KerasLayerHyperparams,
  443. freeze_batchnorm=False,
  444. inplace_batchnorm_update=False,
  445. num_predictions_per_location_list=[1],
  446. box_predictor_config=frcnn_config.second_stage_box_predictor,
  447. is_training=is_training,
  448. num_classes=num_classes)
  449. else:
  450. second_stage_box_predictor = box_predictor_builder.build(
  451. hyperparams_builder.build,
  452. frcnn_config.second_stage_box_predictor,
  453. is_training=is_training,
  454. num_classes=num_classes)
  455. second_stage_batch_size = frcnn_config.second_stage_batch_size
  456. second_stage_sampler = sampler.BalancedPositiveNegativeSampler(
  457. positive_fraction=frcnn_config.second_stage_balance_fraction,
  458. is_static=(frcnn_config.use_static_balanced_label_sampler and
  459. use_static_shapes))
  460. (second_stage_non_max_suppression_fn, second_stage_score_conversion_fn
  461. ) = post_processing_builder.build(frcnn_config.second_stage_post_processing)
  462. second_stage_localization_loss_weight = (
  463. frcnn_config.second_stage_localization_loss_weight)
  464. second_stage_classification_loss = (
  465. losses_builder.build_faster_rcnn_classification_loss(
  466. frcnn_config.second_stage_classification_loss))
  467. second_stage_classification_loss_weight = (
  468. frcnn_config.second_stage_classification_loss_weight)
  469. second_stage_mask_prediction_loss_weight = (
  470. frcnn_config.second_stage_mask_prediction_loss_weight)
  471. hard_example_miner = None
  472. if frcnn_config.HasField('hard_example_miner'):
  473. hard_example_miner = losses_builder.build_hard_example_miner(
  474. frcnn_config.hard_example_miner,
  475. second_stage_classification_loss_weight,
  476. second_stage_localization_loss_weight)
  477. crop_and_resize_fn = (
  478. ops.matmul_crop_and_resize if frcnn_config.use_matmul_crop_and_resize
  479. else ops.native_crop_and_resize)
  480. clip_anchors_to_image = (
  481. frcnn_config.clip_anchors_to_image)
  482. common_kwargs = {
  483. 'is_training': is_training,
  484. 'num_classes': num_classes,
  485. 'image_resizer_fn': image_resizer_fn,
  486. 'feature_extractor': feature_extractor,
  487. 'number_of_stages': number_of_stages,
  488. 'first_stage_anchor_generator': first_stage_anchor_generator,
  489. 'first_stage_target_assigner': first_stage_target_assigner,
  490. 'first_stage_atrous_rate': first_stage_atrous_rate,
  491. 'first_stage_box_predictor_arg_scope_fn':
  492. first_stage_box_predictor_arg_scope_fn,
  493. 'first_stage_box_predictor_kernel_size':
  494. first_stage_box_predictor_kernel_size,
  495. 'first_stage_box_predictor_depth': first_stage_box_predictor_depth,
  496. 'first_stage_minibatch_size': first_stage_minibatch_size,
  497. 'first_stage_sampler': first_stage_sampler,
  498. 'first_stage_non_max_suppression_fn': first_stage_non_max_suppression_fn,
  499. 'first_stage_max_proposals': first_stage_max_proposals,
  500. 'first_stage_localization_loss_weight': first_stage_loc_loss_weight,
  501. 'first_stage_objectness_loss_weight': first_stage_obj_loss_weight,
  502. 'second_stage_target_assigner': second_stage_target_assigner,
  503. 'second_stage_batch_size': second_stage_batch_size,
  504. 'second_stage_sampler': second_stage_sampler,
  505. 'second_stage_non_max_suppression_fn':
  506. second_stage_non_max_suppression_fn,
  507. 'second_stage_score_conversion_fn': second_stage_score_conversion_fn,
  508. 'second_stage_localization_loss_weight':
  509. second_stage_localization_loss_weight,
  510. 'second_stage_classification_loss':
  511. second_stage_classification_loss,
  512. 'second_stage_classification_loss_weight':
  513. second_stage_classification_loss_weight,
  514. 'hard_example_miner': hard_example_miner,
  515. 'add_summaries': add_summaries,
  516. 'crop_and_resize_fn': crop_and_resize_fn,
  517. 'clip_anchors_to_image': clip_anchors_to_image,
  518. 'use_static_shapes': use_static_shapes,
  519. 'resize_masks': frcnn_config.resize_masks
  520. }
  521. if (isinstance(second_stage_box_predictor,
  522. rfcn_box_predictor.RfcnBoxPredictor) or
  523. isinstance(second_stage_box_predictor,
  524. rfcn_keras_box_predictor.RfcnKerasBoxPredictor)):
  525. return rfcn_meta_arch.RFCNMetaArch(
  526. second_stage_rfcn_box_predictor=second_stage_box_predictor,
  527. **common_kwargs)
  528. else:
  529. return faster_rcnn_meta_arch.FasterRCNNMetaArch(
  530. initial_crop_size=initial_crop_size,
  531. maxpool_kernel_size=maxpool_kernel_size,
  532. maxpool_stride=maxpool_stride,
  533. second_stage_mask_rcnn_box_predictor=second_stage_box_predictor,
  534. second_stage_mask_prediction_loss_weight=(
  535. second_stage_mask_prediction_loss_weight),
  536. **common_kwargs)