|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
"""Tests for object_detection.matchers.argmax_matcher."""
|
|
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
|
|
from object_detection.matchers import argmax_matcher
|
|
from object_detection.utils import test_case
|
|
|
|
|
|
class ArgMaxMatcherTest(test_case.TestCase):
|
|
|
|
def test_return_correct_matches_with_default_thresholds(self):
|
|
|
|
def graph_fn(similarity_matrix):
|
|
matcher = argmax_matcher.ArgMaxMatcher(matched_threshold=None)
|
|
match = matcher.match(similarity_matrix)
|
|
matched_cols = match.matched_column_indicator()
|
|
unmatched_cols = match.unmatched_column_indicator()
|
|
match_results = match.match_results
|
|
return (matched_cols, unmatched_cols, match_results)
|
|
|
|
similarity = np.array([[1., 1, 1, 3, 1],
|
|
[2, -1, 2, 0, 4],
|
|
[3, 0, -1, 0, 0]], dtype=np.float32)
|
|
expected_matched_rows = np.array([2, 0, 1, 0, 1])
|
|
(res_matched_cols, res_unmatched_cols,
|
|
res_match_results) = self.execute(graph_fn, [similarity])
|
|
|
|
self.assertAllEqual(res_match_results[res_matched_cols],
|
|
expected_matched_rows)
|
|
self.assertAllEqual(np.nonzero(res_matched_cols)[0], [0, 1, 2, 3, 4])
|
|
self.assertFalse(np.all(res_unmatched_cols))
|
|
|
|
def test_return_correct_matches_with_empty_rows(self):
|
|
|
|
def graph_fn(similarity_matrix):
|
|
matcher = argmax_matcher.ArgMaxMatcher(matched_threshold=None)
|
|
match = matcher.match(similarity_matrix)
|
|
return match.unmatched_column_indicator()
|
|
similarity = 0.2 * np.ones([0, 5], dtype=np.float32)
|
|
res_unmatched_cols = self.execute(graph_fn, [similarity])
|
|
self.assertAllEqual(np.nonzero(res_unmatched_cols)[0], np.arange(5))
|
|
|
|
def test_return_correct_matches_with_matched_threshold(self):
|
|
|
|
def graph_fn(similarity):
|
|
matcher = argmax_matcher.ArgMaxMatcher(matched_threshold=3.)
|
|
match = matcher.match(similarity)
|
|
matched_cols = match.matched_column_indicator()
|
|
unmatched_cols = match.unmatched_column_indicator()
|
|
match_results = match.match_results
|
|
return (matched_cols, unmatched_cols, match_results)
|
|
|
|
similarity = np.array([[1, 1, 1, 3, 1],
|
|
[2, -1, 2, 0, 4],
|
|
[3, 0, -1, 0, 0]], dtype=np.float32)
|
|
expected_matched_cols = np.array([0, 3, 4])
|
|
expected_matched_rows = np.array([2, 0, 1])
|
|
expected_unmatched_cols = np.array([1, 2])
|
|
|
|
(res_matched_cols, res_unmatched_cols,
|
|
match_results) = self.execute(graph_fn, [similarity])
|
|
self.assertAllEqual(match_results[res_matched_cols], expected_matched_rows)
|
|
self.assertAllEqual(np.nonzero(res_matched_cols)[0], expected_matched_cols)
|
|
self.assertAllEqual(np.nonzero(res_unmatched_cols)[0],
|
|
expected_unmatched_cols)
|
|
|
|
def test_return_correct_matches_with_matched_and_unmatched_threshold(self):
|
|
|
|
def graph_fn(similarity):
|
|
matcher = argmax_matcher.ArgMaxMatcher(matched_threshold=3.,
|
|
unmatched_threshold=2.)
|
|
match = matcher.match(similarity)
|
|
matched_cols = match.matched_column_indicator()
|
|
unmatched_cols = match.unmatched_column_indicator()
|
|
match_results = match.match_results
|
|
return (matched_cols, unmatched_cols, match_results)
|
|
|
|
similarity = np.array([[1, 1, 1, 3, 1],
|
|
[2, -1, 2, 0, 4],
|
|
[3, 0, -1, 0, 0]], dtype=np.float32)
|
|
expected_matched_cols = np.array([0, 3, 4])
|
|
expected_matched_rows = np.array([2, 0, 1])
|
|
expected_unmatched_cols = np.array([1]) # col 2 has too high maximum val
|
|
|
|
(res_matched_cols, res_unmatched_cols,
|
|
match_results) = self.execute(graph_fn, [similarity])
|
|
self.assertAllEqual(match_results[res_matched_cols], expected_matched_rows)
|
|
self.assertAllEqual(np.nonzero(res_matched_cols)[0], expected_matched_cols)
|
|
self.assertAllEqual(np.nonzero(res_unmatched_cols)[0],
|
|
expected_unmatched_cols)
|
|
|
|
def test_return_correct_matches_negatives_lower_than_unmatched_false(self):
|
|
|
|
def graph_fn(similarity):
|
|
matcher = argmax_matcher.ArgMaxMatcher(
|
|
matched_threshold=3.,
|
|
unmatched_threshold=2.,
|
|
negatives_lower_than_unmatched=False)
|
|
match = matcher.match(similarity)
|
|
matched_cols = match.matched_column_indicator()
|
|
unmatched_cols = match.unmatched_column_indicator()
|
|
match_results = match.match_results
|
|
return (matched_cols, unmatched_cols, match_results)
|
|
|
|
similarity = np.array([[1, 1, 1, 3, 1],
|
|
[2, -1, 2, 0, 4],
|
|
[3, 0, -1, 0, 0]], dtype=np.float32)
|
|
expected_matched_cols = np.array([0, 3, 4])
|
|
expected_matched_rows = np.array([2, 0, 1])
|
|
expected_unmatched_cols = np.array([2]) # col 1 has too low maximum val
|
|
|
|
(res_matched_cols, res_unmatched_cols,
|
|
match_results) = self.execute(graph_fn, [similarity])
|
|
self.assertAllEqual(match_results[res_matched_cols], expected_matched_rows)
|
|
self.assertAllEqual(np.nonzero(res_matched_cols)[0], expected_matched_cols)
|
|
self.assertAllEqual(np.nonzero(res_unmatched_cols)[0],
|
|
expected_unmatched_cols)
|
|
|
|
def test_return_correct_matches_unmatched_row_not_using_force_match(self):
|
|
|
|
def graph_fn(similarity):
|
|
matcher = argmax_matcher.ArgMaxMatcher(matched_threshold=3.,
|
|
unmatched_threshold=2.)
|
|
match = matcher.match(similarity)
|
|
matched_cols = match.matched_column_indicator()
|
|
unmatched_cols = match.unmatched_column_indicator()
|
|
match_results = match.match_results
|
|
return (matched_cols, unmatched_cols, match_results)
|
|
|
|
similarity = np.array([[1, 1, 1, 3, 1],
|
|
[-1, 0, -2, -2, -1],
|
|
[3, 0, -1, 2, 0]], dtype=np.float32)
|
|
expected_matched_cols = np.array([0, 3])
|
|
expected_matched_rows = np.array([2, 0])
|
|
expected_unmatched_cols = np.array([1, 2, 4])
|
|
|
|
(res_matched_cols, res_unmatched_cols,
|
|
match_results) = self.execute(graph_fn, [similarity])
|
|
self.assertAllEqual(match_results[res_matched_cols], expected_matched_rows)
|
|
self.assertAllEqual(np.nonzero(res_matched_cols)[0], expected_matched_cols)
|
|
self.assertAllEqual(np.nonzero(res_unmatched_cols)[0],
|
|
expected_unmatched_cols)
|
|
|
|
def test_return_correct_matches_unmatched_row_while_using_force_match(self):
|
|
def graph_fn(similarity):
|
|
matcher = argmax_matcher.ArgMaxMatcher(matched_threshold=3.,
|
|
unmatched_threshold=2.,
|
|
force_match_for_each_row=True)
|
|
match = matcher.match(similarity)
|
|
matched_cols = match.matched_column_indicator()
|
|
unmatched_cols = match.unmatched_column_indicator()
|
|
match_results = match.match_results
|
|
return (matched_cols, unmatched_cols, match_results)
|
|
|
|
similarity = np.array([[1, 1, 1, 3, 1],
|
|
[-1, 0, -2, -2, -1],
|
|
[3, 0, -1, 2, 0]], dtype=np.float32)
|
|
expected_matched_cols = np.array([0, 1, 3])
|
|
expected_matched_rows = np.array([2, 1, 0])
|
|
expected_unmatched_cols = np.array([2, 4]) # col 2 has too high max val
|
|
|
|
(res_matched_cols, res_unmatched_cols,
|
|
match_results) = self.execute(graph_fn, [similarity])
|
|
self.assertAllEqual(match_results[res_matched_cols], expected_matched_rows)
|
|
self.assertAllEqual(np.nonzero(res_matched_cols)[0], expected_matched_cols)
|
|
self.assertAllEqual(np.nonzero(res_unmatched_cols)[0],
|
|
expected_unmatched_cols)
|
|
|
|
def test_return_correct_matches_using_force_match_padded_groundtruth(self):
|
|
def graph_fn(similarity, valid_rows):
|
|
matcher = argmax_matcher.ArgMaxMatcher(matched_threshold=3.,
|
|
unmatched_threshold=2.,
|
|
force_match_for_each_row=True)
|
|
match = matcher.match(similarity, valid_rows)
|
|
matched_cols = match.matched_column_indicator()
|
|
unmatched_cols = match.unmatched_column_indicator()
|
|
match_results = match.match_results
|
|
return (matched_cols, unmatched_cols, match_results)
|
|
|
|
similarity = np.array([[1, 1, 1, 3, 1],
|
|
[-1, 0, -2, -2, -1],
|
|
[0, 0, 0, 0, 0],
|
|
[3, 0, -1, 2, 0],
|
|
[0, 0, 0, 0, 0]], dtype=np.float32)
|
|
valid_rows = np.array([True, True, False, True, False])
|
|
expected_matched_cols = np.array([0, 1, 3])
|
|
expected_matched_rows = np.array([3, 1, 0])
|
|
expected_unmatched_cols = np.array([2, 4]) # col 2 has too high max val
|
|
|
|
(res_matched_cols, res_unmatched_cols,
|
|
match_results) = self.execute(graph_fn, [similarity, valid_rows])
|
|
self.assertAllEqual(match_results[res_matched_cols], expected_matched_rows)
|
|
self.assertAllEqual(np.nonzero(res_matched_cols)[0], expected_matched_cols)
|
|
self.assertAllEqual(np.nonzero(res_unmatched_cols)[0],
|
|
expected_unmatched_cols)
|
|
|
|
def test_valid_arguments_corner_case(self):
|
|
argmax_matcher.ArgMaxMatcher(matched_threshold=1,
|
|
unmatched_threshold=1)
|
|
|
|
def test_invalid_arguments_corner_case_negatives_lower_than_thres_false(self):
|
|
with self.assertRaises(ValueError):
|
|
argmax_matcher.ArgMaxMatcher(matched_threshold=1,
|
|
unmatched_threshold=1,
|
|
negatives_lower_than_unmatched=False)
|
|
|
|
def test_invalid_arguments_no_matched_threshold(self):
|
|
with self.assertRaises(ValueError):
|
|
argmax_matcher.ArgMaxMatcher(matched_threshold=None,
|
|
unmatched_threshold=4)
|
|
|
|
def test_invalid_arguments_unmatched_thres_larger_than_matched_thres(self):
|
|
with self.assertRaises(ValueError):
|
|
argmax_matcher.ArgMaxMatcher(matched_threshold=1,
|
|
unmatched_threshold=2)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
tf.test.main()
|