You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

171 lines
6.2 KiB

6 years ago
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Keypoint box coder.
  16. The keypoint box coder follows the coding schema described below (this is
  17. similar to the FasterRcnnBoxCoder, except that it encodes keypoints in addition
  18. to box coordinates):
  19. ty = (y - ya) / ha
  20. tx = (x - xa) / wa
  21. th = log(h / ha)
  22. tw = log(w / wa)
  23. tky0 = (ky0 - ya) / ha
  24. tkx0 = (kx0 - xa) / wa
  25. tky1 = (ky1 - ya) / ha
  26. tkx1 = (kx1 - xa) / wa
  27. ...
  28. where x, y, w, h denote the box's center coordinates, width and height
  29. respectively. Similarly, xa, ya, wa, ha denote the anchor's center
  30. coordinates, width and height. tx, ty, tw and th denote the anchor-encoded
  31. center, width and height respectively. ky0, kx0, ky1, kx1, ... denote the
  32. keypoints' coordinates, and tky0, tkx0, tky1, tkx1, ... denote the
  33. anchor-encoded keypoint coordinates.
  34. """
  35. import tensorflow as tf
  36. from object_detection.core import box_coder
  37. from object_detection.core import box_list
  38. from object_detection.core import standard_fields as fields
  39. EPSILON = 1e-8
  40. class KeypointBoxCoder(box_coder.BoxCoder):
  41. """Keypoint box coder."""
  42. def __init__(self, num_keypoints, scale_factors=None):
  43. """Constructor for KeypointBoxCoder.
  44. Args:
  45. num_keypoints: Number of keypoints to encode/decode.
  46. scale_factors: List of 4 positive scalars to scale ty, tx, th and tw.
  47. In addition to scaling ty and tx, the first 2 scalars are used to scale
  48. the y and x coordinates of the keypoints as well. If set to None, does
  49. not perform scaling.
  50. """
  51. self._num_keypoints = num_keypoints
  52. if scale_factors:
  53. assert len(scale_factors) == 4
  54. for scalar in scale_factors:
  55. assert scalar > 0
  56. self._scale_factors = scale_factors
  57. self._keypoint_scale_factors = None
  58. if scale_factors is not None:
  59. self._keypoint_scale_factors = tf.expand_dims(tf.tile(
  60. [tf.to_float(scale_factors[0]), tf.to_float(scale_factors[1])],
  61. [num_keypoints]), 1)
  62. @property
  63. def code_size(self):
  64. return 4 + self._num_keypoints * 2
  65. def _encode(self, boxes, anchors):
  66. """Encode a box and keypoint collection with respect to anchor collection.
  67. Args:
  68. boxes: BoxList holding N boxes and keypoints to be encoded. Boxes are
  69. tensors with the shape [N, 4], and keypoints are tensors with the shape
  70. [N, num_keypoints, 2].
  71. anchors: BoxList of anchors.
  72. Returns:
  73. a tensor representing N anchor-encoded boxes of the format
  74. [ty, tx, th, tw, tky0, tkx0, tky1, tkx1, ...] where tky0 and tkx0
  75. represent the y and x coordinates of the first keypoint, tky1 and tkx1
  76. represent the y and x coordinates of the second keypoint, and so on.
  77. """
  78. # Convert anchors to the center coordinate representation.
  79. ycenter_a, xcenter_a, ha, wa = anchors.get_center_coordinates_and_sizes()
  80. ycenter, xcenter, h, w = boxes.get_center_coordinates_and_sizes()
  81. keypoints = boxes.get_field(fields.BoxListFields.keypoints)
  82. keypoints = tf.transpose(tf.reshape(keypoints,
  83. [-1, self._num_keypoints * 2]))
  84. num_boxes = boxes.num_boxes()
  85. # Avoid NaN in division and log below.
  86. ha += EPSILON
  87. wa += EPSILON
  88. h += EPSILON
  89. w += EPSILON
  90. tx = (xcenter - xcenter_a) / wa
  91. ty = (ycenter - ycenter_a) / ha
  92. tw = tf.log(w / wa)
  93. th = tf.log(h / ha)
  94. tiled_anchor_centers = tf.tile(
  95. tf.stack([ycenter_a, xcenter_a]), [self._num_keypoints, 1])
  96. tiled_anchor_sizes = tf.tile(
  97. tf.stack([ha, wa]), [self._num_keypoints, 1])
  98. tkeypoints = (keypoints - tiled_anchor_centers) / tiled_anchor_sizes
  99. # Scales location targets as used in paper for joint training.
  100. if self._scale_factors:
  101. ty *= self._scale_factors[0]
  102. tx *= self._scale_factors[1]
  103. th *= self._scale_factors[2]
  104. tw *= self._scale_factors[3]
  105. tkeypoints *= tf.tile(self._keypoint_scale_factors, [1, num_boxes])
  106. tboxes = tf.stack([ty, tx, th, tw])
  107. return tf.transpose(tf.concat([tboxes, tkeypoints], 0))
  108. def _decode(self, rel_codes, anchors):
  109. """Decode relative codes to boxes and keypoints.
  110. Args:
  111. rel_codes: a tensor with shape [N, 4 + 2 * num_keypoints] representing N
  112. anchor-encoded boxes and keypoints
  113. anchors: BoxList of anchors.
  114. Returns:
  115. boxes: BoxList holding N bounding boxes and keypoints.
  116. """
  117. ycenter_a, xcenter_a, ha, wa = anchors.get_center_coordinates_and_sizes()
  118. num_codes = tf.shape(rel_codes)[0]
  119. result = tf.unstack(tf.transpose(rel_codes))
  120. ty, tx, th, tw = result[:4]
  121. tkeypoints = result[4:]
  122. if self._scale_factors:
  123. ty /= self._scale_factors[0]
  124. tx /= self._scale_factors[1]
  125. th /= self._scale_factors[2]
  126. tw /= self._scale_factors[3]
  127. tkeypoints /= tf.tile(self._keypoint_scale_factors, [1, num_codes])
  128. w = tf.exp(tw) * wa
  129. h = tf.exp(th) * ha
  130. ycenter = ty * ha + ycenter_a
  131. xcenter = tx * wa + xcenter_a
  132. ymin = ycenter - h / 2.
  133. xmin = xcenter - w / 2.
  134. ymax = ycenter + h / 2.
  135. xmax = xcenter + w / 2.
  136. decoded_boxes_keypoints = box_list.BoxList(
  137. tf.transpose(tf.stack([ymin, xmin, ymax, xmax])))
  138. tiled_anchor_centers = tf.tile(
  139. tf.stack([ycenter_a, xcenter_a]), [self._num_keypoints, 1])
  140. tiled_anchor_sizes = tf.tile(
  141. tf.stack([ha, wa]), [self._num_keypoints, 1])
  142. keypoints = tkeypoints * tiled_anchor_sizes + tiled_anchor_centers
  143. keypoints = tf.reshape(tf.transpose(keypoints),
  144. [-1, self._num_keypoints, 2])
  145. decoded_boxes_keypoints.add_field(fields.BoxListFields.keypoints, keypoints)
  146. return decoded_boxes_keypoints