|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
"""A function to build localization and classification losses from config."""
|
|
|
|
import functools
|
|
from object_detection.core import balanced_positive_negative_sampler as sampler
|
|
from object_detection.core import losses
|
|
from object_detection.protos import losses_pb2
|
|
from object_detection.utils import ops
|
|
|
|
|
|
def build(loss_config):
|
|
"""Build losses based on the config.
|
|
|
|
Builds classification, localization losses and optionally a hard example miner
|
|
based on the config.
|
|
|
|
Args:
|
|
loss_config: A losses_pb2.Loss object.
|
|
|
|
Returns:
|
|
classification_loss: Classification loss object.
|
|
localization_loss: Localization loss object.
|
|
classification_weight: Classification loss weight.
|
|
localization_weight: Localization loss weight.
|
|
hard_example_miner: Hard example miner object.
|
|
random_example_sampler: BalancedPositiveNegativeSampler object.
|
|
|
|
Raises:
|
|
ValueError: If hard_example_miner is used with sigmoid_focal_loss.
|
|
ValueError: If random_example_sampler is getting non-positive value as
|
|
desired positive example fraction.
|
|
"""
|
|
classification_loss = _build_classification_loss(
|
|
loss_config.classification_loss)
|
|
localization_loss = _build_localization_loss(
|
|
loss_config.localization_loss)
|
|
classification_weight = loss_config.classification_weight
|
|
localization_weight = loss_config.localization_weight
|
|
hard_example_miner = None
|
|
if loss_config.HasField('hard_example_miner'):
|
|
if (loss_config.classification_loss.WhichOneof('classification_loss') ==
|
|
'weighted_sigmoid_focal'):
|
|
raise ValueError('HardExampleMiner should not be used with sigmoid focal '
|
|
'loss')
|
|
hard_example_miner = build_hard_example_miner(
|
|
loss_config.hard_example_miner,
|
|
classification_weight,
|
|
localization_weight)
|
|
random_example_sampler = None
|
|
if loss_config.HasField('random_example_sampler'):
|
|
if loss_config.random_example_sampler.positive_sample_fraction <= 0:
|
|
raise ValueError('RandomExampleSampler should not use non-positive'
|
|
'value as positive sample fraction.')
|
|
random_example_sampler = sampler.BalancedPositiveNegativeSampler(
|
|
positive_fraction=loss_config.random_example_sampler.
|
|
positive_sample_fraction)
|
|
|
|
if loss_config.expected_loss_weights == loss_config.NONE:
|
|
expected_loss_weights_fn = None
|
|
elif loss_config.expected_loss_weights == loss_config.EXPECTED_SAMPLING:
|
|
expected_loss_weights_fn = functools.partial(
|
|
ops.expected_classification_loss_by_expected_sampling,
|
|
min_num_negative_samples=loss_config.min_num_negative_samples,
|
|
desired_negative_sampling_ratio=loss_config
|
|
.desired_negative_sampling_ratio)
|
|
elif (loss_config.expected_loss_weights == loss_config
|
|
.REWEIGHTING_UNMATCHED_ANCHORS):
|
|
expected_loss_weights_fn = functools.partial(
|
|
ops.expected_classification_loss_by_reweighting_unmatched_anchors,
|
|
min_num_negative_samples=loss_config.min_num_negative_samples,
|
|
desired_negative_sampling_ratio=loss_config
|
|
.desired_negative_sampling_ratio)
|
|
else:
|
|
raise ValueError('Not a valid value for expected_classification_loss.')
|
|
|
|
return (classification_loss, localization_loss, classification_weight,
|
|
localization_weight, hard_example_miner, random_example_sampler,
|
|
expected_loss_weights_fn)
|
|
|
|
|
|
def build_hard_example_miner(config,
|
|
classification_weight,
|
|
localization_weight):
|
|
"""Builds hard example miner based on the config.
|
|
|
|
Args:
|
|
config: A losses_pb2.HardExampleMiner object.
|
|
classification_weight: Classification loss weight.
|
|
localization_weight: Localization loss weight.
|
|
|
|
Returns:
|
|
Hard example miner.
|
|
|
|
"""
|
|
loss_type = None
|
|
if config.loss_type == losses_pb2.HardExampleMiner.BOTH:
|
|
loss_type = 'both'
|
|
if config.loss_type == losses_pb2.HardExampleMiner.CLASSIFICATION:
|
|
loss_type = 'cls'
|
|
if config.loss_type == losses_pb2.HardExampleMiner.LOCALIZATION:
|
|
loss_type = 'loc'
|
|
|
|
max_negatives_per_positive = None
|
|
num_hard_examples = None
|
|
if config.max_negatives_per_positive > 0:
|
|
max_negatives_per_positive = config.max_negatives_per_positive
|
|
if config.num_hard_examples > 0:
|
|
num_hard_examples = config.num_hard_examples
|
|
hard_example_miner = losses.HardExampleMiner(
|
|
num_hard_examples=num_hard_examples,
|
|
iou_threshold=config.iou_threshold,
|
|
loss_type=loss_type,
|
|
cls_loss_weight=classification_weight,
|
|
loc_loss_weight=localization_weight,
|
|
max_negatives_per_positive=max_negatives_per_positive,
|
|
min_negatives_per_image=config.min_negatives_per_image)
|
|
return hard_example_miner
|
|
|
|
|
|
def build_faster_rcnn_classification_loss(loss_config):
|
|
"""Builds a classification loss for Faster RCNN based on the loss config.
|
|
|
|
Args:
|
|
loss_config: A losses_pb2.ClassificationLoss object.
|
|
|
|
Returns:
|
|
Loss based on the config.
|
|
|
|
Raises:
|
|
ValueError: On invalid loss_config.
|
|
"""
|
|
if not isinstance(loss_config, losses_pb2.ClassificationLoss):
|
|
raise ValueError('loss_config not of type losses_pb2.ClassificationLoss.')
|
|
|
|
loss_type = loss_config.WhichOneof('classification_loss')
|
|
|
|
if loss_type == 'weighted_sigmoid':
|
|
return losses.WeightedSigmoidClassificationLoss()
|
|
if loss_type == 'weighted_softmax':
|
|
config = loss_config.weighted_softmax
|
|
return losses.WeightedSoftmaxClassificationLoss(
|
|
logit_scale=config.logit_scale)
|
|
if loss_type == 'weighted_logits_softmax':
|
|
config = loss_config.weighted_logits_softmax
|
|
return losses.WeightedSoftmaxClassificationAgainstLogitsLoss(
|
|
logit_scale=config.logit_scale)
|
|
if loss_type == 'weighted_sigmoid_focal':
|
|
config = loss_config.weighted_sigmoid_focal
|
|
alpha = None
|
|
if config.HasField('alpha'):
|
|
alpha = config.alpha
|
|
return losses.SigmoidFocalClassificationLoss(
|
|
gamma=config.gamma,
|
|
alpha=alpha)
|
|
|
|
# By default, Faster RCNN second stage classifier uses Softmax loss
|
|
# with anchor-wise outputs.
|
|
config = loss_config.weighted_softmax
|
|
return losses.WeightedSoftmaxClassificationLoss(
|
|
logit_scale=config.logit_scale)
|
|
|
|
|
|
def _build_localization_loss(loss_config):
|
|
"""Builds a localization loss based on the loss config.
|
|
|
|
Args:
|
|
loss_config: A losses_pb2.LocalizationLoss object.
|
|
|
|
Returns:
|
|
Loss based on the config.
|
|
|
|
Raises:
|
|
ValueError: On invalid loss_config.
|
|
"""
|
|
if not isinstance(loss_config, losses_pb2.LocalizationLoss):
|
|
raise ValueError('loss_config not of type losses_pb2.LocalizationLoss.')
|
|
|
|
loss_type = loss_config.WhichOneof('localization_loss')
|
|
|
|
if loss_type == 'weighted_l2':
|
|
return losses.WeightedL2LocalizationLoss()
|
|
|
|
if loss_type == 'weighted_smooth_l1':
|
|
return losses.WeightedSmoothL1LocalizationLoss(
|
|
loss_config.weighted_smooth_l1.delta)
|
|
|
|
if loss_type == 'weighted_iou':
|
|
return losses.WeightedIOULocalizationLoss()
|
|
|
|
raise ValueError('Empty loss config.')
|
|
|
|
|
|
def _build_classification_loss(loss_config):
|
|
"""Builds a classification loss based on the loss config.
|
|
|
|
Args:
|
|
loss_config: A losses_pb2.ClassificationLoss object.
|
|
|
|
Returns:
|
|
Loss based on the config.
|
|
|
|
Raises:
|
|
ValueError: On invalid loss_config.
|
|
"""
|
|
if not isinstance(loss_config, losses_pb2.ClassificationLoss):
|
|
raise ValueError('loss_config not of type losses_pb2.ClassificationLoss.')
|
|
|
|
loss_type = loss_config.WhichOneof('classification_loss')
|
|
|
|
if loss_type == 'weighted_sigmoid':
|
|
return losses.WeightedSigmoidClassificationLoss()
|
|
|
|
if loss_type == 'weighted_sigmoid_focal':
|
|
config = loss_config.weighted_sigmoid_focal
|
|
alpha = None
|
|
if config.HasField('alpha'):
|
|
alpha = config.alpha
|
|
return losses.SigmoidFocalClassificationLoss(
|
|
gamma=config.gamma,
|
|
alpha=alpha)
|
|
|
|
if loss_type == 'weighted_softmax':
|
|
config = loss_config.weighted_softmax
|
|
return losses.WeightedSoftmaxClassificationLoss(
|
|
logit_scale=config.logit_scale)
|
|
|
|
if loss_type == 'weighted_logits_softmax':
|
|
config = loss_config.weighted_logits_softmax
|
|
return losses.WeightedSoftmaxClassificationAgainstLogitsLoss(
|
|
logit_scale=config.logit_scale)
|
|
|
|
if loss_type == 'bootstrapped_sigmoid':
|
|
config = loss_config.bootstrapped_sigmoid
|
|
return losses.BootstrappedSigmoidClassificationLoss(
|
|
alpha=config.alpha,
|
|
bootstrap_type=('hard' if config.hard_bootstrap else 'soft'))
|
|
|
|
raise ValueError('Empty loss config.')
|