You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

66 lines
2.8 KiB

  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """A function to build an object detection box coder from configuration."""
  16. from object_detection.box_coders import faster_rcnn_box_coder
  17. from object_detection.box_coders import keypoint_box_coder
  18. from object_detection.box_coders import mean_stddev_box_coder
  19. from object_detection.box_coders import square_box_coder
  20. from object_detection.protos import box_coder_pb2
  21. def build(box_coder_config):
  22. """Builds a box coder object based on the box coder config.
  23. Args:
  24. box_coder_config: A box_coder.proto object containing the config for the
  25. desired box coder.
  26. Returns:
  27. BoxCoder based on the config.
  28. Raises:
  29. ValueError: On empty box coder proto.
  30. """
  31. if not isinstance(box_coder_config, box_coder_pb2.BoxCoder):
  32. raise ValueError('box_coder_config not of type box_coder_pb2.BoxCoder.')
  33. if box_coder_config.WhichOneof('box_coder_oneof') == 'faster_rcnn_box_coder':
  34. return faster_rcnn_box_coder.FasterRcnnBoxCoder(scale_factors=[
  35. box_coder_config.faster_rcnn_box_coder.y_scale,
  36. box_coder_config.faster_rcnn_box_coder.x_scale,
  37. box_coder_config.faster_rcnn_box_coder.height_scale,
  38. box_coder_config.faster_rcnn_box_coder.width_scale
  39. ])
  40. if box_coder_config.WhichOneof('box_coder_oneof') == 'keypoint_box_coder':
  41. return keypoint_box_coder.KeypointBoxCoder(
  42. box_coder_config.keypoint_box_coder.num_keypoints,
  43. scale_factors=[
  44. box_coder_config.keypoint_box_coder.y_scale,
  45. box_coder_config.keypoint_box_coder.x_scale,
  46. box_coder_config.keypoint_box_coder.height_scale,
  47. box_coder_config.keypoint_box_coder.width_scale
  48. ])
  49. if (box_coder_config.WhichOneof('box_coder_oneof') ==
  50. 'mean_stddev_box_coder'):
  51. return mean_stddev_box_coder.MeanStddevBoxCoder(
  52. stddev=box_coder_config.mean_stddev_box_coder.stddev)
  53. if box_coder_config.WhichOneof('box_coder_oneof') == 'square_box_coder':
  54. return square_box_coder.SquareBoxCoder(scale_factors=[
  55. box_coder_config.square_box_coder.y_scale,
  56. box_coder_config.square_box_coder.x_scale,
  57. box_coder_config.square_box_coder.length_scale
  58. ])
  59. raise ValueError('Empty box coder.')