You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

235 lines
10 KiB

  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for object_detection.matchers.argmax_matcher."""
  16. import numpy as np
  17. import tensorflow as tf
  18. from object_detection.matchers import argmax_matcher
  19. from object_detection.utils import test_case
  20. class ArgMaxMatcherTest(test_case.TestCase):
  21. def test_return_correct_matches_with_default_thresholds(self):
  22. def graph_fn(similarity_matrix):
  23. matcher = argmax_matcher.ArgMaxMatcher(matched_threshold=None)
  24. match = matcher.match(similarity_matrix)
  25. matched_cols = match.matched_column_indicator()
  26. unmatched_cols = match.unmatched_column_indicator()
  27. match_results = match.match_results
  28. return (matched_cols, unmatched_cols, match_results)
  29. similarity = np.array([[1., 1, 1, 3, 1],
  30. [2, -1, 2, 0, 4],
  31. [3, 0, -1, 0, 0]], dtype=np.float32)
  32. expected_matched_rows = np.array([2, 0, 1, 0, 1])
  33. (res_matched_cols, res_unmatched_cols,
  34. res_match_results) = self.execute(graph_fn, [similarity])
  35. self.assertAllEqual(res_match_results[res_matched_cols],
  36. expected_matched_rows)
  37. self.assertAllEqual(np.nonzero(res_matched_cols)[0], [0, 1, 2, 3, 4])
  38. self.assertFalse(np.all(res_unmatched_cols))
  39. def test_return_correct_matches_with_empty_rows(self):
  40. def graph_fn(similarity_matrix):
  41. matcher = argmax_matcher.ArgMaxMatcher(matched_threshold=None)
  42. match = matcher.match(similarity_matrix)
  43. return match.unmatched_column_indicator()
  44. similarity = 0.2 * np.ones([0, 5], dtype=np.float32)
  45. res_unmatched_cols = self.execute(graph_fn, [similarity])
  46. self.assertAllEqual(np.nonzero(res_unmatched_cols)[0], np.arange(5))
  47. def test_return_correct_matches_with_matched_threshold(self):
  48. def graph_fn(similarity):
  49. matcher = argmax_matcher.ArgMaxMatcher(matched_threshold=3.)
  50. match = matcher.match(similarity)
  51. matched_cols = match.matched_column_indicator()
  52. unmatched_cols = match.unmatched_column_indicator()
  53. match_results = match.match_results
  54. return (matched_cols, unmatched_cols, match_results)
  55. similarity = np.array([[1, 1, 1, 3, 1],
  56. [2, -1, 2, 0, 4],
  57. [3, 0, -1, 0, 0]], dtype=np.float32)
  58. expected_matched_cols = np.array([0, 3, 4])
  59. expected_matched_rows = np.array([2, 0, 1])
  60. expected_unmatched_cols = np.array([1, 2])
  61. (res_matched_cols, res_unmatched_cols,
  62. match_results) = self.execute(graph_fn, [similarity])
  63. self.assertAllEqual(match_results[res_matched_cols], expected_matched_rows)
  64. self.assertAllEqual(np.nonzero(res_matched_cols)[0], expected_matched_cols)
  65. self.assertAllEqual(np.nonzero(res_unmatched_cols)[0],
  66. expected_unmatched_cols)
  67. def test_return_correct_matches_with_matched_and_unmatched_threshold(self):
  68. def graph_fn(similarity):
  69. matcher = argmax_matcher.ArgMaxMatcher(matched_threshold=3.,
  70. unmatched_threshold=2.)
  71. match = matcher.match(similarity)
  72. matched_cols = match.matched_column_indicator()
  73. unmatched_cols = match.unmatched_column_indicator()
  74. match_results = match.match_results
  75. return (matched_cols, unmatched_cols, match_results)
  76. similarity = np.array([[1, 1, 1, 3, 1],
  77. [2, -1, 2, 0, 4],
  78. [3, 0, -1, 0, 0]], dtype=np.float32)
  79. expected_matched_cols = np.array([0, 3, 4])
  80. expected_matched_rows = np.array([2, 0, 1])
  81. expected_unmatched_cols = np.array([1]) # col 2 has too high maximum val
  82. (res_matched_cols, res_unmatched_cols,
  83. match_results) = self.execute(graph_fn, [similarity])
  84. self.assertAllEqual(match_results[res_matched_cols], expected_matched_rows)
  85. self.assertAllEqual(np.nonzero(res_matched_cols)[0], expected_matched_cols)
  86. self.assertAllEqual(np.nonzero(res_unmatched_cols)[0],
  87. expected_unmatched_cols)
  88. def test_return_correct_matches_negatives_lower_than_unmatched_false(self):
  89. def graph_fn(similarity):
  90. matcher = argmax_matcher.ArgMaxMatcher(
  91. matched_threshold=3.,
  92. unmatched_threshold=2.,
  93. negatives_lower_than_unmatched=False)
  94. match = matcher.match(similarity)
  95. matched_cols = match.matched_column_indicator()
  96. unmatched_cols = match.unmatched_column_indicator()
  97. match_results = match.match_results
  98. return (matched_cols, unmatched_cols, match_results)
  99. similarity = np.array([[1, 1, 1, 3, 1],
  100. [2, -1, 2, 0, 4],
  101. [3, 0, -1, 0, 0]], dtype=np.float32)
  102. expected_matched_cols = np.array([0, 3, 4])
  103. expected_matched_rows = np.array([2, 0, 1])
  104. expected_unmatched_cols = np.array([2]) # col 1 has too low maximum val
  105. (res_matched_cols, res_unmatched_cols,
  106. match_results) = self.execute(graph_fn, [similarity])
  107. self.assertAllEqual(match_results[res_matched_cols], expected_matched_rows)
  108. self.assertAllEqual(np.nonzero(res_matched_cols)[0], expected_matched_cols)
  109. self.assertAllEqual(np.nonzero(res_unmatched_cols)[0],
  110. expected_unmatched_cols)
  111. def test_return_correct_matches_unmatched_row_not_using_force_match(self):
  112. def graph_fn(similarity):
  113. matcher = argmax_matcher.ArgMaxMatcher(matched_threshold=3.,
  114. unmatched_threshold=2.)
  115. match = matcher.match(similarity)
  116. matched_cols = match.matched_column_indicator()
  117. unmatched_cols = match.unmatched_column_indicator()
  118. match_results = match.match_results
  119. return (matched_cols, unmatched_cols, match_results)
  120. similarity = np.array([[1, 1, 1, 3, 1],
  121. [-1, 0, -2, -2, -1],
  122. [3, 0, -1, 2, 0]], dtype=np.float32)
  123. expected_matched_cols = np.array([0, 3])
  124. expected_matched_rows = np.array([2, 0])
  125. expected_unmatched_cols = np.array([1, 2, 4])
  126. (res_matched_cols, res_unmatched_cols,
  127. match_results) = self.execute(graph_fn, [similarity])
  128. self.assertAllEqual(match_results[res_matched_cols], expected_matched_rows)
  129. self.assertAllEqual(np.nonzero(res_matched_cols)[0], expected_matched_cols)
  130. self.assertAllEqual(np.nonzero(res_unmatched_cols)[0],
  131. expected_unmatched_cols)
  132. def test_return_correct_matches_unmatched_row_while_using_force_match(self):
  133. def graph_fn(similarity):
  134. matcher = argmax_matcher.ArgMaxMatcher(matched_threshold=3.,
  135. unmatched_threshold=2.,
  136. force_match_for_each_row=True)
  137. match = matcher.match(similarity)
  138. matched_cols = match.matched_column_indicator()
  139. unmatched_cols = match.unmatched_column_indicator()
  140. match_results = match.match_results
  141. return (matched_cols, unmatched_cols, match_results)
  142. similarity = np.array([[1, 1, 1, 3, 1],
  143. [-1, 0, -2, -2, -1],
  144. [3, 0, -1, 2, 0]], dtype=np.float32)
  145. expected_matched_cols = np.array([0, 1, 3])
  146. expected_matched_rows = np.array([2, 1, 0])
  147. expected_unmatched_cols = np.array([2, 4]) # col 2 has too high max val
  148. (res_matched_cols, res_unmatched_cols,
  149. match_results) = self.execute(graph_fn, [similarity])
  150. self.assertAllEqual(match_results[res_matched_cols], expected_matched_rows)
  151. self.assertAllEqual(np.nonzero(res_matched_cols)[0], expected_matched_cols)
  152. self.assertAllEqual(np.nonzero(res_unmatched_cols)[0],
  153. expected_unmatched_cols)
  154. def test_return_correct_matches_using_force_match_padded_groundtruth(self):
  155. def graph_fn(similarity, valid_rows):
  156. matcher = argmax_matcher.ArgMaxMatcher(matched_threshold=3.,
  157. unmatched_threshold=2.,
  158. force_match_for_each_row=True)
  159. match = matcher.match(similarity, valid_rows)
  160. matched_cols = match.matched_column_indicator()
  161. unmatched_cols = match.unmatched_column_indicator()
  162. match_results = match.match_results
  163. return (matched_cols, unmatched_cols, match_results)
  164. similarity = np.array([[1, 1, 1, 3, 1],
  165. [-1, 0, -2, -2, -1],
  166. [0, 0, 0, 0, 0],
  167. [3, 0, -1, 2, 0],
  168. [0, 0, 0, 0, 0]], dtype=np.float32)
  169. valid_rows = np.array([True, True, False, True, False])
  170. expected_matched_cols = np.array([0, 1, 3])
  171. expected_matched_rows = np.array([3, 1, 0])
  172. expected_unmatched_cols = np.array([2, 4]) # col 2 has too high max val
  173. (res_matched_cols, res_unmatched_cols,
  174. match_results) = self.execute(graph_fn, [similarity, valid_rows])
  175. self.assertAllEqual(match_results[res_matched_cols], expected_matched_rows)
  176. self.assertAllEqual(np.nonzero(res_matched_cols)[0], expected_matched_cols)
  177. self.assertAllEqual(np.nonzero(res_unmatched_cols)[0],
  178. expected_unmatched_cols)
  179. def test_valid_arguments_corner_case(self):
  180. argmax_matcher.ArgMaxMatcher(matched_threshold=1,
  181. unmatched_threshold=1)
  182. def test_invalid_arguments_corner_case_negatives_lower_than_thres_false(self):
  183. with self.assertRaises(ValueError):
  184. argmax_matcher.ArgMaxMatcher(matched_threshold=1,
  185. unmatched_threshold=1,
  186. negatives_lower_than_unmatched=False)
  187. def test_invalid_arguments_no_matched_threshold(self):
  188. with self.assertRaises(ValueError):
  189. argmax_matcher.ArgMaxMatcher(matched_threshold=None,
  190. unmatched_threshold=4)
  191. def test_invalid_arguments_unmatched_thres_larger_than_matched_thres(self):
  192. with self.assertRaises(ValueError):
  193. argmax_matcher.ArgMaxMatcher(matched_threshold=1,
  194. unmatched_threshold=2)
  195. if __name__ == '__main__':
  196. tf.test.main()