|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
"""Keypoint Head.
|
|
|
|
Contains Keypoint prediction head classes for different meta architectures.
|
|
All the keypoint prediction heads have a predict function that receives the
|
|
`features` as the first argument and returns `keypoint_predictions`.
|
|
Keypoints could be used to represent the human body joint locations as in
|
|
Mask RCNN paper. Or they could be used to represent different part locations of
|
|
objects.
|
|
"""
|
|
import tensorflow as tf
|
|
|
|
from object_detection.predictors.heads import head
|
|
slim = tf.contrib.slim
|
|
|
|
|
|
class MaskRCNNKeypointHead(head.Head):
|
|
"""Mask RCNN keypoint prediction head.
|
|
|
|
Please refer to Mask RCNN paper:
|
|
https://arxiv.org/abs/1703.06870
|
|
"""
|
|
|
|
def __init__(self,
|
|
num_keypoints=17,
|
|
conv_hyperparams_fn=None,
|
|
keypoint_heatmap_height=56,
|
|
keypoint_heatmap_width=56,
|
|
keypoint_prediction_num_conv_layers=8,
|
|
keypoint_prediction_conv_depth=512):
|
|
"""Constructor.
|
|
|
|
Args:
|
|
num_keypoints: (int scalar) number of keypoints.
|
|
conv_hyperparams_fn: A function to generate tf-slim arg_scope with
|
|
hyperparameters for convolution ops.
|
|
keypoint_heatmap_height: Desired output mask height. The default value
|
|
is 14.
|
|
keypoint_heatmap_width: Desired output mask width. The default value
|
|
is 14.
|
|
keypoint_prediction_num_conv_layers: Number of convolution layers applied
|
|
to the image_features in mask prediction branch.
|
|
keypoint_prediction_conv_depth: The depth for the first conv2d_transpose
|
|
op applied to the image_features in the mask prediction branch. If set
|
|
to 0, the depth of the convolution layers will be automatically chosen
|
|
based on the number of object classes and the number of channels in the
|
|
image features.
|
|
"""
|
|
super(MaskRCNNKeypointHead, self).__init__()
|
|
self._num_keypoints = num_keypoints
|
|
self._conv_hyperparams_fn = conv_hyperparams_fn
|
|
self._keypoint_heatmap_height = keypoint_heatmap_height
|
|
self._keypoint_heatmap_width = keypoint_heatmap_width
|
|
self._keypoint_prediction_num_conv_layers = (
|
|
keypoint_prediction_num_conv_layers)
|
|
self._keypoint_prediction_conv_depth = keypoint_prediction_conv_depth
|
|
|
|
def predict(self, features, num_predictions_per_location=1):
|
|
"""Performs keypoint prediction.
|
|
|
|
Args:
|
|
features: A float tensor of shape [batch_size, height, width,
|
|
channels] containing features for a batch of images.
|
|
num_predictions_per_location: Int containing number of predictions per
|
|
location.
|
|
|
|
Returns:
|
|
instance_masks: A float tensor of shape
|
|
[batch_size, 1, num_keypoints, heatmap_height, heatmap_width].
|
|
|
|
Raises:
|
|
ValueError: If num_predictions_per_location is not 1.
|
|
"""
|
|
if num_predictions_per_location != 1:
|
|
raise ValueError('Only num_predictions_per_location=1 is supported')
|
|
with slim.arg_scope(self._conv_hyperparams_fn()):
|
|
net = slim.conv2d(
|
|
features,
|
|
self._keypoint_prediction_conv_depth, [3, 3],
|
|
scope='conv_1')
|
|
for i in range(1, self._keypoint_prediction_num_conv_layers):
|
|
net = slim.conv2d(
|
|
net,
|
|
self._keypoint_prediction_conv_depth, [3, 3],
|
|
scope='conv_%d' % (i + 1))
|
|
net = slim.conv2d_transpose(
|
|
net, self._num_keypoints, [2, 2], scope='deconv1')
|
|
heatmaps_mask = tf.image.resize_bilinear(
|
|
net, [self._keypoint_heatmap_height, self._keypoint_heatmap_width],
|
|
align_corners=True,
|
|
name='upsample')
|
|
return tf.expand_dims(
|
|
tf.transpose(heatmaps_mask, perm=[0, 3, 1, 2]),
|
|
axis=1,
|
|
name='KeypointPredictor')
|