|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
"""Argmax matcher implementation.
|
|
|
|
This class takes a similarity matrix and matches columns to rows based on the
|
|
maximum value per column. One can specify matched_thresholds and
|
|
to prevent columns from matching to rows (generally resulting in a negative
|
|
training example) and unmatched_theshold to ignore the match (generally
|
|
resulting in neither a positive or negative training example).
|
|
|
|
This matcher is used in Fast(er)-RCNN.
|
|
|
|
Note: matchers are used in TargetAssigners. There is a create_target_assigner
|
|
factory function for popular implementations.
|
|
"""
|
|
import tensorflow as tf
|
|
|
|
from object_detection.core import matcher
|
|
from object_detection.utils import shape_utils
|
|
|
|
|
|
class ArgMaxMatcher(matcher.Matcher):
|
|
"""Matcher based on highest value.
|
|
|
|
This class computes matches from a similarity matrix. Each column is matched
|
|
to a single row.
|
|
|
|
To support object detection target assignment this class enables setting both
|
|
matched_threshold (upper threshold) and unmatched_threshold (lower thresholds)
|
|
defining three categories of similarity which define whether examples are
|
|
positive, negative, or ignored:
|
|
(1) similarity >= matched_threshold: Highest similarity. Matched/Positive!
|
|
(2) matched_threshold > similarity >= unmatched_threshold: Medium similarity.
|
|
Depending on negatives_lower_than_unmatched, this is either
|
|
Unmatched/Negative OR Ignore.
|
|
(3) unmatched_threshold > similarity: Lowest similarity. Depending on flag
|
|
negatives_lower_than_unmatched, either Unmatched/Negative OR Ignore.
|
|
For ignored matches this class sets the values in the Match object to -2.
|
|
"""
|
|
|
|
def __init__(self,
|
|
matched_threshold,
|
|
unmatched_threshold=None,
|
|
negatives_lower_than_unmatched=True,
|
|
force_match_for_each_row=False,
|
|
use_matmul_gather=False):
|
|
"""Construct ArgMaxMatcher.
|
|
|
|
Args:
|
|
matched_threshold: Threshold for positive matches. Positive if
|
|
sim >= matched_threshold, where sim is the maximum value of the
|
|
similarity matrix for a given column. Set to None for no threshold.
|
|
unmatched_threshold: Threshold for negative matches. Negative if
|
|
sim < unmatched_threshold. Defaults to matched_threshold
|
|
when set to None.
|
|
negatives_lower_than_unmatched: Boolean which defaults to True. If True
|
|
then negative matches are the ones below the unmatched_threshold,
|
|
whereas ignored matches are in between the matched and umatched
|
|
threshold. If False, then negative matches are in between the matched
|
|
and unmatched threshold, and everything lower than unmatched is ignored.
|
|
force_match_for_each_row: If True, ensures that each row is matched to
|
|
at least one column (which is not guaranteed otherwise if the
|
|
matched_threshold is high). Defaults to False. See
|
|
argmax_matcher_test.testMatcherForceMatch() for an example.
|
|
use_matmul_gather: Force constructed match objects to use matrix
|
|
multiplication based gather instead of standard tf.gather.
|
|
(Default: False).
|
|
|
|
Raises:
|
|
ValueError: if unmatched_threshold is set but matched_threshold is not set
|
|
or if unmatched_threshold > matched_threshold.
|
|
"""
|
|
super(ArgMaxMatcher, self).__init__(use_matmul_gather=use_matmul_gather)
|
|
if (matched_threshold is None) and (unmatched_threshold is not None):
|
|
raise ValueError('Need to also define matched_threshold when'
|
|
'unmatched_threshold is defined')
|
|
self._matched_threshold = matched_threshold
|
|
if unmatched_threshold is None:
|
|
self._unmatched_threshold = matched_threshold
|
|
else:
|
|
if unmatched_threshold > matched_threshold:
|
|
raise ValueError('unmatched_threshold needs to be smaller or equal'
|
|
'to matched_threshold')
|
|
self._unmatched_threshold = unmatched_threshold
|
|
if not negatives_lower_than_unmatched:
|
|
if self._unmatched_threshold == self._matched_threshold:
|
|
raise ValueError('When negatives are in between matched and '
|
|
'unmatched thresholds, these cannot be of equal '
|
|
'value. matched: {}, unmatched: {}'.format(
|
|
self._matched_threshold,
|
|
self._unmatched_threshold))
|
|
self._force_match_for_each_row = force_match_for_each_row
|
|
self._negatives_lower_than_unmatched = negatives_lower_than_unmatched
|
|
|
|
def _match(self, similarity_matrix, valid_rows):
|
|
"""Tries to match each column of the similarity matrix to a row.
|
|
|
|
Args:
|
|
similarity_matrix: tensor of shape [N, M] representing any similarity
|
|
metric.
|
|
valid_rows: a boolean tensor of shape [N] indicating valid rows.
|
|
|
|
Returns:
|
|
Match object with corresponding matches for each of M columns.
|
|
"""
|
|
|
|
def _match_when_rows_are_empty():
|
|
"""Performs matching when the rows of similarity matrix are empty.
|
|
|
|
When the rows are empty, all detections are false positives. So we return
|
|
a tensor of -1's to indicate that the columns do not match to any rows.
|
|
|
|
Returns:
|
|
matches: int32 tensor indicating the row each column matches to.
|
|
"""
|
|
similarity_matrix_shape = shape_utils.combined_static_and_dynamic_shape(
|
|
similarity_matrix)
|
|
return -1 * tf.ones([similarity_matrix_shape[1]], dtype=tf.int32)
|
|
|
|
def _match_when_rows_are_non_empty():
|
|
"""Performs matching when the rows of similarity matrix are non empty.
|
|
|
|
Returns:
|
|
matches: int32 tensor indicating the row each column matches to.
|
|
"""
|
|
# Matches for each column
|
|
matches = tf.argmax(similarity_matrix, 0, output_type=tf.int32)
|
|
|
|
# Deal with matched and unmatched threshold
|
|
if self._matched_threshold is not None:
|
|
# Get logical indices of ignored and unmatched columns as tf.int64
|
|
matched_vals = tf.reduce_max(similarity_matrix, 0)
|
|
below_unmatched_threshold = tf.greater(self._unmatched_threshold,
|
|
matched_vals)
|
|
between_thresholds = tf.logical_and(
|
|
tf.greater_equal(matched_vals, self._unmatched_threshold),
|
|
tf.greater(self._matched_threshold, matched_vals))
|
|
|
|
if self._negatives_lower_than_unmatched:
|
|
matches = self._set_values_using_indicator(matches,
|
|
below_unmatched_threshold,
|
|
-1)
|
|
matches = self._set_values_using_indicator(matches,
|
|
between_thresholds,
|
|
-2)
|
|
else:
|
|
matches = self._set_values_using_indicator(matches,
|
|
below_unmatched_threshold,
|
|
-2)
|
|
matches = self._set_values_using_indicator(matches,
|
|
between_thresholds,
|
|
-1)
|
|
|
|
if self._force_match_for_each_row:
|
|
similarity_matrix_shape = shape_utils.combined_static_and_dynamic_shape(
|
|
similarity_matrix)
|
|
force_match_column_ids = tf.argmax(similarity_matrix, 1,
|
|
output_type=tf.int32)
|
|
force_match_column_indicators = (
|
|
tf.one_hot(
|
|
force_match_column_ids, depth=similarity_matrix_shape[1]) *
|
|
tf.cast(tf.expand_dims(valid_rows, axis=-1), dtype=tf.float32))
|
|
force_match_row_ids = tf.argmax(force_match_column_indicators, 0,
|
|
output_type=tf.int32)
|
|
force_match_column_mask = tf.cast(
|
|
tf.reduce_max(force_match_column_indicators, 0), tf.bool)
|
|
final_matches = tf.where(force_match_column_mask,
|
|
force_match_row_ids, matches)
|
|
return final_matches
|
|
else:
|
|
return matches
|
|
|
|
if similarity_matrix.shape.is_fully_defined():
|
|
if similarity_matrix.shape[0].value == 0:
|
|
return _match_when_rows_are_empty()
|
|
else:
|
|
return _match_when_rows_are_non_empty()
|
|
else:
|
|
return tf.cond(
|
|
tf.greater(tf.shape(similarity_matrix)[0], 0),
|
|
_match_when_rows_are_non_empty, _match_when_rows_are_empty)
|
|
|
|
def _set_values_using_indicator(self, x, indicator, val):
|
|
"""Set the indicated fields of x to val.
|
|
|
|
Args:
|
|
x: tensor.
|
|
indicator: boolean with same shape as x.
|
|
val: scalar with value to set.
|
|
|
|
Returns:
|
|
modified tensor.
|
|
"""
|
|
indicator = tf.cast(indicator, x.dtype)
|
|
return tf.add(tf.multiply(x, 1 - indicator), val * indicator)
|