You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

109 lines
4.3 KiB

6 years ago
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Keypoint Head.
  16. Contains Keypoint prediction head classes for different meta architectures.
  17. All the keypoint prediction heads have a predict function that receives the
  18. `features` as the first argument and returns `keypoint_predictions`.
  19. Keypoints could be used to represent the human body joint locations as in
  20. Mask RCNN paper. Or they could be used to represent different part locations of
  21. objects.
  22. """
  23. import tensorflow as tf
  24. from object_detection.predictors.heads import head
  25. slim = tf.contrib.slim
  26. class MaskRCNNKeypointHead(head.Head):
  27. """Mask RCNN keypoint prediction head.
  28. Please refer to Mask RCNN paper:
  29. https://arxiv.org/abs/1703.06870
  30. """
  31. def __init__(self,
  32. num_keypoints=17,
  33. conv_hyperparams_fn=None,
  34. keypoint_heatmap_height=56,
  35. keypoint_heatmap_width=56,
  36. keypoint_prediction_num_conv_layers=8,
  37. keypoint_prediction_conv_depth=512):
  38. """Constructor.
  39. Args:
  40. num_keypoints: (int scalar) number of keypoints.
  41. conv_hyperparams_fn: A function to generate tf-slim arg_scope with
  42. hyperparameters for convolution ops.
  43. keypoint_heatmap_height: Desired output mask height. The default value
  44. is 14.
  45. keypoint_heatmap_width: Desired output mask width. The default value
  46. is 14.
  47. keypoint_prediction_num_conv_layers: Number of convolution layers applied
  48. to the image_features in mask prediction branch.
  49. keypoint_prediction_conv_depth: The depth for the first conv2d_transpose
  50. op applied to the image_features in the mask prediction branch. If set
  51. to 0, the depth of the convolution layers will be automatically chosen
  52. based on the number of object classes and the number of channels in the
  53. image features.
  54. """
  55. super(MaskRCNNKeypointHead, self).__init__()
  56. self._num_keypoints = num_keypoints
  57. self._conv_hyperparams_fn = conv_hyperparams_fn
  58. self._keypoint_heatmap_height = keypoint_heatmap_height
  59. self._keypoint_heatmap_width = keypoint_heatmap_width
  60. self._keypoint_prediction_num_conv_layers = (
  61. keypoint_prediction_num_conv_layers)
  62. self._keypoint_prediction_conv_depth = keypoint_prediction_conv_depth
  63. def predict(self, features, num_predictions_per_location=1):
  64. """Performs keypoint prediction.
  65. Args:
  66. features: A float tensor of shape [batch_size, height, width,
  67. channels] containing features for a batch of images.
  68. num_predictions_per_location: Int containing number of predictions per
  69. location.
  70. Returns:
  71. instance_masks: A float tensor of shape
  72. [batch_size, 1, num_keypoints, heatmap_height, heatmap_width].
  73. Raises:
  74. ValueError: If num_predictions_per_location is not 1.
  75. """
  76. if num_predictions_per_location != 1:
  77. raise ValueError('Only num_predictions_per_location=1 is supported')
  78. with slim.arg_scope(self._conv_hyperparams_fn()):
  79. net = slim.conv2d(
  80. features,
  81. self._keypoint_prediction_conv_depth, [3, 3],
  82. scope='conv_1')
  83. for i in range(1, self._keypoint_prediction_num_conv_layers):
  84. net = slim.conv2d(
  85. net,
  86. self._keypoint_prediction_conv_depth, [3, 3],
  87. scope='conv_%d' % (i + 1))
  88. net = slim.conv2d_transpose(
  89. net, self._num_keypoints, [2, 2], scope='deconv1')
  90. heatmaps_mask = tf.image.resize_bilinear(
  91. net, [self._keypoint_heatmap_height, self._keypoint_heatmap_width],
  92. align_corners=True,
  93. name='upsample')
  94. return tf.expand_dims(
  95. tf.transpose(heatmaps_mask, perm=[0, 3, 1, 2]),
  96. axis=1,
  97. name='KeypointPredictor')