You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

81 lines
3.3 KiB

6 years ago
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for object_detection.core.bipartite_matcher."""
  16. import tensorflow as tf
  17. from object_detection.matchers import bipartite_matcher
  18. class GreedyBipartiteMatcherTest(tf.test.TestCase):
  19. def test_get_expected_matches_when_all_rows_are_valid(self):
  20. similarity_matrix = tf.constant([[0.50, 0.1, 0.8], [0.15, 0.2, 0.3]])
  21. valid_rows = tf.ones([2], dtype=tf.bool)
  22. expected_match_results = [-1, 1, 0]
  23. matcher = bipartite_matcher.GreedyBipartiteMatcher()
  24. match = matcher.match(similarity_matrix, valid_rows=valid_rows)
  25. with self.test_session() as sess:
  26. match_results_out = sess.run(match._match_results)
  27. self.assertAllEqual(match_results_out, expected_match_results)
  28. def test_get_expected_matches_with_all_rows_be_default(self):
  29. similarity_matrix = tf.constant([[0.50, 0.1, 0.8], [0.15, 0.2, 0.3]])
  30. expected_match_results = [-1, 1, 0]
  31. matcher = bipartite_matcher.GreedyBipartiteMatcher()
  32. match = matcher.match(similarity_matrix)
  33. with self.test_session() as sess:
  34. match_results_out = sess.run(match._match_results)
  35. self.assertAllEqual(match_results_out, expected_match_results)
  36. def test_get_no_matches_with_zero_valid_rows(self):
  37. similarity_matrix = tf.constant([[0.50, 0.1, 0.8], [0.15, 0.2, 0.3]])
  38. valid_rows = tf.zeros([2], dtype=tf.bool)
  39. expected_match_results = [-1, -1, -1]
  40. matcher = bipartite_matcher.GreedyBipartiteMatcher()
  41. match = matcher.match(similarity_matrix, valid_rows)
  42. with self.test_session() as sess:
  43. match_results_out = sess.run(match._match_results)
  44. self.assertAllEqual(match_results_out, expected_match_results)
  45. def test_get_expected_matches_with_only_one_valid_row(self):
  46. similarity_matrix = tf.constant([[0.50, 0.1, 0.8], [0.15, 0.2, 0.3]])
  47. valid_rows = tf.constant([True, False], dtype=tf.bool)
  48. expected_match_results = [-1, -1, 0]
  49. matcher = bipartite_matcher.GreedyBipartiteMatcher()
  50. match = matcher.match(similarity_matrix, valid_rows)
  51. with self.test_session() as sess:
  52. match_results_out = sess.run(match._match_results)
  53. self.assertAllEqual(match_results_out, expected_match_results)
  54. def test_get_expected_matches_with_only_one_valid_row_at_bottom(self):
  55. similarity_matrix = tf.constant([[0.15, 0.2, 0.3], [0.50, 0.1, 0.8]])
  56. valid_rows = tf.constant([False, True], dtype=tf.bool)
  57. expected_match_results = [-1, -1, 0]
  58. matcher = bipartite_matcher.GreedyBipartiteMatcher()
  59. match = matcher.match(similarity_matrix, valid_rows)
  60. with self.test_session() as sess:
  61. match_results_out = sess.run(match._match_results)
  62. self.assertAllEqual(match_results_out, expected_match_results)
  63. if __name__ == '__main__':
  64. tf.test.main()