|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
"""A function to build an object detection matcher from configuration."""
|
|
|
|
from object_detection.matchers import argmax_matcher
|
|
from object_detection.matchers import bipartite_matcher
|
|
from object_detection.protos import matcher_pb2
|
|
|
|
|
|
def build(matcher_config):
|
|
"""Builds a matcher object based on the matcher config.
|
|
|
|
Args:
|
|
matcher_config: A matcher.proto object containing the config for the desired
|
|
Matcher.
|
|
|
|
Returns:
|
|
Matcher based on the config.
|
|
|
|
Raises:
|
|
ValueError: On empty matcher proto.
|
|
"""
|
|
if not isinstance(matcher_config, matcher_pb2.Matcher):
|
|
raise ValueError('matcher_config not of type matcher_pb2.Matcher.')
|
|
if matcher_config.WhichOneof('matcher_oneof') == 'argmax_matcher':
|
|
matcher = matcher_config.argmax_matcher
|
|
matched_threshold = unmatched_threshold = None
|
|
if not matcher.ignore_thresholds:
|
|
matched_threshold = matcher.matched_threshold
|
|
unmatched_threshold = matcher.unmatched_threshold
|
|
return argmax_matcher.ArgMaxMatcher(
|
|
matched_threshold=matched_threshold,
|
|
unmatched_threshold=unmatched_threshold,
|
|
negatives_lower_than_unmatched=matcher.negatives_lower_than_unmatched,
|
|
force_match_for_each_row=matcher.force_match_for_each_row,
|
|
use_matmul_gather=matcher.use_matmul_gather)
|
|
if matcher_config.WhichOneof('matcher_oneof') == 'bipartite_matcher':
|
|
matcher = matcher_config.bipartite_matcher
|
|
return bipartite_matcher.GreedyBipartiteMatcher(matcher.use_matmul_gather)
|
|
raise ValueError('Empty matcher.')
|