You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

208 lines
9.3 KiB

6 years ago
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Argmax matcher implementation.
  16. This class takes a similarity matrix and matches columns to rows based on the
  17. maximum value per column. One can specify matched_thresholds and
  18. to prevent columns from matching to rows (generally resulting in a negative
  19. training example) and unmatched_theshold to ignore the match (generally
  20. resulting in neither a positive or negative training example).
  21. This matcher is used in Fast(er)-RCNN.
  22. Note: matchers are used in TargetAssigners. There is a create_target_assigner
  23. factory function for popular implementations.
  24. """
  25. import tensorflow as tf
  26. from object_detection.core import matcher
  27. from object_detection.utils import shape_utils
  28. class ArgMaxMatcher(matcher.Matcher):
  29. """Matcher based on highest value.
  30. This class computes matches from a similarity matrix. Each column is matched
  31. to a single row.
  32. To support object detection target assignment this class enables setting both
  33. matched_threshold (upper threshold) and unmatched_threshold (lower thresholds)
  34. defining three categories of similarity which define whether examples are
  35. positive, negative, or ignored:
  36. (1) similarity >= matched_threshold: Highest similarity. Matched/Positive!
  37. (2) matched_threshold > similarity >= unmatched_threshold: Medium similarity.
  38. Depending on negatives_lower_than_unmatched, this is either
  39. Unmatched/Negative OR Ignore.
  40. (3) unmatched_threshold > similarity: Lowest similarity. Depending on flag
  41. negatives_lower_than_unmatched, either Unmatched/Negative OR Ignore.
  42. For ignored matches this class sets the values in the Match object to -2.
  43. """
  44. def __init__(self,
  45. matched_threshold,
  46. unmatched_threshold=None,
  47. negatives_lower_than_unmatched=True,
  48. force_match_for_each_row=False,
  49. use_matmul_gather=False):
  50. """Construct ArgMaxMatcher.
  51. Args:
  52. matched_threshold: Threshold for positive matches. Positive if
  53. sim >= matched_threshold, where sim is the maximum value of the
  54. similarity matrix for a given column. Set to None for no threshold.
  55. unmatched_threshold: Threshold for negative matches. Negative if
  56. sim < unmatched_threshold. Defaults to matched_threshold
  57. when set to None.
  58. negatives_lower_than_unmatched: Boolean which defaults to True. If True
  59. then negative matches are the ones below the unmatched_threshold,
  60. whereas ignored matches are in between the matched and umatched
  61. threshold. If False, then negative matches are in between the matched
  62. and unmatched threshold, and everything lower than unmatched is ignored.
  63. force_match_for_each_row: If True, ensures that each row is matched to
  64. at least one column (which is not guaranteed otherwise if the
  65. matched_threshold is high). Defaults to False. See
  66. argmax_matcher_test.testMatcherForceMatch() for an example.
  67. use_matmul_gather: Force constructed match objects to use matrix
  68. multiplication based gather instead of standard tf.gather.
  69. (Default: False).
  70. Raises:
  71. ValueError: if unmatched_threshold is set but matched_threshold is not set
  72. or if unmatched_threshold > matched_threshold.
  73. """
  74. super(ArgMaxMatcher, self).__init__(use_matmul_gather=use_matmul_gather)
  75. if (matched_threshold is None) and (unmatched_threshold is not None):
  76. raise ValueError('Need to also define matched_threshold when'
  77. 'unmatched_threshold is defined')
  78. self._matched_threshold = matched_threshold
  79. if unmatched_threshold is None:
  80. self._unmatched_threshold = matched_threshold
  81. else:
  82. if unmatched_threshold > matched_threshold:
  83. raise ValueError('unmatched_threshold needs to be smaller or equal'
  84. 'to matched_threshold')
  85. self._unmatched_threshold = unmatched_threshold
  86. if not negatives_lower_than_unmatched:
  87. if self._unmatched_threshold == self._matched_threshold:
  88. raise ValueError('When negatives are in between matched and '
  89. 'unmatched thresholds, these cannot be of equal '
  90. 'value. matched: {}, unmatched: {}'.format(
  91. self._matched_threshold,
  92. self._unmatched_threshold))
  93. self._force_match_for_each_row = force_match_for_each_row
  94. self._negatives_lower_than_unmatched = negatives_lower_than_unmatched
  95. def _match(self, similarity_matrix, valid_rows):
  96. """Tries to match each column of the similarity matrix to a row.
  97. Args:
  98. similarity_matrix: tensor of shape [N, M] representing any similarity
  99. metric.
  100. valid_rows: a boolean tensor of shape [N] indicating valid rows.
  101. Returns:
  102. Match object with corresponding matches for each of M columns.
  103. """
  104. def _match_when_rows_are_empty():
  105. """Performs matching when the rows of similarity matrix are empty.
  106. When the rows are empty, all detections are false positives. So we return
  107. a tensor of -1's to indicate that the columns do not match to any rows.
  108. Returns:
  109. matches: int32 tensor indicating the row each column matches to.
  110. """
  111. similarity_matrix_shape = shape_utils.combined_static_and_dynamic_shape(
  112. similarity_matrix)
  113. return -1 * tf.ones([similarity_matrix_shape[1]], dtype=tf.int32)
  114. def _match_when_rows_are_non_empty():
  115. """Performs matching when the rows of similarity matrix are non empty.
  116. Returns:
  117. matches: int32 tensor indicating the row each column matches to.
  118. """
  119. # Matches for each column
  120. matches = tf.argmax(similarity_matrix, 0, output_type=tf.int32)
  121. # Deal with matched and unmatched threshold
  122. if self._matched_threshold is not None:
  123. # Get logical indices of ignored and unmatched columns as tf.int64
  124. matched_vals = tf.reduce_max(similarity_matrix, 0)
  125. below_unmatched_threshold = tf.greater(self._unmatched_threshold,
  126. matched_vals)
  127. between_thresholds = tf.logical_and(
  128. tf.greater_equal(matched_vals, self._unmatched_threshold),
  129. tf.greater(self._matched_threshold, matched_vals))
  130. if self._negatives_lower_than_unmatched:
  131. matches = self._set_values_using_indicator(matches,
  132. below_unmatched_threshold,
  133. -1)
  134. matches = self._set_values_using_indicator(matches,
  135. between_thresholds,
  136. -2)
  137. else:
  138. matches = self._set_values_using_indicator(matches,
  139. below_unmatched_threshold,
  140. -2)
  141. matches = self._set_values_using_indicator(matches,
  142. between_thresholds,
  143. -1)
  144. if self._force_match_for_each_row:
  145. similarity_matrix_shape = shape_utils.combined_static_and_dynamic_shape(
  146. similarity_matrix)
  147. force_match_column_ids = tf.argmax(similarity_matrix, 1,
  148. output_type=tf.int32)
  149. force_match_column_indicators = (
  150. tf.one_hot(
  151. force_match_column_ids, depth=similarity_matrix_shape[1]) *
  152. tf.cast(tf.expand_dims(valid_rows, axis=-1), dtype=tf.float32))
  153. force_match_row_ids = tf.argmax(force_match_column_indicators, 0,
  154. output_type=tf.int32)
  155. force_match_column_mask = tf.cast(
  156. tf.reduce_max(force_match_column_indicators, 0), tf.bool)
  157. final_matches = tf.where(force_match_column_mask,
  158. force_match_row_ids, matches)
  159. return final_matches
  160. else:
  161. return matches
  162. if similarity_matrix.shape.is_fully_defined():
  163. if similarity_matrix.shape[0].value == 0:
  164. return _match_when_rows_are_empty()
  165. else:
  166. return _match_when_rows_are_non_empty()
  167. else:
  168. return tf.cond(
  169. tf.greater(tf.shape(similarity_matrix)[0], 0),
  170. _match_when_rows_are_non_empty, _match_when_rows_are_empty)
  171. def _set_values_using_indicator(self, x, indicator, val):
  172. """Set the indicated fields of x to val.
  173. Args:
  174. x: tensor.
  175. indicator: boolean with same shape as x.
  176. val: scalar with value to set.
  177. Returns:
  178. modified tensor.
  179. """
  180. indicator = tf.cast(indicator, x.dtype)
  181. return tf.add(tf.multiply(x, 1 - indicator), val * indicator)