You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

136 lines
5.0 KiB

6 years ago
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for box_coder_builder."""
  16. import tensorflow as tf
  17. from google.protobuf import text_format
  18. from object_detection.box_coders import faster_rcnn_box_coder
  19. from object_detection.box_coders import keypoint_box_coder
  20. from object_detection.box_coders import mean_stddev_box_coder
  21. from object_detection.box_coders import square_box_coder
  22. from object_detection.builders import box_coder_builder
  23. from object_detection.protos import box_coder_pb2
  24. class BoxCoderBuilderTest(tf.test.TestCase):
  25. def test_build_faster_rcnn_box_coder_with_defaults(self):
  26. box_coder_text_proto = """
  27. faster_rcnn_box_coder {
  28. }
  29. """
  30. box_coder_proto = box_coder_pb2.BoxCoder()
  31. text_format.Merge(box_coder_text_proto, box_coder_proto)
  32. box_coder_object = box_coder_builder.build(box_coder_proto)
  33. self.assertIsInstance(box_coder_object,
  34. faster_rcnn_box_coder.FasterRcnnBoxCoder)
  35. self.assertEqual(box_coder_object._scale_factors, [10.0, 10.0, 5.0, 5.0])
  36. def test_build_faster_rcnn_box_coder_with_non_default_parameters(self):
  37. box_coder_text_proto = """
  38. faster_rcnn_box_coder {
  39. y_scale: 6.0
  40. x_scale: 3.0
  41. height_scale: 7.0
  42. width_scale: 8.0
  43. }
  44. """
  45. box_coder_proto = box_coder_pb2.BoxCoder()
  46. text_format.Merge(box_coder_text_proto, box_coder_proto)
  47. box_coder_object = box_coder_builder.build(box_coder_proto)
  48. self.assertIsInstance(box_coder_object,
  49. faster_rcnn_box_coder.FasterRcnnBoxCoder)
  50. self.assertEqual(box_coder_object._scale_factors, [6.0, 3.0, 7.0, 8.0])
  51. def test_build_keypoint_box_coder_with_defaults(self):
  52. box_coder_text_proto = """
  53. keypoint_box_coder {
  54. }
  55. """
  56. box_coder_proto = box_coder_pb2.BoxCoder()
  57. text_format.Merge(box_coder_text_proto, box_coder_proto)
  58. box_coder_object = box_coder_builder.build(box_coder_proto)
  59. self.assertIsInstance(box_coder_object, keypoint_box_coder.KeypointBoxCoder)
  60. self.assertEqual(box_coder_object._scale_factors, [10.0, 10.0, 5.0, 5.0])
  61. def test_build_keypoint_box_coder_with_non_default_parameters(self):
  62. box_coder_text_proto = """
  63. keypoint_box_coder {
  64. num_keypoints: 6
  65. y_scale: 6.0
  66. x_scale: 3.0
  67. height_scale: 7.0
  68. width_scale: 8.0
  69. }
  70. """
  71. box_coder_proto = box_coder_pb2.BoxCoder()
  72. text_format.Merge(box_coder_text_proto, box_coder_proto)
  73. box_coder_object = box_coder_builder.build(box_coder_proto)
  74. self.assertIsInstance(box_coder_object, keypoint_box_coder.KeypointBoxCoder)
  75. self.assertEqual(box_coder_object._num_keypoints, 6)
  76. self.assertEqual(box_coder_object._scale_factors, [6.0, 3.0, 7.0, 8.0])
  77. def test_build_mean_stddev_box_coder(self):
  78. box_coder_text_proto = """
  79. mean_stddev_box_coder {
  80. }
  81. """
  82. box_coder_proto = box_coder_pb2.BoxCoder()
  83. text_format.Merge(box_coder_text_proto, box_coder_proto)
  84. box_coder_object = box_coder_builder.build(box_coder_proto)
  85. self.assertTrue(
  86. isinstance(box_coder_object,
  87. mean_stddev_box_coder.MeanStddevBoxCoder))
  88. def test_build_square_box_coder_with_defaults(self):
  89. box_coder_text_proto = """
  90. square_box_coder {
  91. }
  92. """
  93. box_coder_proto = box_coder_pb2.BoxCoder()
  94. text_format.Merge(box_coder_text_proto, box_coder_proto)
  95. box_coder_object = box_coder_builder.build(box_coder_proto)
  96. self.assertTrue(
  97. isinstance(box_coder_object, square_box_coder.SquareBoxCoder))
  98. self.assertEqual(box_coder_object._scale_factors, [10.0, 10.0, 5.0])
  99. def test_build_square_box_coder_with_non_default_parameters(self):
  100. box_coder_text_proto = """
  101. square_box_coder {
  102. y_scale: 6.0
  103. x_scale: 3.0
  104. length_scale: 7.0
  105. }
  106. """
  107. box_coder_proto = box_coder_pb2.BoxCoder()
  108. text_format.Merge(box_coder_text_proto, box_coder_proto)
  109. box_coder_object = box_coder_builder.build(box_coder_proto)
  110. self.assertTrue(
  111. isinstance(box_coder_object, square_box_coder.SquareBoxCoder))
  112. self.assertEqual(box_coder_object._scale_factors, [6.0, 3.0, 7.0])
  113. def test_raise_error_on_empty_box_coder(self):
  114. box_coder_text_proto = """
  115. """
  116. box_coder_proto = box_coder_pb2.BoxCoder()
  117. text_format.Merge(box_coder_text_proto, box_coder_proto)
  118. with self.assertRaises(ValueError):
  119. box_coder_builder.build(box_coder_proto)
  120. if __name__ == '__main__':
  121. tf.test.main()