You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

99 lines
3.8 KiB

  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for matcher_builder."""
  16. import tensorflow as tf
  17. from google.protobuf import text_format
  18. from object_detection.builders import matcher_builder
  19. from object_detection.matchers import argmax_matcher
  20. from object_detection.matchers import bipartite_matcher
  21. from object_detection.protos import matcher_pb2
  22. class MatcherBuilderTest(tf.test.TestCase):
  23. def test_build_arg_max_matcher_with_defaults(self):
  24. matcher_text_proto = """
  25. argmax_matcher {
  26. }
  27. """
  28. matcher_proto = matcher_pb2.Matcher()
  29. text_format.Merge(matcher_text_proto, matcher_proto)
  30. matcher_object = matcher_builder.build(matcher_proto)
  31. self.assertTrue(isinstance(matcher_object, argmax_matcher.ArgMaxMatcher))
  32. self.assertAlmostEqual(matcher_object._matched_threshold, 0.5)
  33. self.assertAlmostEqual(matcher_object._unmatched_threshold, 0.5)
  34. self.assertTrue(matcher_object._negatives_lower_than_unmatched)
  35. self.assertFalse(matcher_object._force_match_for_each_row)
  36. def test_build_arg_max_matcher_without_thresholds(self):
  37. matcher_text_proto = """
  38. argmax_matcher {
  39. ignore_thresholds: true
  40. }
  41. """
  42. matcher_proto = matcher_pb2.Matcher()
  43. text_format.Merge(matcher_text_proto, matcher_proto)
  44. matcher_object = matcher_builder.build(matcher_proto)
  45. self.assertTrue(isinstance(matcher_object, argmax_matcher.ArgMaxMatcher))
  46. self.assertEqual(matcher_object._matched_threshold, None)
  47. self.assertEqual(matcher_object._unmatched_threshold, None)
  48. self.assertTrue(matcher_object._negatives_lower_than_unmatched)
  49. self.assertFalse(matcher_object._force_match_for_each_row)
  50. def test_build_arg_max_matcher_with_non_default_parameters(self):
  51. matcher_text_proto = """
  52. argmax_matcher {
  53. matched_threshold: 0.7
  54. unmatched_threshold: 0.3
  55. negatives_lower_than_unmatched: false
  56. force_match_for_each_row: true
  57. use_matmul_gather: true
  58. }
  59. """
  60. matcher_proto = matcher_pb2.Matcher()
  61. text_format.Merge(matcher_text_proto, matcher_proto)
  62. matcher_object = matcher_builder.build(matcher_proto)
  63. self.assertTrue(isinstance(matcher_object, argmax_matcher.ArgMaxMatcher))
  64. self.assertAlmostEqual(matcher_object._matched_threshold, 0.7)
  65. self.assertAlmostEqual(matcher_object._unmatched_threshold, 0.3)
  66. self.assertFalse(matcher_object._negatives_lower_than_unmatched)
  67. self.assertTrue(matcher_object._force_match_for_each_row)
  68. self.assertTrue(matcher_object._use_matmul_gather)
  69. def test_build_bipartite_matcher(self):
  70. matcher_text_proto = """
  71. bipartite_matcher {
  72. }
  73. """
  74. matcher_proto = matcher_pb2.Matcher()
  75. text_format.Merge(matcher_text_proto, matcher_proto)
  76. matcher_object = matcher_builder.build(matcher_proto)
  77. self.assertTrue(
  78. isinstance(matcher_object, bipartite_matcher.GreedyBipartiteMatcher))
  79. def test_raise_error_on_empty_matcher(self):
  80. matcher_text_proto = """
  81. """
  82. matcher_proto = matcher_pb2.Matcher()
  83. text_format.Merge(matcher_text_proto, matcher_proto)
  84. with self.assertRaises(ValueError):
  85. matcher_builder.build(matcher_proto)
  86. if __name__ == '__main__':
  87. tf.test.main()