You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

252 lines
9.0 KiB

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A function to build localization and classification losses from config."""
import functools
from object_detection.core import balanced_positive_negative_sampler as sampler
from object_detection.core import losses
from object_detection.protos import losses_pb2
from object_detection.utils import ops
def build(loss_config):
"""Build losses based on the config.
Builds classification, localization losses and optionally a hard example miner
based on the config.
Args:
loss_config: A losses_pb2.Loss object.
Returns:
classification_loss: Classification loss object.
localization_loss: Localization loss object.
classification_weight: Classification loss weight.
localization_weight: Localization loss weight.
hard_example_miner: Hard example miner object.
random_example_sampler: BalancedPositiveNegativeSampler object.
Raises:
ValueError: If hard_example_miner is used with sigmoid_focal_loss.
ValueError: If random_example_sampler is getting non-positive value as
desired positive example fraction.
"""
classification_loss = _build_classification_loss(
loss_config.classification_loss)
localization_loss = _build_localization_loss(
loss_config.localization_loss)
classification_weight = loss_config.classification_weight
localization_weight = loss_config.localization_weight
hard_example_miner = None
if loss_config.HasField('hard_example_miner'):
if (loss_config.classification_loss.WhichOneof('classification_loss') ==
'weighted_sigmoid_focal'):
raise ValueError('HardExampleMiner should not be used with sigmoid focal '
'loss')
hard_example_miner = build_hard_example_miner(
loss_config.hard_example_miner,
classification_weight,
localization_weight)
random_example_sampler = None
if loss_config.HasField('random_example_sampler'):
if loss_config.random_example_sampler.positive_sample_fraction <= 0:
raise ValueError('RandomExampleSampler should not use non-positive'
'value as positive sample fraction.')
random_example_sampler = sampler.BalancedPositiveNegativeSampler(
positive_fraction=loss_config.random_example_sampler.
positive_sample_fraction)
if loss_config.expected_loss_weights == loss_config.NONE:
expected_loss_weights_fn = None
elif loss_config.expected_loss_weights == loss_config.EXPECTED_SAMPLING:
expected_loss_weights_fn = functools.partial(
ops.expected_classification_loss_by_expected_sampling,
min_num_negative_samples=loss_config.min_num_negative_samples,
desired_negative_sampling_ratio=loss_config
.desired_negative_sampling_ratio)
elif (loss_config.expected_loss_weights == loss_config
.REWEIGHTING_UNMATCHED_ANCHORS):
expected_loss_weights_fn = functools.partial(
ops.expected_classification_loss_by_reweighting_unmatched_anchors,
min_num_negative_samples=loss_config.min_num_negative_samples,
desired_negative_sampling_ratio=loss_config
.desired_negative_sampling_ratio)
else:
raise ValueError('Not a valid value for expected_classification_loss.')
return (classification_loss, localization_loss, classification_weight,
localization_weight, hard_example_miner, random_example_sampler,
expected_loss_weights_fn)
def build_hard_example_miner(config,
classification_weight,
localization_weight):
"""Builds hard example miner based on the config.
Args:
config: A losses_pb2.HardExampleMiner object.
classification_weight: Classification loss weight.
localization_weight: Localization loss weight.
Returns:
Hard example miner.
"""
loss_type = None
if config.loss_type == losses_pb2.HardExampleMiner.BOTH:
loss_type = 'both'
if config.loss_type == losses_pb2.HardExampleMiner.CLASSIFICATION:
loss_type = 'cls'
if config.loss_type == losses_pb2.HardExampleMiner.LOCALIZATION:
loss_type = 'loc'
max_negatives_per_positive = None
num_hard_examples = None
if config.max_negatives_per_positive > 0:
max_negatives_per_positive = config.max_negatives_per_positive
if config.num_hard_examples > 0:
num_hard_examples = config.num_hard_examples
hard_example_miner = losses.HardExampleMiner(
num_hard_examples=num_hard_examples,
iou_threshold=config.iou_threshold,
loss_type=loss_type,
cls_loss_weight=classification_weight,
loc_loss_weight=localization_weight,
max_negatives_per_positive=max_negatives_per_positive,
min_negatives_per_image=config.min_negatives_per_image)
return hard_example_miner
def build_faster_rcnn_classification_loss(loss_config):
"""Builds a classification loss for Faster RCNN based on the loss config.
Args:
loss_config: A losses_pb2.ClassificationLoss object.
Returns:
Loss based on the config.
Raises:
ValueError: On invalid loss_config.
"""
if not isinstance(loss_config, losses_pb2.ClassificationLoss):
raise ValueError('loss_config not of type losses_pb2.ClassificationLoss.')
loss_type = loss_config.WhichOneof('classification_loss')
if loss_type == 'weighted_sigmoid':
return losses.WeightedSigmoidClassificationLoss()
if loss_type == 'weighted_softmax':
config = loss_config.weighted_softmax
return losses.WeightedSoftmaxClassificationLoss(
logit_scale=config.logit_scale)
if loss_type == 'weighted_logits_softmax':
config = loss_config.weighted_logits_softmax
return losses.WeightedSoftmaxClassificationAgainstLogitsLoss(
logit_scale=config.logit_scale)
if loss_type == 'weighted_sigmoid_focal':
config = loss_config.weighted_sigmoid_focal
alpha = None
if config.HasField('alpha'):
alpha = config.alpha
return losses.SigmoidFocalClassificationLoss(
gamma=config.gamma,
alpha=alpha)
# By default, Faster RCNN second stage classifier uses Softmax loss
# with anchor-wise outputs.
config = loss_config.weighted_softmax
return losses.WeightedSoftmaxClassificationLoss(
logit_scale=config.logit_scale)
def _build_localization_loss(loss_config):
"""Builds a localization loss based on the loss config.
Args:
loss_config: A losses_pb2.LocalizationLoss object.
Returns:
Loss based on the config.
Raises:
ValueError: On invalid loss_config.
"""
if not isinstance(loss_config, losses_pb2.LocalizationLoss):
raise ValueError('loss_config not of type losses_pb2.LocalizationLoss.')
loss_type = loss_config.WhichOneof('localization_loss')
if loss_type == 'weighted_l2':
return losses.WeightedL2LocalizationLoss()
if loss_type == 'weighted_smooth_l1':
return losses.WeightedSmoothL1LocalizationLoss(
loss_config.weighted_smooth_l1.delta)
if loss_type == 'weighted_iou':
return losses.WeightedIOULocalizationLoss()
raise ValueError('Empty loss config.')
def _build_classification_loss(loss_config):
"""Builds a classification loss based on the loss config.
Args:
loss_config: A losses_pb2.ClassificationLoss object.
Returns:
Loss based on the config.
Raises:
ValueError: On invalid loss_config.
"""
if not isinstance(loss_config, losses_pb2.ClassificationLoss):
raise ValueError('loss_config not of type losses_pb2.ClassificationLoss.')
loss_type = loss_config.WhichOneof('classification_loss')
if loss_type == 'weighted_sigmoid':
return losses.WeightedSigmoidClassificationLoss()
if loss_type == 'weighted_sigmoid_focal':
config = loss_config.weighted_sigmoid_focal
alpha = None
if config.HasField('alpha'):
alpha = config.alpha
return losses.SigmoidFocalClassificationLoss(
gamma=config.gamma,
alpha=alpha)
if loss_type == 'weighted_softmax':
config = loss_config.weighted_softmax
return losses.WeightedSoftmaxClassificationLoss(
logit_scale=config.logit_scale)
if loss_type == 'weighted_logits_softmax':
config = loss_config.weighted_logits_softmax
return losses.WeightedSoftmaxClassificationAgainstLogitsLoss(
logit_scale=config.logit_scale)
if loss_type == 'bootstrapped_sigmoid':
config = loss_config.bootstrapped_sigmoid
return losses.BootstrappedSigmoidClassificationLoss(
alpha=config.alpha,
bootstrap_type=('hard' if config.hard_bootstrap else 'soft'))
raise ValueError('Empty loss config.')