You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1209 lines
53 KiB

6 years ago
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """SSD Meta-architecture definition.
  16. General tensorflow implementation of convolutional Multibox/SSD detection
  17. models.
  18. """
  19. import abc
  20. import tensorflow as tf
  21. from object_detection.core import box_list
  22. from object_detection.core import box_list_ops
  23. from object_detection.core import model
  24. from object_detection.core import standard_fields as fields
  25. from object_detection.core import target_assigner
  26. from object_detection.utils import ops
  27. from object_detection.utils import shape_utils
  28. from object_detection.utils import visualization_utils
  29. slim = tf.contrib.slim
  30. class SSDFeatureExtractor(object):
  31. """SSD Slim Feature Extractor definition."""
  32. def __init__(self,
  33. is_training,
  34. depth_multiplier,
  35. min_depth,
  36. pad_to_multiple,
  37. conv_hyperparams_fn,
  38. reuse_weights=None,
  39. use_explicit_padding=False,
  40. use_depthwise=False,
  41. override_base_feature_extractor_hyperparams=False):
  42. """Constructor.
  43. Args:
  44. is_training: whether the network is in training mode.
  45. depth_multiplier: float depth multiplier for feature extractor.
  46. min_depth: minimum feature extractor depth.
  47. pad_to_multiple: the nearest multiple to zero pad the input height and
  48. width dimensions to.
  49. conv_hyperparams_fn: A function to construct tf slim arg_scope for conv2d
  50. and separable_conv2d ops in the layers that are added on top of the
  51. base feature extractor.
  52. reuse_weights: whether to reuse variables. Default is None.
  53. use_explicit_padding: Whether to use explicit padding when extracting
  54. features. Default is False.
  55. use_depthwise: Whether to use depthwise convolutions. Default is False.
  56. override_base_feature_extractor_hyperparams: Whether to override
  57. hyperparameters of the base feature extractor with the one from
  58. `conv_hyperparams_fn`.
  59. """
  60. self._is_training = is_training
  61. self._depth_multiplier = depth_multiplier
  62. self._min_depth = min_depth
  63. self._pad_to_multiple = pad_to_multiple
  64. self._conv_hyperparams_fn = conv_hyperparams_fn
  65. self._reuse_weights = reuse_weights
  66. self._use_explicit_padding = use_explicit_padding
  67. self._use_depthwise = use_depthwise
  68. self._override_base_feature_extractor_hyperparams = (
  69. override_base_feature_extractor_hyperparams)
  70. @property
  71. def is_keras_model(self):
  72. return False
  73. @abc.abstractmethod
  74. def preprocess(self, resized_inputs):
  75. """Preprocesses images for feature extraction (minus image resizing).
  76. Args:
  77. resized_inputs: a [batch, height, width, channels] float tensor
  78. representing a batch of images.
  79. Returns:
  80. preprocessed_inputs: a [batch, height, width, channels] float tensor
  81. representing a batch of images.
  82. true_image_shapes: int32 tensor of shape [batch, 3] where each row is
  83. of the form [height, width, channels] indicating the shapes
  84. of true images in the resized images, as resized images can be padded
  85. with zeros.
  86. """
  87. pass
  88. @abc.abstractmethod
  89. def extract_features(self, preprocessed_inputs):
  90. """Extracts features from preprocessed inputs.
  91. This function is responsible for extracting feature maps from preprocessed
  92. images.
  93. Args:
  94. preprocessed_inputs: a [batch, height, width, channels] float tensor
  95. representing a batch of images.
  96. Returns:
  97. feature_maps: a list of tensors where the ith tensor has shape
  98. [batch, height_i, width_i, depth_i]
  99. """
  100. raise NotImplementedError
  101. def restore_from_classification_checkpoint_fn(self, feature_extractor_scope):
  102. """Returns a map of variables to load from a foreign checkpoint.
  103. Args:
  104. feature_extractor_scope: A scope name for the feature extractor.
  105. Returns:
  106. A dict mapping variable names (to load from a checkpoint) to variables in
  107. the model graph.
  108. """
  109. variables_to_restore = {}
  110. for variable in tf.global_variables():
  111. var_name = variable.op.name
  112. if var_name.startswith(feature_extractor_scope + '/'):
  113. var_name = var_name.replace(feature_extractor_scope + '/', '')
  114. variables_to_restore[var_name] = variable
  115. return variables_to_restore
  116. class SSDKerasFeatureExtractor(tf.keras.Model):
  117. """SSD Feature Extractor definition."""
  118. def __init__(self,
  119. is_training,
  120. depth_multiplier,
  121. min_depth,
  122. pad_to_multiple,
  123. conv_hyperparams,
  124. freeze_batchnorm,
  125. inplace_batchnorm_update,
  126. use_explicit_padding=False,
  127. use_depthwise=False,
  128. override_base_feature_extractor_hyperparams=False,
  129. name=None):
  130. """Constructor.
  131. Args:
  132. is_training: whether the network is in training mode.
  133. depth_multiplier: float depth multiplier for feature extractor.
  134. min_depth: minimum feature extractor depth.
  135. pad_to_multiple: the nearest multiple to zero pad the input height and
  136. width dimensions to.
  137. conv_hyperparams: `hyperparams_builder.KerasLayerHyperparams` object
  138. containing convolution hyperparameters for the layers added on top of
  139. the base feature extractor.
  140. freeze_batchnorm: Whether to freeze batch norm parameters during
  141. training or not. When training with a small batch size (e.g. 1), it is
  142. desirable to freeze batch norm update and use pretrained batch norm
  143. params.
  144. inplace_batchnorm_update: Whether to update batch norm moving average
  145. values inplace. When this is false train op must add a control
  146. dependency on tf.graphkeys.UPDATE_OPS collection in order to update
  147. batch norm statistics.
  148. use_explicit_padding: Whether to use explicit padding when extracting
  149. features. Default is False.
  150. use_depthwise: Whether to use depthwise convolutions. Default is False.
  151. override_base_feature_extractor_hyperparams: Whether to override
  152. hyperparameters of the base feature extractor with the one from
  153. `conv_hyperparams_config`.
  154. name: A string name scope to assign to the model. If 'None', Keras
  155. will auto-generate one from the class name.
  156. """
  157. super(SSDKerasFeatureExtractor, self).__init__(name=name)
  158. self._is_training = is_training
  159. self._depth_multiplier = depth_multiplier
  160. self._min_depth = min_depth
  161. self._pad_to_multiple = pad_to_multiple
  162. self._conv_hyperparams = conv_hyperparams
  163. self._freeze_batchnorm = freeze_batchnorm
  164. self._inplace_batchnorm_update = inplace_batchnorm_update
  165. self._use_explicit_padding = use_explicit_padding
  166. self._use_depthwise = use_depthwise
  167. self._override_base_feature_extractor_hyperparams = (
  168. override_base_feature_extractor_hyperparams)
  169. @property
  170. def is_keras_model(self):
  171. return True
  172. @abc.abstractmethod
  173. def preprocess(self, resized_inputs):
  174. """Preprocesses images for feature extraction (minus image resizing).
  175. Args:
  176. resized_inputs: a [batch, height, width, channels] float tensor
  177. representing a batch of images.
  178. Returns:
  179. preprocessed_inputs: a [batch, height, width, channels] float tensor
  180. representing a batch of images.
  181. true_image_shapes: int32 tensor of shape [batch, 3] where each row is
  182. of the form [height, width, channels] indicating the shapes
  183. of true images in the resized images, as resized images can be padded
  184. with zeros.
  185. """
  186. raise NotImplementedError
  187. @abc.abstractmethod
  188. def _extract_features(self, preprocessed_inputs):
  189. """Extracts features from preprocessed inputs.
  190. This function is responsible for extracting feature maps from preprocessed
  191. images.
  192. Args:
  193. preprocessed_inputs: a [batch, height, width, channels] float tensor
  194. representing a batch of images.
  195. Returns:
  196. feature_maps: a list of tensors where the ith tensor has shape
  197. [batch, height_i, width_i, depth_i]
  198. """
  199. raise NotImplementedError
  200. # This overrides the keras.Model `call` method with the _extract_features
  201. # method.
  202. def call(self, inputs, **kwargs):
  203. return self._extract_features(inputs)
  204. def restore_from_classification_checkpoint_fn(self, feature_extractor_scope):
  205. """Returns a map of variables to load from a foreign checkpoint.
  206. Args:
  207. feature_extractor_scope: A scope name for the feature extractor.
  208. Returns:
  209. A dict mapping variable names (to load from a checkpoint) to variables in
  210. the model graph.
  211. """
  212. variables_to_restore = {}
  213. for variable in tf.global_variables():
  214. var_name = variable.op.name
  215. if var_name.startswith(feature_extractor_scope + '/'):
  216. var_name = var_name.replace(feature_extractor_scope + '/', '')
  217. variables_to_restore[var_name] = variable
  218. return variables_to_restore
  219. class SSDMetaArch(model.DetectionModel):
  220. """SSD Meta-architecture definition."""
  221. def __init__(self,
  222. is_training,
  223. anchor_generator,
  224. box_predictor,
  225. box_coder,
  226. feature_extractor,
  227. encode_background_as_zeros,
  228. image_resizer_fn,
  229. non_max_suppression_fn,
  230. score_conversion_fn,
  231. classification_loss,
  232. localization_loss,
  233. classification_loss_weight,
  234. localization_loss_weight,
  235. normalize_loss_by_num_matches,
  236. hard_example_miner,
  237. target_assigner_instance,
  238. add_summaries=True,
  239. normalize_loc_loss_by_codesize=False,
  240. freeze_batchnorm=False,
  241. inplace_batchnorm_update=False,
  242. add_background_class=True,
  243. explicit_background_class=False,
  244. random_example_sampler=None,
  245. expected_loss_weights_fn=None,
  246. use_confidences_as_targets=False,
  247. implicit_example_weight=0.5,
  248. equalization_loss_config=None):
  249. """SSDMetaArch Constructor.
  250. TODO(rathodv,jonathanhuang): group NMS parameters + score converter into
  251. a class and loss parameters into a class and write config protos for
  252. postprocessing and losses.
  253. Args:
  254. is_training: A boolean indicating whether the training version of the
  255. computation graph should be constructed.
  256. anchor_generator: an anchor_generator.AnchorGenerator object.
  257. box_predictor: a box_predictor.BoxPredictor object.
  258. box_coder: a box_coder.BoxCoder object.
  259. feature_extractor: a SSDFeatureExtractor object.
  260. encode_background_as_zeros: boolean determining whether background
  261. targets are to be encoded as an all zeros vector or a one-hot
  262. vector (where background is the 0th class).
  263. image_resizer_fn: a callable for image resizing. This callable always
  264. takes a rank-3 image tensor (corresponding to a single image) and
  265. returns a rank-3 image tensor, possibly with new spatial dimensions and
  266. a 1-D tensor of shape [3] indicating shape of true image within
  267. the resized image tensor as the resized image tensor could be padded.
  268. See builders/image_resizer_builder.py.
  269. non_max_suppression_fn: batch_multiclass_non_max_suppression
  270. callable that takes `boxes`, `scores` and optional `clip_window`
  271. inputs (with all other inputs already set) and returns a dictionary
  272. hold tensors with keys: `detection_boxes`, `detection_scores`,
  273. `detection_classes` and `num_detections`. See `post_processing.
  274. batch_multiclass_non_max_suppression` for the type and shape of these
  275. tensors.
  276. score_conversion_fn: callable elementwise nonlinearity (that takes tensors
  277. as inputs and returns tensors). This is usually used to convert logits
  278. to probabilities.
  279. classification_loss: an object_detection.core.losses.Loss object.
  280. localization_loss: a object_detection.core.losses.Loss object.
  281. classification_loss_weight: float
  282. localization_loss_weight: float
  283. normalize_loss_by_num_matches: boolean
  284. hard_example_miner: a losses.HardExampleMiner object (can be None)
  285. target_assigner_instance: target_assigner.TargetAssigner instance to use.
  286. add_summaries: boolean (default: True) controlling whether summary ops
  287. should be added to tensorflow graph.
  288. normalize_loc_loss_by_codesize: whether to normalize localization loss
  289. by code size of the box encoder.
  290. freeze_batchnorm: Whether to freeze batch norm parameters during
  291. training or not. When training with a small batch size (e.g. 1), it is
  292. desirable to freeze batch norm update and use pretrained batch norm
  293. params.
  294. inplace_batchnorm_update: Whether to update batch norm moving average
  295. values inplace. When this is false train op must add a control
  296. dependency on tf.graphkeys.UPDATE_OPS collection in order to update
  297. batch norm statistics.
  298. add_background_class: Whether to add an implicit background class to
  299. one-hot encodings of groundtruth labels. Set to false if training a
  300. single class model or using groundtruth labels with an explicit
  301. background class.
  302. explicit_background_class: Set to true if using groundtruth labels with an
  303. explicit background class, as in multiclass scores.
  304. random_example_sampler: a BalancedPositiveNegativeSampler object that can
  305. perform random example sampling when computing loss. If None, random
  306. sampling process is skipped. Note that random example sampler and hard
  307. example miner can both be applied to the model. In that case, random
  308. sampler will take effect first and hard example miner can only process
  309. the random sampled examples.
  310. expected_loss_weights_fn: If not None, use to calculate
  311. loss by background/foreground weighting. Should take batch_cls_targets
  312. as inputs and return foreground_weights, background_weights. See
  313. expected_classification_loss_by_expected_sampling and
  314. expected_classification_loss_by_reweighting_unmatched_anchors in
  315. third_party/tensorflow_models/object_detection/utils/ops.py as examples.
  316. use_confidences_as_targets: Whether to use groundtruth_condifences field
  317. to assign the targets.
  318. implicit_example_weight: a float number that specifies the weight used
  319. for the implicit negative examples.
  320. equalization_loss_config: a namedtuple that specifies configs for
  321. computing equalization loss.
  322. """
  323. super(SSDMetaArch, self).__init__(num_classes=box_predictor.num_classes)
  324. self._is_training = is_training
  325. self._freeze_batchnorm = freeze_batchnorm
  326. self._inplace_batchnorm_update = inplace_batchnorm_update
  327. self._anchor_generator = anchor_generator
  328. self._box_predictor = box_predictor
  329. self._box_coder = box_coder
  330. self._feature_extractor = feature_extractor
  331. self._add_background_class = add_background_class
  332. self._explicit_background_class = explicit_background_class
  333. if add_background_class and explicit_background_class:
  334. raise ValueError("Cannot have both 'add_background_class' and"
  335. " 'explicit_background_class' true.")
  336. # Needed for fine-tuning from classification checkpoints whose
  337. # variables do not have the feature extractor scope.
  338. if self._feature_extractor.is_keras_model:
  339. # Keras feature extractors will have a name they implicitly use to scope.
  340. # So, all contained variables are prefixed by this name.
  341. # To load from classification checkpoints, need to filter out this name.
  342. self._extract_features_scope = feature_extractor.name
  343. else:
  344. # Slim feature extractors get an explicit naming scope
  345. self._extract_features_scope = 'FeatureExtractor'
  346. if encode_background_as_zeros:
  347. background_class = [0]
  348. else:
  349. background_class = [1]
  350. if self._add_background_class:
  351. num_foreground_classes = self.num_classes
  352. else:
  353. num_foreground_classes = self.num_classes - 1
  354. self._unmatched_class_label = tf.constant(
  355. background_class + num_foreground_classes * [0], tf.float32)
  356. self._target_assigner = target_assigner_instance
  357. self._classification_loss = classification_loss
  358. self._localization_loss = localization_loss
  359. self._classification_loss_weight = classification_loss_weight
  360. self._localization_loss_weight = localization_loss_weight
  361. self._normalize_loss_by_num_matches = normalize_loss_by_num_matches
  362. self._normalize_loc_loss_by_codesize = normalize_loc_loss_by_codesize
  363. self._hard_example_miner = hard_example_miner
  364. self._random_example_sampler = random_example_sampler
  365. self._parallel_iterations = 16
  366. self._image_resizer_fn = image_resizer_fn
  367. self._non_max_suppression_fn = non_max_suppression_fn
  368. self._score_conversion_fn = score_conversion_fn
  369. self._anchors = None
  370. self._add_summaries = add_summaries
  371. self._batched_prediction_tensor_names = []
  372. self._expected_loss_weights_fn = expected_loss_weights_fn
  373. self._use_confidences_as_targets = use_confidences_as_targets
  374. self._implicit_example_weight = implicit_example_weight
  375. self._equalization_loss_config = equalization_loss_config
  376. @property
  377. def anchors(self):
  378. if not self._anchors:
  379. raise RuntimeError('anchors have not been constructed yet!')
  380. if not isinstance(self._anchors, box_list.BoxList):
  381. raise RuntimeError('anchors should be a BoxList object, but is not.')
  382. return self._anchors
  383. @property
  384. def batched_prediction_tensor_names(self):
  385. if not self._batched_prediction_tensor_names:
  386. raise RuntimeError('Must call predict() method to get batched prediction '
  387. 'tensor names.')
  388. return self._batched_prediction_tensor_names
  389. def preprocess(self, inputs):
  390. """Feature-extractor specific preprocessing.
  391. SSD meta architecture uses a default clip_window of [0, 0, 1, 1] during
  392. post-processing. On calling `preprocess` method, clip_window gets updated
  393. based on `true_image_shapes` returned by `image_resizer_fn`.
  394. Args:
  395. inputs: a [batch, height_in, width_in, channels] float tensor representing
  396. a batch of images with values between 0 and 255.0.
  397. Returns:
  398. preprocessed_inputs: a [batch, height_out, width_out, channels] float
  399. tensor representing a batch of images.
  400. true_image_shapes: int32 tensor of shape [batch, 3] where each row is
  401. of the form [height, width, channels] indicating the shapes
  402. of true images in the resized images, as resized images can be padded
  403. with zeros.
  404. Raises:
  405. ValueError: if inputs tensor does not have type tf.float32
  406. """
  407. if inputs.dtype is not tf.float32:
  408. raise ValueError('`preprocess` expects a tf.float32 tensor')
  409. with tf.name_scope('Preprocessor'):
  410. # TODO(jonathanhuang): revisit whether to always use batch size as
  411. # the number of parallel iterations vs allow for dynamic batching.
  412. outputs = shape_utils.static_or_dynamic_map_fn(
  413. self._image_resizer_fn,
  414. elems=inputs,
  415. dtype=[tf.float32, tf.int32])
  416. resized_inputs = outputs[0]
  417. true_image_shapes = outputs[1]
  418. return (self._feature_extractor.preprocess(resized_inputs),
  419. true_image_shapes)
  420. def _compute_clip_window(self, preprocessed_images, true_image_shapes):
  421. """Computes clip window to use during post_processing.
  422. Computes a new clip window to use during post-processing based on
  423. `resized_image_shapes` and `true_image_shapes` only if `preprocess` method
  424. has been called. Otherwise returns a default clip window of [0, 0, 1, 1].
  425. Args:
  426. preprocessed_images: the [batch, height, width, channels] image
  427. tensor.
  428. true_image_shapes: int32 tensor of shape [batch, 3] where each row is
  429. of the form [height, width, channels] indicating the shapes
  430. of true images in the resized images, as resized images can be padded
  431. with zeros. Or None if the clip window should cover the full image.
  432. Returns:
  433. a 2-D float32 tensor of the form [batch_size, 4] containing the clip
  434. window for each image in the batch in normalized coordinates (relative to
  435. the resized dimensions) where each clip window is of the form [ymin, xmin,
  436. ymax, xmax] or a default clip window of [0, 0, 1, 1].
  437. """
  438. if true_image_shapes is None:
  439. return tf.constant([0, 0, 1, 1], dtype=tf.float32)
  440. resized_inputs_shape = shape_utils.combined_static_and_dynamic_shape(
  441. preprocessed_images)
  442. true_heights, true_widths, _ = tf.unstack(
  443. tf.to_float(true_image_shapes), axis=1)
  444. padded_height = tf.to_float(resized_inputs_shape[1])
  445. padded_width = tf.to_float(resized_inputs_shape[2])
  446. return tf.stack(
  447. [
  448. tf.zeros_like(true_heights),
  449. tf.zeros_like(true_widths), true_heights / padded_height,
  450. true_widths / padded_width
  451. ],
  452. axis=1)
  453. def predict(self, preprocessed_inputs, true_image_shapes):
  454. """Predicts unpostprocessed tensors from input tensor.
  455. This function takes an input batch of images and runs it through the forward
  456. pass of the network to yield unpostprocessesed predictions.
  457. A side effect of calling the predict method is that self._anchors is
  458. populated with a box_list.BoxList of anchors. These anchors must be
  459. constructed before the postprocess or loss functions can be called.
  460. Args:
  461. preprocessed_inputs: a [batch, height, width, channels] image tensor.
  462. true_image_shapes: int32 tensor of shape [batch, 3] where each row is
  463. of the form [height, width, channels] indicating the shapes
  464. of true images in the resized images, as resized images can be padded
  465. with zeros.
  466. Returns:
  467. prediction_dict: a dictionary holding "raw" prediction tensors:
  468. 1) preprocessed_inputs: the [batch, height, width, channels] image
  469. tensor.
  470. 2) box_encodings: 4-D float tensor of shape [batch_size, num_anchors,
  471. box_code_dimension] containing predicted boxes.
  472. 3) class_predictions_with_background: 3-D float tensor of shape
  473. [batch_size, num_anchors, num_classes+1] containing class predictions
  474. (logits) for each of the anchors. Note that this tensor *includes*
  475. background class predictions (at class index 0).
  476. 4) feature_maps: a list of tensors where the ith tensor has shape
  477. [batch, height_i, width_i, depth_i].
  478. 5) anchors: 2-D float tensor of shape [num_anchors, 4] containing
  479. the generated anchors in normalized coordinates.
  480. """
  481. if self._inplace_batchnorm_update:
  482. batchnorm_updates_collections = None
  483. else:
  484. batchnorm_updates_collections = tf.GraphKeys.UPDATE_OPS
  485. if self._feature_extractor.is_keras_model:
  486. feature_maps = self._feature_extractor(preprocessed_inputs)
  487. else:
  488. with slim.arg_scope([slim.batch_norm],
  489. is_training=(self._is_training and
  490. not self._freeze_batchnorm),
  491. updates_collections=batchnorm_updates_collections):
  492. with tf.variable_scope(None, self._extract_features_scope,
  493. [preprocessed_inputs]):
  494. feature_maps = self._feature_extractor.extract_features(
  495. preprocessed_inputs)
  496. feature_map_spatial_dims = self._get_feature_map_spatial_dims(
  497. feature_maps)
  498. image_shape = shape_utils.combined_static_and_dynamic_shape(
  499. preprocessed_inputs)
  500. self._anchors = box_list_ops.concatenate(
  501. self._anchor_generator.generate(
  502. feature_map_spatial_dims,
  503. im_height=image_shape[1],
  504. im_width=image_shape[2]))
  505. if self._box_predictor.is_keras_model:
  506. predictor_results_dict = self._box_predictor(feature_maps)
  507. else:
  508. with slim.arg_scope([slim.batch_norm],
  509. is_training=(self._is_training and
  510. not self._freeze_batchnorm),
  511. updates_collections=batchnorm_updates_collections):
  512. predictor_results_dict = self._box_predictor.predict(
  513. feature_maps, self._anchor_generator.num_anchors_per_location())
  514. predictions_dict = {
  515. 'preprocessed_inputs': preprocessed_inputs,
  516. 'feature_maps': feature_maps,
  517. 'anchors': self._anchors.get()
  518. }
  519. for prediction_key, prediction_list in iter(predictor_results_dict.items()):
  520. prediction = tf.concat(prediction_list, axis=1)
  521. if (prediction_key == 'box_encodings' and prediction.shape.ndims == 4 and
  522. prediction.shape[2] == 1):
  523. prediction = tf.squeeze(prediction, axis=2)
  524. predictions_dict[prediction_key] = prediction
  525. self._batched_prediction_tensor_names = [x for x in predictions_dict
  526. if x != 'anchors']
  527. return predictions_dict
  528. def _get_feature_map_spatial_dims(self, feature_maps):
  529. """Return list of spatial dimensions for each feature map in a list.
  530. Args:
  531. feature_maps: a list of tensors where the ith tensor has shape
  532. [batch, height_i, width_i, depth_i].
  533. Returns:
  534. a list of pairs (height, width) for each feature map in feature_maps
  535. """
  536. feature_map_shapes = [
  537. shape_utils.combined_static_and_dynamic_shape(
  538. feature_map) for feature_map in feature_maps
  539. ]
  540. return [(shape[1], shape[2]) for shape in feature_map_shapes]
  541. def postprocess(self, prediction_dict, true_image_shapes):
  542. """Converts prediction tensors to final detections.
  543. This function converts raw predictions tensors to final detection results by
  544. slicing off the background class, decoding box predictions and applying
  545. non max suppression and clipping to the image window.
  546. See base class for output format conventions. Note also that by default,
  547. scores are to be interpreted as logits, but if a score_conversion_fn is
  548. used, then scores are remapped (and may thus have a different
  549. interpretation).
  550. Args:
  551. prediction_dict: a dictionary holding prediction tensors with
  552. 1) preprocessed_inputs: a [batch, height, width, channels] image
  553. tensor.
  554. 2) box_encodings: 3-D float tensor of shape [batch_size, num_anchors,
  555. box_code_dimension] containing predicted boxes.
  556. 3) class_predictions_with_background: 3-D float tensor of shape
  557. [batch_size, num_anchors, num_classes+1] containing class predictions
  558. (logits) for each of the anchors. Note that this tensor *includes*
  559. background class predictions.
  560. 4) mask_predictions: (optional) a 5-D float tensor of shape
  561. [batch_size, num_anchors, q, mask_height, mask_width]. `q` can be
  562. either number of classes or 1 depending on whether a separate mask is
  563. predicted per class.
  564. true_image_shapes: int32 tensor of shape [batch, 3] where each row is
  565. of the form [height, width, channels] indicating the shapes
  566. of true images in the resized images, as resized images can be padded
  567. with zeros. Or None, if the clip window should cover the full image.
  568. Returns:
  569. detections: a dictionary containing the following fields
  570. detection_boxes: [batch, max_detections, 4] tensor with post-processed
  571. detection boxes.
  572. detection_scores: [batch, max_detections] tensor with scalar scores for
  573. post-processed detection boxes.
  574. detection_classes: [batch, max_detections] tensor with classes for
  575. post-processed detection classes.
  576. detection_keypoints: [batch, max_detections, num_keypoints, 2] (if
  577. encoded in the prediction_dict 'box_encodings')
  578. detection_masks: [batch_size, max_detections, mask_height, mask_width]
  579. (optional)
  580. num_detections: [batch]
  581. raw_detection_boxes: [batch, total_detections, 4] tensor with decoded
  582. detection boxes before Non-Max Suppression.
  583. raw_detection_score: [batch, total_detections,
  584. num_classes_with_background] tensor of multi-class score logits for
  585. raw detection boxes.
  586. Raises:
  587. ValueError: if prediction_dict does not contain `box_encodings` or
  588. `class_predictions_with_background` fields.
  589. """
  590. if ('box_encodings' not in prediction_dict or
  591. 'class_predictions_with_background' not in prediction_dict):
  592. raise ValueError('prediction_dict does not contain expected entries.')
  593. with tf.name_scope('Postprocessor'):
  594. preprocessed_images = prediction_dict['preprocessed_inputs']
  595. box_encodings = prediction_dict['box_encodings']
  596. box_encodings = tf.identity(box_encodings, 'raw_box_encodings')
  597. class_predictions = prediction_dict['class_predictions_with_background']
  598. detection_boxes, detection_keypoints = self._batch_decode(box_encodings)
  599. detection_boxes = tf.identity(detection_boxes, 'raw_box_locations')
  600. detection_boxes = tf.expand_dims(detection_boxes, axis=2)
  601. detection_scores = self._score_conversion_fn(class_predictions)
  602. detection_scores = tf.identity(detection_scores, 'raw_box_scores')
  603. if self._add_background_class or self._explicit_background_class:
  604. detection_scores = tf.slice(detection_scores, [0, 0, 1], [-1, -1, -1])
  605. additional_fields = None
  606. batch_size = (
  607. shape_utils.combined_static_and_dynamic_shape(preprocessed_images)[0])
  608. if 'feature_maps' in prediction_dict:
  609. feature_map_list = []
  610. for feature_map in prediction_dict['feature_maps']:
  611. feature_map_list.append(tf.reshape(feature_map, [batch_size, -1]))
  612. box_features = tf.concat(feature_map_list, 1)
  613. box_features = tf.identity(box_features, 'raw_box_features')
  614. if detection_keypoints is not None:
  615. additional_fields = {
  616. fields.BoxListFields.keypoints: detection_keypoints}
  617. (nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_masks,
  618. nmsed_additional_fields, num_detections) = self._non_max_suppression_fn(
  619. detection_boxes,
  620. detection_scores,
  621. clip_window=self._compute_clip_window(preprocessed_images,
  622. true_image_shapes),
  623. additional_fields=additional_fields,
  624. masks=prediction_dict.get('mask_predictions'))
  625. detection_dict = {
  626. fields.DetectionResultFields.detection_boxes:
  627. nmsed_boxes,
  628. fields.DetectionResultFields.detection_scores:
  629. nmsed_scores,
  630. fields.DetectionResultFields.detection_classes:
  631. nmsed_classes,
  632. fields.DetectionResultFields.num_detections:
  633. tf.to_float(num_detections),
  634. fields.DetectionResultFields.raw_detection_boxes:
  635. tf.squeeze(detection_boxes, axis=2),
  636. fields.DetectionResultFields.raw_detection_scores:
  637. class_predictions
  638. }
  639. if (nmsed_additional_fields is not None and
  640. fields.BoxListFields.keypoints in nmsed_additional_fields):
  641. detection_dict[fields.DetectionResultFields.detection_keypoints] = (
  642. nmsed_additional_fields[fields.BoxListFields.keypoints])
  643. if nmsed_masks is not None:
  644. detection_dict[
  645. fields.DetectionResultFields.detection_masks] = nmsed_masks
  646. return detection_dict
  647. def loss(self, prediction_dict, true_image_shapes, scope=None):
  648. """Compute scalar loss tensors with respect to provided groundtruth.
  649. Calling this function requires that groundtruth tensors have been
  650. provided via the provide_groundtruth function.
  651. Args:
  652. prediction_dict: a dictionary holding prediction tensors with
  653. 1) box_encodings: 3-D float tensor of shape [batch_size, num_anchors,
  654. box_code_dimension] containing predicted boxes.
  655. 2) class_predictions_with_background: 3-D float tensor of shape
  656. [batch_size, num_anchors, num_classes+1] containing class predictions
  657. (logits) for each of the anchors. Note that this tensor *includes*
  658. background class predictions.
  659. true_image_shapes: int32 tensor of shape [batch, 3] where each row is
  660. of the form [height, width, channels] indicating the shapes
  661. of true images in the resized images, as resized images can be padded
  662. with zeros.
  663. scope: Optional scope name.
  664. Returns:
  665. a dictionary mapping loss keys (`localization_loss` and
  666. `classification_loss`) to scalar tensors representing corresponding loss
  667. values.
  668. """
  669. with tf.name_scope(scope, 'Loss', prediction_dict.values()):
  670. keypoints = None
  671. if self.groundtruth_has_field(fields.BoxListFields.keypoints):
  672. keypoints = self.groundtruth_lists(fields.BoxListFields.keypoints)
  673. weights = None
  674. if self.groundtruth_has_field(fields.BoxListFields.weights):
  675. weights = self.groundtruth_lists(fields.BoxListFields.weights)
  676. confidences = None
  677. if self.groundtruth_has_field(fields.BoxListFields.confidences):
  678. confidences = self.groundtruth_lists(fields.BoxListFields.confidences)
  679. (batch_cls_targets, batch_cls_weights, batch_reg_targets,
  680. batch_reg_weights, match_list) = self._assign_targets(
  681. self.groundtruth_lists(fields.BoxListFields.boxes),
  682. self.groundtruth_lists(fields.BoxListFields.classes),
  683. keypoints, weights, confidences)
  684. if self._add_summaries:
  685. self._summarize_target_assignment(
  686. self.groundtruth_lists(fields.BoxListFields.boxes), match_list)
  687. if self._random_example_sampler:
  688. batch_cls_per_anchor_weights = tf.reduce_mean(
  689. batch_cls_weights, axis=-1)
  690. batch_sampled_indicator = tf.to_float(
  691. shape_utils.static_or_dynamic_map_fn(
  692. self._minibatch_subsample_fn,
  693. [batch_cls_targets, batch_cls_per_anchor_weights],
  694. dtype=tf.bool,
  695. parallel_iterations=self._parallel_iterations,
  696. back_prop=True))
  697. batch_reg_weights = tf.multiply(batch_sampled_indicator,
  698. batch_reg_weights)
  699. batch_cls_weights = tf.multiply(
  700. tf.expand_dims(batch_sampled_indicator, -1),
  701. batch_cls_weights)
  702. losses_mask = None
  703. if self.groundtruth_has_field(fields.InputDataFields.is_annotated):
  704. losses_mask = tf.stack(self.groundtruth_lists(
  705. fields.InputDataFields.is_annotated))
  706. location_losses = self._localization_loss(
  707. prediction_dict['box_encodings'],
  708. batch_reg_targets,
  709. ignore_nan_targets=True,
  710. weights=batch_reg_weights,
  711. losses_mask=losses_mask)
  712. cls_losses = self._classification_loss(
  713. prediction_dict['class_predictions_with_background'],
  714. batch_cls_targets,
  715. weights=batch_cls_weights,
  716. losses_mask=losses_mask)
  717. if self._expected_loss_weights_fn:
  718. # Need to compute losses for assigned targets against the
  719. # unmatched_class_label as well as their assigned targets.
  720. # simplest thing (but wasteful) is just to calculate all losses
  721. # twice
  722. batch_size, num_anchors, num_classes = batch_cls_targets.get_shape()
  723. unmatched_targets = tf.ones([batch_size, num_anchors, 1
  724. ]) * self._unmatched_class_label
  725. unmatched_cls_losses = self._classification_loss(
  726. prediction_dict['class_predictions_with_background'],
  727. unmatched_targets,
  728. weights=batch_cls_weights,
  729. losses_mask=losses_mask)
  730. if cls_losses.get_shape().ndims == 3:
  731. batch_size, num_anchors, num_classes = cls_losses.get_shape()
  732. cls_losses = tf.reshape(cls_losses, [batch_size, -1])
  733. unmatched_cls_losses = tf.reshape(unmatched_cls_losses,
  734. [batch_size, -1])
  735. batch_cls_targets = tf.reshape(
  736. batch_cls_targets, [batch_size, num_anchors * num_classes, -1])
  737. batch_cls_targets = tf.concat(
  738. [1 - batch_cls_targets, batch_cls_targets], axis=-1)
  739. location_losses = tf.tile(location_losses, [1, num_classes])
  740. foreground_weights, background_weights = (
  741. self._expected_loss_weights_fn(batch_cls_targets))
  742. cls_losses = (
  743. foreground_weights * cls_losses +
  744. background_weights * unmatched_cls_losses)
  745. location_losses *= foreground_weights
  746. classification_loss = tf.reduce_sum(cls_losses)
  747. localization_loss = tf.reduce_sum(location_losses)
  748. elif self._hard_example_miner:
  749. cls_losses = ops.reduce_sum_trailing_dimensions(cls_losses, ndims=2)
  750. (localization_loss, classification_loss) = self._apply_hard_mining(
  751. location_losses, cls_losses, prediction_dict, match_list)
  752. if self._add_summaries:
  753. self._hard_example_miner.summarize()
  754. else:
  755. cls_losses = ops.reduce_sum_trailing_dimensions(cls_losses, ndims=2)
  756. localization_loss = tf.reduce_sum(location_losses)
  757. classification_loss = tf.reduce_sum(cls_losses)
  758. # Optionally normalize by number of positive matches
  759. normalizer = tf.constant(1.0, dtype=tf.float32)
  760. if self._normalize_loss_by_num_matches:
  761. normalizer = tf.maximum(tf.to_float(tf.reduce_sum(batch_reg_weights)),
  762. 1.0)
  763. localization_loss_normalizer = normalizer
  764. if self._normalize_loc_loss_by_codesize:
  765. localization_loss_normalizer *= self._box_coder.code_size
  766. localization_loss = tf.multiply((self._localization_loss_weight /
  767. localization_loss_normalizer),
  768. localization_loss,
  769. name='localization_loss')
  770. classification_loss = tf.multiply((self._classification_loss_weight /
  771. normalizer), classification_loss,
  772. name='classification_loss')
  773. loss_dict = {
  774. str(localization_loss.op.name): localization_loss,
  775. str(classification_loss.op.name): classification_loss
  776. }
  777. return loss_dict
  778. def _minibatch_subsample_fn(self, inputs):
  779. """Randomly samples anchors for one image.
  780. Args:
  781. inputs: a list of 2 inputs. First one is a tensor of shape [num_anchors,
  782. num_classes] indicating targets assigned to each anchor. Second one
  783. is a tensor of shape [num_anchors] indicating the class weight of each
  784. anchor.
  785. Returns:
  786. batch_sampled_indicator: bool tensor of shape [num_anchors] indicating
  787. whether the anchor should be selected for loss computation.
  788. """
  789. cls_targets, cls_weights = inputs
  790. if self._add_background_class:
  791. # Set background_class bits to 0 so that the positives_indicator
  792. # computation would not consider background class.
  793. background_class = tf.zeros_like(tf.slice(cls_targets, [0, 0], [-1, 1]))
  794. regular_class = tf.slice(cls_targets, [0, 1], [-1, -1])
  795. cls_targets = tf.concat([background_class, regular_class], 1)
  796. positives_indicator = tf.reduce_sum(cls_targets, axis=1)
  797. return self._random_example_sampler.subsample(
  798. tf.cast(cls_weights, tf.bool),
  799. batch_size=None,
  800. labels=tf.cast(positives_indicator, tf.bool))
  801. def _summarize_anchor_classification_loss(self, class_ids, cls_losses):
  802. positive_indices = tf.where(tf.greater(class_ids, 0))
  803. positive_anchor_cls_loss = tf.squeeze(
  804. tf.gather(cls_losses, positive_indices), axis=1)
  805. visualization_utils.add_cdf_image_summary(positive_anchor_cls_loss,
  806. 'PositiveAnchorLossCDF')
  807. negative_indices = tf.where(tf.equal(class_ids, 0))
  808. negative_anchor_cls_loss = tf.squeeze(
  809. tf.gather(cls_losses, negative_indices), axis=1)
  810. visualization_utils.add_cdf_image_summary(negative_anchor_cls_loss,
  811. 'NegativeAnchorLossCDF')
  812. def _assign_targets(self,
  813. groundtruth_boxes_list,
  814. groundtruth_classes_list,
  815. groundtruth_keypoints_list=None,
  816. groundtruth_weights_list=None,
  817. groundtruth_confidences_list=None):
  818. """Assign groundtruth targets.
  819. Adds a background class to each one-hot encoding of groundtruth classes
  820. and uses target assigner to obtain regression and classification targets.
  821. Args:
  822. groundtruth_boxes_list: a list of 2-D tensors of shape [num_boxes, 4]
  823. containing coordinates of the groundtruth boxes.
  824. Groundtruth boxes are provided in [y_min, x_min, y_max, x_max]
  825. format and assumed to be normalized and clipped
  826. relative to the image window with y_min <= y_max and x_min <= x_max.
  827. groundtruth_classes_list: a list of 2-D one-hot (or k-hot) tensors of
  828. shape [num_boxes, num_classes] containing the class targets with the 0th
  829. index assumed to map to the first non-background class.
  830. groundtruth_keypoints_list: (optional) a list of 3-D tensors of shape
  831. [num_boxes, num_keypoints, 2]
  832. groundtruth_weights_list: A list of 1-D tf.float32 tensors of shape
  833. [num_boxes] containing weights for groundtruth boxes.
  834. groundtruth_confidences_list: A list of 2-D tf.float32 tensors of shape
  835. [num_boxes, num_classes] containing class confidences for
  836. groundtruth boxes.
  837. Returns:
  838. batch_cls_targets: a tensor with shape [batch_size, num_anchors,
  839. num_classes],
  840. batch_cls_weights: a tensor with shape [batch_size, num_anchors],
  841. batch_reg_targets: a tensor with shape [batch_size, num_anchors,
  842. box_code_dimension]
  843. batch_reg_weights: a tensor with shape [batch_size, num_anchors],
  844. match_list: a list of matcher.Match objects encoding the match between
  845. anchors and groundtruth boxes for each image of the batch,
  846. with rows of the Match objects corresponding to groundtruth boxes
  847. and columns corresponding to anchors.
  848. """
  849. groundtruth_boxlists = [
  850. box_list.BoxList(boxes) for boxes in groundtruth_boxes_list
  851. ]
  852. train_using_confidences = (self._is_training and
  853. self._use_confidences_as_targets)
  854. if self._add_background_class:
  855. groundtruth_classes_with_background_list = [
  856. tf.pad(one_hot_encoding, [[0, 0], [1, 0]], mode='CONSTANT')
  857. for one_hot_encoding in groundtruth_classes_list
  858. ]
  859. if train_using_confidences:
  860. groundtruth_confidences_with_background_list = [
  861. tf.pad(groundtruth_confidences, [[0, 0], [1, 0]], mode='CONSTANT')
  862. for groundtruth_confidences in groundtruth_confidences_list
  863. ]
  864. else:
  865. groundtruth_classes_with_background_list = groundtruth_classes_list
  866. if groundtruth_keypoints_list is not None:
  867. for boxlist, keypoints in zip(
  868. groundtruth_boxlists, groundtruth_keypoints_list):
  869. boxlist.add_field(fields.BoxListFields.keypoints, keypoints)
  870. if train_using_confidences:
  871. return target_assigner.batch_assign_confidences(
  872. self._target_assigner,
  873. self.anchors,
  874. groundtruth_boxlists,
  875. groundtruth_confidences_with_background_list,
  876. groundtruth_weights_list,
  877. self._unmatched_class_label,
  878. self._add_background_class,
  879. self._implicit_example_weight)
  880. else:
  881. return target_assigner.batch_assign_targets(
  882. self._target_assigner,
  883. self.anchors,
  884. groundtruth_boxlists,
  885. groundtruth_classes_with_background_list,
  886. self._unmatched_class_label,
  887. groundtruth_weights_list)
  888. def _summarize_target_assignment(self, groundtruth_boxes_list, match_list):
  889. """Creates tensorflow summaries for the input boxes and anchors.
  890. This function creates four summaries corresponding to the average
  891. number (over images in a batch) of (1) groundtruth boxes, (2) anchors
  892. marked as positive, (3) anchors marked as negative, and (4) anchors marked
  893. as ignored.
  894. Args:
  895. groundtruth_boxes_list: a list of 2-D tensors of shape [num_boxes, 4]
  896. containing corners of the groundtruth boxes.
  897. match_list: a list of matcher.Match objects encoding the match between
  898. anchors and groundtruth boxes for each image of the batch,
  899. with rows of the Match objects corresponding to groundtruth boxes
  900. and columns corresponding to anchors.
  901. """
  902. num_boxes_per_image = tf.stack(
  903. [tf.shape(x)[0] for x in groundtruth_boxes_list])
  904. pos_anchors_per_image = tf.stack(
  905. [match.num_matched_columns() for match in match_list])
  906. neg_anchors_per_image = tf.stack(
  907. [match.num_unmatched_columns() for match in match_list])
  908. ignored_anchors_per_image = tf.stack(
  909. [match.num_ignored_columns() for match in match_list])
  910. tf.summary.scalar('AvgNumGroundtruthBoxesPerImage',
  911. tf.reduce_mean(tf.to_float(num_boxes_per_image)),
  912. family='TargetAssignment')
  913. tf.summary.scalar('AvgNumPositiveAnchorsPerImage',
  914. tf.reduce_mean(tf.to_float(pos_anchors_per_image)),
  915. family='TargetAssignment')
  916. tf.summary.scalar('AvgNumNegativeAnchorsPerImage',
  917. tf.reduce_mean(tf.to_float(neg_anchors_per_image)),
  918. family='TargetAssignment')
  919. tf.summary.scalar('AvgNumIgnoredAnchorsPerImage',
  920. tf.reduce_mean(tf.to_float(ignored_anchors_per_image)),
  921. family='TargetAssignment')
  922. def _apply_hard_mining(self, location_losses, cls_losses, prediction_dict,
  923. match_list):
  924. """Applies hard mining to anchorwise losses.
  925. Args:
  926. location_losses: Float tensor of shape [batch_size, num_anchors]
  927. representing anchorwise location losses.
  928. cls_losses: Float tensor of shape [batch_size, num_anchors]
  929. representing anchorwise classification losses.
  930. prediction_dict: p a dictionary holding prediction tensors with
  931. 1) box_encodings: 3-D float tensor of shape [batch_size, num_anchors,
  932. box_code_dimension] containing predicted boxes.
  933. 2) class_predictions_with_background: 3-D float tensor of shape
  934. [batch_size, num_anchors, num_classes+1] containing class predictions
  935. (logits) for each of the anchors. Note that this tensor *includes*
  936. background class predictions.
  937. match_list: a list of matcher.Match objects encoding the match between
  938. anchors and groundtruth boxes for each image of the batch,
  939. with rows of the Match objects corresponding to groundtruth boxes
  940. and columns corresponding to anchors.
  941. Returns:
  942. mined_location_loss: a float scalar with sum of localization losses from
  943. selected hard examples.
  944. mined_cls_loss: a float scalar with sum of classification losses from
  945. selected hard examples.
  946. """
  947. class_predictions = prediction_dict['class_predictions_with_background']
  948. if self._add_background_class:
  949. class_predictions = tf.slice(class_predictions, [0, 0, 1], [-1, -1, -1])
  950. decoded_boxes, _ = self._batch_decode(prediction_dict['box_encodings'])
  951. decoded_box_tensors_list = tf.unstack(decoded_boxes)
  952. class_prediction_list = tf.unstack(class_predictions)
  953. decoded_boxlist_list = []
  954. for box_location, box_score in zip(decoded_box_tensors_list,
  955. class_prediction_list):
  956. decoded_boxlist = box_list.BoxList(box_location)
  957. decoded_boxlist.add_field('scores', box_score)
  958. decoded_boxlist_list.append(decoded_boxlist)
  959. return self._hard_example_miner(
  960. location_losses=location_losses,
  961. cls_losses=cls_losses,
  962. decoded_boxlist_list=decoded_boxlist_list,
  963. match_list=match_list)
  964. def _batch_decode(self, box_encodings):
  965. """Decodes a batch of box encodings with respect to the anchors.
  966. Args:
  967. box_encodings: A float32 tensor of shape
  968. [batch_size, num_anchors, box_code_size] containing box encodings.
  969. Returns:
  970. decoded_boxes: A float32 tensor of shape
  971. [batch_size, num_anchors, 4] containing the decoded boxes.
  972. decoded_keypoints: A float32 tensor of shape
  973. [batch_size, num_anchors, num_keypoints, 2] containing the decoded
  974. keypoints if present in the input `box_encodings`, None otherwise.
  975. """
  976. combined_shape = shape_utils.combined_static_and_dynamic_shape(
  977. box_encodings)
  978. batch_size = combined_shape[0]
  979. tiled_anchor_boxes = tf.tile(
  980. tf.expand_dims(self.anchors.get(), 0), [batch_size, 1, 1])
  981. tiled_anchors_boxlist = box_list.BoxList(
  982. tf.reshape(tiled_anchor_boxes, [-1, 4]))
  983. decoded_boxes = self._box_coder.decode(
  984. tf.reshape(box_encodings, [-1, self._box_coder.code_size]),
  985. tiled_anchors_boxlist)
  986. decoded_keypoints = None
  987. if decoded_boxes.has_field(fields.BoxListFields.keypoints):
  988. decoded_keypoints = decoded_boxes.get_field(
  989. fields.BoxListFields.keypoints)
  990. num_keypoints = decoded_keypoints.get_shape()[1]
  991. decoded_keypoints = tf.reshape(
  992. decoded_keypoints,
  993. tf.stack([combined_shape[0], combined_shape[1], num_keypoints, 2]))
  994. decoded_boxes = tf.reshape(decoded_boxes.get(), tf.stack(
  995. [combined_shape[0], combined_shape[1], 4]))
  996. return decoded_boxes, decoded_keypoints
  997. def regularization_losses(self):
  998. """Returns a list of regularization losses for this model.
  999. Returns a list of regularization losses for this model that the estimator
  1000. needs to use during training/optimization.
  1001. Returns:
  1002. A list of regularization loss tensors.
  1003. """
  1004. losses = []
  1005. slim_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
  1006. # Copy the slim losses to avoid modifying the collection
  1007. if slim_losses:
  1008. losses.extend(slim_losses)
  1009. if self._box_predictor.is_keras_model:
  1010. losses.extend(self._box_predictor.losses)
  1011. if self._feature_extractor.is_keras_model:
  1012. losses.extend(self._feature_extractor.losses)
  1013. return losses
  1014. def restore_map(self,
  1015. fine_tune_checkpoint_type='detection',
  1016. load_all_detection_checkpoint_vars=False):
  1017. """Returns a map of variables to load from a foreign checkpoint.
  1018. See parent class for details.
  1019. Args:
  1020. fine_tune_checkpoint_type: whether to restore from a full detection
  1021. checkpoint (with compatible variable names) or to restore from a
  1022. classification checkpoint for initialization prior to training.
  1023. Valid values: `detection`, `classification`. Default 'detection'.
  1024. load_all_detection_checkpoint_vars: whether to load all variables (when
  1025. `fine_tune_checkpoint_type='detection'`). If False, only variables
  1026. within the appropriate scopes are included. Default False.
  1027. Returns:
  1028. A dict mapping variable names (to load from a checkpoint) to variables in
  1029. the model graph.
  1030. Raises:
  1031. ValueError: if fine_tune_checkpoint_type is neither `classification`
  1032. nor `detection`.
  1033. """
  1034. if fine_tune_checkpoint_type not in ['detection', 'classification']:
  1035. raise ValueError('Not supported fine_tune_checkpoint_type: {}'.format(
  1036. fine_tune_checkpoint_type))
  1037. if fine_tune_checkpoint_type == 'classification':
  1038. return self._feature_extractor.restore_from_classification_checkpoint_fn(
  1039. self._extract_features_scope)
  1040. if fine_tune_checkpoint_type == 'detection':
  1041. variables_to_restore = {}
  1042. for variable in tf.global_variables():
  1043. var_name = variable.op.name
  1044. if load_all_detection_checkpoint_vars:
  1045. variables_to_restore[var_name] = variable
  1046. else:
  1047. if var_name.startswith(self._extract_features_scope):
  1048. variables_to_restore[var_name] = variable
  1049. return variables_to_restore
  1050. def updates(self):
  1051. """Returns a list of update operators for this model.
  1052. Returns a list of update operators for this model that must be executed at
  1053. each training step. The estimator's train op needs to have a control
  1054. dependency on these updates.
  1055. Returns:
  1056. A list of update operators.
  1057. """
  1058. update_ops = []
  1059. slim_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  1060. # Copy the slim ops to avoid modifying the collection
  1061. if slim_update_ops:
  1062. update_ops.extend(slim_update_ops)
  1063. if self._box_predictor.is_keras_model:
  1064. update_ops.extend(self._box_predictor.get_updates_for(None))
  1065. update_ops.extend(self._box_predictor.get_updates_for(
  1066. self._box_predictor.inputs))
  1067. if self._feature_extractor.is_keras_model:
  1068. update_ops.extend(self._feature_extractor.get_updates_for(None))
  1069. update_ops.extend(self._feature_extractor.get_updates_for(
  1070. self._feature_extractor.inputs))
  1071. return update_ops