You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

138 lines
5.7 KiB

6 years ago
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for post_processing_builder."""
  16. import tensorflow as tf
  17. from google.protobuf import text_format
  18. from object_detection.builders import post_processing_builder
  19. from object_detection.protos import post_processing_pb2
  20. class PostProcessingBuilderTest(tf.test.TestCase):
  21. def test_build_non_max_suppressor_with_correct_parameters(self):
  22. post_processing_text_proto = """
  23. batch_non_max_suppression {
  24. score_threshold: 0.7
  25. iou_threshold: 0.6
  26. max_detections_per_class: 100
  27. max_total_detections: 300
  28. }
  29. """
  30. post_processing_config = post_processing_pb2.PostProcessing()
  31. text_format.Merge(post_processing_text_proto, post_processing_config)
  32. non_max_suppressor, _ = post_processing_builder.build(
  33. post_processing_config)
  34. self.assertEqual(non_max_suppressor.keywords['max_size_per_class'], 100)
  35. self.assertEqual(non_max_suppressor.keywords['max_total_size'], 300)
  36. self.assertAlmostEqual(non_max_suppressor.keywords['score_thresh'], 0.7)
  37. self.assertAlmostEqual(non_max_suppressor.keywords['iou_thresh'], 0.6)
  38. def test_build_identity_score_converter(self):
  39. post_processing_text_proto = """
  40. score_converter: IDENTITY
  41. """
  42. post_processing_config = post_processing_pb2.PostProcessing()
  43. text_format.Merge(post_processing_text_proto, post_processing_config)
  44. _, score_converter = post_processing_builder.build(
  45. post_processing_config)
  46. self.assertEqual(score_converter.__name__, 'identity_with_logit_scale')
  47. inputs = tf.constant([1, 1], tf.float32)
  48. outputs = score_converter(inputs)
  49. with self.test_session() as sess:
  50. converted_scores = sess.run(outputs)
  51. expected_converted_scores = sess.run(inputs)
  52. self.assertAllClose(converted_scores, expected_converted_scores)
  53. def test_build_identity_score_converter_with_logit_scale(self):
  54. post_processing_text_proto = """
  55. score_converter: IDENTITY
  56. logit_scale: 2.0
  57. """
  58. post_processing_config = post_processing_pb2.PostProcessing()
  59. text_format.Merge(post_processing_text_proto, post_processing_config)
  60. _, score_converter = post_processing_builder.build(post_processing_config)
  61. self.assertEqual(score_converter.__name__, 'identity_with_logit_scale')
  62. inputs = tf.constant([1, 1], tf.float32)
  63. outputs = score_converter(inputs)
  64. with self.test_session() as sess:
  65. converted_scores = sess.run(outputs)
  66. expected_converted_scores = sess.run(tf.constant([.5, .5], tf.float32))
  67. self.assertAllClose(converted_scores, expected_converted_scores)
  68. def test_build_sigmoid_score_converter(self):
  69. post_processing_text_proto = """
  70. score_converter: SIGMOID
  71. """
  72. post_processing_config = post_processing_pb2.PostProcessing()
  73. text_format.Merge(post_processing_text_proto, post_processing_config)
  74. _, score_converter = post_processing_builder.build(post_processing_config)
  75. self.assertEqual(score_converter.__name__, 'sigmoid_with_logit_scale')
  76. def test_build_softmax_score_converter(self):
  77. post_processing_text_proto = """
  78. score_converter: SOFTMAX
  79. """
  80. post_processing_config = post_processing_pb2.PostProcessing()
  81. text_format.Merge(post_processing_text_proto, post_processing_config)
  82. _, score_converter = post_processing_builder.build(post_processing_config)
  83. self.assertEqual(score_converter.__name__, 'softmax_with_logit_scale')
  84. def test_build_softmax_score_converter_with_temperature(self):
  85. post_processing_text_proto = """
  86. score_converter: SOFTMAX
  87. logit_scale: 2.0
  88. """
  89. post_processing_config = post_processing_pb2.PostProcessing()
  90. text_format.Merge(post_processing_text_proto, post_processing_config)
  91. _, score_converter = post_processing_builder.build(post_processing_config)
  92. self.assertEqual(score_converter.__name__, 'softmax_with_logit_scale')
  93. def test_build_calibrator_with_nonempty_config(self):
  94. """Test that identity function used when no calibration_config specified."""
  95. # Calibration config maps all scores to 0.5.
  96. post_processing_text_proto = """
  97. score_converter: SOFTMAX
  98. calibration_config {
  99. function_approximation {
  100. x_y_pairs {
  101. x_y_pair {
  102. x: 0.0
  103. y: 0.5
  104. }
  105. x_y_pair {
  106. x: 1.0
  107. y: 0.5
  108. }}}}"""
  109. post_processing_config = post_processing_pb2.PostProcessing()
  110. text_format.Merge(post_processing_text_proto, post_processing_config)
  111. _, calibrated_score_conversion_fn = post_processing_builder.build(
  112. post_processing_config)
  113. self.assertEqual(calibrated_score_conversion_fn.__name__,
  114. 'calibrate_with_function_approximation')
  115. input_scores = tf.constant([1, 1], tf.float32)
  116. outputs = calibrated_score_conversion_fn(input_scores)
  117. with self.test_session() as sess:
  118. calibrated_scores = sess.run(outputs)
  119. expected_calibrated_scores = sess.run(tf.constant([0.5, 0.5], tf.float32))
  120. self.assertAllClose(calibrated_scores, expected_calibrated_scores)
  121. if __name__ == '__main__':
  122. tf.test.main()