|
|
- # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
-
- """A function to build localization and classification losses from config."""
-
- import functools
- from object_detection.core import balanced_positive_negative_sampler as sampler
- from object_detection.core import losses
- from object_detection.protos import losses_pb2
- from object_detection.utils import ops
-
-
- def build(loss_config):
- """Build losses based on the config.
-
- Builds classification, localization losses and optionally a hard example miner
- based on the config.
-
- Args:
- loss_config: A losses_pb2.Loss object.
-
- Returns:
- classification_loss: Classification loss object.
- localization_loss: Localization loss object.
- classification_weight: Classification loss weight.
- localization_weight: Localization loss weight.
- hard_example_miner: Hard example miner object.
- random_example_sampler: BalancedPositiveNegativeSampler object.
-
- Raises:
- ValueError: If hard_example_miner is used with sigmoid_focal_loss.
- ValueError: If random_example_sampler is getting non-positive value as
- desired positive example fraction.
- """
- classification_loss = _build_classification_loss(
- loss_config.classification_loss)
- localization_loss = _build_localization_loss(
- loss_config.localization_loss)
- classification_weight = loss_config.classification_weight
- localization_weight = loss_config.localization_weight
- hard_example_miner = None
- if loss_config.HasField('hard_example_miner'):
- if (loss_config.classification_loss.WhichOneof('classification_loss') ==
- 'weighted_sigmoid_focal'):
- raise ValueError('HardExampleMiner should not be used with sigmoid focal '
- 'loss')
- hard_example_miner = build_hard_example_miner(
- loss_config.hard_example_miner,
- classification_weight,
- localization_weight)
- random_example_sampler = None
- if loss_config.HasField('random_example_sampler'):
- if loss_config.random_example_sampler.positive_sample_fraction <= 0:
- raise ValueError('RandomExampleSampler should not use non-positive'
- 'value as positive sample fraction.')
- random_example_sampler = sampler.BalancedPositiveNegativeSampler(
- positive_fraction=loss_config.random_example_sampler.
- positive_sample_fraction)
-
- if loss_config.expected_loss_weights == loss_config.NONE:
- expected_loss_weights_fn = None
- elif loss_config.expected_loss_weights == loss_config.EXPECTED_SAMPLING:
- expected_loss_weights_fn = functools.partial(
- ops.expected_classification_loss_by_expected_sampling,
- min_num_negative_samples=loss_config.min_num_negative_samples,
- desired_negative_sampling_ratio=loss_config
- .desired_negative_sampling_ratio)
- elif (loss_config.expected_loss_weights == loss_config
- .REWEIGHTING_UNMATCHED_ANCHORS):
- expected_loss_weights_fn = functools.partial(
- ops.expected_classification_loss_by_reweighting_unmatched_anchors,
- min_num_negative_samples=loss_config.min_num_negative_samples,
- desired_negative_sampling_ratio=loss_config
- .desired_negative_sampling_ratio)
- else:
- raise ValueError('Not a valid value for expected_classification_loss.')
-
- return (classification_loss, localization_loss, classification_weight,
- localization_weight, hard_example_miner, random_example_sampler,
- expected_loss_weights_fn)
-
-
- def build_hard_example_miner(config,
- classification_weight,
- localization_weight):
- """Builds hard example miner based on the config.
-
- Args:
- config: A losses_pb2.HardExampleMiner object.
- classification_weight: Classification loss weight.
- localization_weight: Localization loss weight.
-
- Returns:
- Hard example miner.
-
- """
- loss_type = None
- if config.loss_type == losses_pb2.HardExampleMiner.BOTH:
- loss_type = 'both'
- if config.loss_type == losses_pb2.HardExampleMiner.CLASSIFICATION:
- loss_type = 'cls'
- if config.loss_type == losses_pb2.HardExampleMiner.LOCALIZATION:
- loss_type = 'loc'
-
- max_negatives_per_positive = None
- num_hard_examples = None
- if config.max_negatives_per_positive > 0:
- max_negatives_per_positive = config.max_negatives_per_positive
- if config.num_hard_examples > 0:
- num_hard_examples = config.num_hard_examples
- hard_example_miner = losses.HardExampleMiner(
- num_hard_examples=num_hard_examples,
- iou_threshold=config.iou_threshold,
- loss_type=loss_type,
- cls_loss_weight=classification_weight,
- loc_loss_weight=localization_weight,
- max_negatives_per_positive=max_negatives_per_positive,
- min_negatives_per_image=config.min_negatives_per_image)
- return hard_example_miner
-
-
- def build_faster_rcnn_classification_loss(loss_config):
- """Builds a classification loss for Faster RCNN based on the loss config.
-
- Args:
- loss_config: A losses_pb2.ClassificationLoss object.
-
- Returns:
- Loss based on the config.
-
- Raises:
- ValueError: On invalid loss_config.
- """
- if not isinstance(loss_config, losses_pb2.ClassificationLoss):
- raise ValueError('loss_config not of type losses_pb2.ClassificationLoss.')
-
- loss_type = loss_config.WhichOneof('classification_loss')
-
- if loss_type == 'weighted_sigmoid':
- return losses.WeightedSigmoidClassificationLoss()
- if loss_type == 'weighted_softmax':
- config = loss_config.weighted_softmax
- return losses.WeightedSoftmaxClassificationLoss(
- logit_scale=config.logit_scale)
- if loss_type == 'weighted_logits_softmax':
- config = loss_config.weighted_logits_softmax
- return losses.WeightedSoftmaxClassificationAgainstLogitsLoss(
- logit_scale=config.logit_scale)
- if loss_type == 'weighted_sigmoid_focal':
- config = loss_config.weighted_sigmoid_focal
- alpha = None
- if config.HasField('alpha'):
- alpha = config.alpha
- return losses.SigmoidFocalClassificationLoss(
- gamma=config.gamma,
- alpha=alpha)
-
- # By default, Faster RCNN second stage classifier uses Softmax loss
- # with anchor-wise outputs.
- config = loss_config.weighted_softmax
- return losses.WeightedSoftmaxClassificationLoss(
- logit_scale=config.logit_scale)
-
-
- def _build_localization_loss(loss_config):
- """Builds a localization loss based on the loss config.
-
- Args:
- loss_config: A losses_pb2.LocalizationLoss object.
-
- Returns:
- Loss based on the config.
-
- Raises:
- ValueError: On invalid loss_config.
- """
- if not isinstance(loss_config, losses_pb2.LocalizationLoss):
- raise ValueError('loss_config not of type losses_pb2.LocalizationLoss.')
-
- loss_type = loss_config.WhichOneof('localization_loss')
-
- if loss_type == 'weighted_l2':
- return losses.WeightedL2LocalizationLoss()
-
- if loss_type == 'weighted_smooth_l1':
- return losses.WeightedSmoothL1LocalizationLoss(
- loss_config.weighted_smooth_l1.delta)
-
- if loss_type == 'weighted_iou':
- return losses.WeightedIOULocalizationLoss()
-
- raise ValueError('Empty loss config.')
-
-
- def _build_classification_loss(loss_config):
- """Builds a classification loss based on the loss config.
-
- Args:
- loss_config: A losses_pb2.ClassificationLoss object.
-
- Returns:
- Loss based on the config.
-
- Raises:
- ValueError: On invalid loss_config.
- """
- if not isinstance(loss_config, losses_pb2.ClassificationLoss):
- raise ValueError('loss_config not of type losses_pb2.ClassificationLoss.')
-
- loss_type = loss_config.WhichOneof('classification_loss')
-
- if loss_type == 'weighted_sigmoid':
- return losses.WeightedSigmoidClassificationLoss()
-
- if loss_type == 'weighted_sigmoid_focal':
- config = loss_config.weighted_sigmoid_focal
- alpha = None
- if config.HasField('alpha'):
- alpha = config.alpha
- return losses.SigmoidFocalClassificationLoss(
- gamma=config.gamma,
- alpha=alpha)
-
- if loss_type == 'weighted_softmax':
- config = loss_config.weighted_softmax
- return losses.WeightedSoftmaxClassificationLoss(
- logit_scale=config.logit_scale)
-
- if loss_type == 'weighted_logits_softmax':
- config = loss_config.weighted_logits_softmax
- return losses.WeightedSoftmaxClassificationAgainstLogitsLoss(
- logit_scale=config.logit_scale)
-
- if loss_type == 'bootstrapped_sigmoid':
- config = loss_config.bootstrapped_sigmoid
- return losses.BootstrappedSigmoidClassificationLoss(
- alpha=config.alpha,
- bootstrap_type=('hard' if config.hard_bootstrap else 'soft'))
-
- raise ValueError('Empty loss config.')
|