You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

202 lines
7.4 KiB

  1. syntax = "proto2";
  2. package object_detection.protos;
  3. // Message for configuring the localization loss, classification loss and hard
  4. // example miner used for training object detection models. See core/losses.py
  5. // for details
  6. message Loss {
  7. // Localization loss to use.
  8. optional LocalizationLoss localization_loss = 1;
  9. // Classification loss to use.
  10. optional ClassificationLoss classification_loss = 2;
  11. // If not left to default, applies hard example mining.
  12. optional HardExampleMiner hard_example_miner = 3;
  13. // Classification loss weight.
  14. optional float classification_weight = 4 [default=1.0];
  15. // Localization loss weight.
  16. optional float localization_weight = 5 [default=1.0];
  17. // If not left to default, applies random example sampling.
  18. optional RandomExampleSampler random_example_sampler = 6;
  19. // Equalization loss.
  20. message EqualizationLoss {
  21. // Weight equalization loss strength.
  22. optional float weight = 1 [default=0.0];
  23. // When computing equalization loss, ops that start with
  24. // equalization_exclude_prefixes will be ignored. Only used when
  25. // equalization_weight > 0.
  26. repeated string exclude_prefixes = 2;
  27. }
  28. optional EqualizationLoss equalization_loss = 7;
  29. enum ExpectedLossWeights {
  30. NONE = 0;
  31. // Use expected_classification_loss_by_expected_sampling
  32. // from third_party/tensorflow_models/object_detection/utils/ops.py
  33. EXPECTED_SAMPLING = 1;
  34. // Use expected_classification_loss_by_reweighting_unmatched_anchors
  35. // from third_party/tensorflow_models/object_detection/utils/ops.py
  36. REWEIGHTING_UNMATCHED_ANCHORS = 2;
  37. }
  38. // Method to compute expected loss weights with respect to balanced
  39. // positive/negative sampling scheme. If NONE, use explicit sampling.
  40. // TODO(birdbrain): Move under ExpectedLossWeights.
  41. optional ExpectedLossWeights expected_loss_weights = 18 [default = NONE];
  42. // Minimum number of effective negative samples.
  43. // Only applies if expected_loss_weights is not NONE.
  44. // TODO(birdbrain): Move under ExpectedLossWeights.
  45. optional float min_num_negative_samples = 19 [default=0];
  46. // Desired number of effective negative samples per positive sample.
  47. // Only applies if expected_loss_weights is not NONE.
  48. // TODO(birdbrain): Move under ExpectedLossWeights.
  49. optional float desired_negative_sampling_ratio = 20 [default=3];
  50. }
  51. // Configuration for bounding box localization loss function.
  52. message LocalizationLoss {
  53. oneof localization_loss {
  54. WeightedL2LocalizationLoss weighted_l2 = 1;
  55. WeightedSmoothL1LocalizationLoss weighted_smooth_l1 = 2;
  56. WeightedIOULocalizationLoss weighted_iou = 3;
  57. }
  58. }
  59. // L2 location loss: 0.5 * ||weight * (a - b)|| ^ 2
  60. message WeightedL2LocalizationLoss {
  61. // DEPRECATED, do not use.
  62. // Output loss per anchor.
  63. optional bool anchorwise_output = 1 [default=false];
  64. }
  65. // SmoothL1 (Huber) location loss.
  66. // The smooth L1_loss is defined elementwise as .5 x^2 if |x| <= delta and
  67. // delta * (|x|-0.5*delta) otherwise, where x is the difference between
  68. // predictions and target.
  69. message WeightedSmoothL1LocalizationLoss {
  70. // DEPRECATED, do not use.
  71. // Output loss per anchor.
  72. optional bool anchorwise_output = 1 [default=false];
  73. // Delta value for huber loss.
  74. optional float delta = 2 [default=1.0];
  75. }
  76. // Intersection over union location loss: 1 - IOU
  77. message WeightedIOULocalizationLoss {
  78. }
  79. // Configuration for class prediction loss function.
  80. message ClassificationLoss {
  81. oneof classification_loss {
  82. WeightedSigmoidClassificationLoss weighted_sigmoid = 1;
  83. WeightedSoftmaxClassificationLoss weighted_softmax = 2;
  84. WeightedSoftmaxClassificationAgainstLogitsLoss weighted_logits_softmax = 5;
  85. BootstrappedSigmoidClassificationLoss bootstrapped_sigmoid = 3;
  86. SigmoidFocalClassificationLoss weighted_sigmoid_focal = 4;
  87. }
  88. }
  89. // Classification loss using a sigmoid function over class predictions.
  90. message WeightedSigmoidClassificationLoss {
  91. // DEPRECATED, do not use.
  92. // Output loss per anchor.
  93. optional bool anchorwise_output = 1 [default=false];
  94. }
  95. // Sigmoid Focal cross entropy loss as described in
  96. // https://arxiv.org/abs/1708.02002
  97. message SigmoidFocalClassificationLoss {
  98. // DEPRECATED, do not use.
  99. optional bool anchorwise_output = 1 [default = false];
  100. // modulating factor for the loss.
  101. optional float gamma = 2 [default = 2.0];
  102. // alpha weighting factor for the loss.
  103. optional float alpha = 3;
  104. }
  105. // Classification loss using a softmax function over class predictions.
  106. message WeightedSoftmaxClassificationLoss {
  107. // DEPRECATED, do not use.
  108. // Output loss per anchor.
  109. optional bool anchorwise_output = 1 [default=false];
  110. // Scale logit (input) value before calculating softmax classification loss.
  111. // Typically used for softmax distillation.
  112. optional float logit_scale = 2 [default = 1.0];
  113. }
  114. // Classification loss using a softmax function over class predictions and
  115. // a softmax function over the groundtruth labels (assumed to be logits).
  116. message WeightedSoftmaxClassificationAgainstLogitsLoss {
  117. // DEPRECATED, do not use.
  118. optional bool anchorwise_output = 1 [default = false];
  119. // Scale and softmax groundtruth logits before calculating softmax
  120. // classification loss. Typically used for softmax distillation with teacher
  121. // annotations stored as logits.
  122. optional float logit_scale = 2 [default = 1.0];
  123. }
  124. // Classification loss using a sigmoid function over the class prediction with
  125. // the highest prediction score.
  126. message BootstrappedSigmoidClassificationLoss {
  127. // Interpolation weight between 0 and 1.
  128. optional float alpha = 1;
  129. // Whether hard boot strapping should be used or not. If true, will only use
  130. // one class favored by model. Othewise, will use all predicted class
  131. // probabilities.
  132. optional bool hard_bootstrap = 2 [default=false];
  133. // DEPRECATED, do not use.
  134. // Output loss per anchor.
  135. optional bool anchorwise_output = 3 [default=false];
  136. }
  137. // Configuration for hard example miner.
  138. message HardExampleMiner {
  139. // Maximum number of hard examples to be selected per image (prior to
  140. // enforcing max negative to positive ratio constraint). If set to 0,
  141. // all examples obtained after NMS are considered.
  142. optional int32 num_hard_examples = 1 [default=64];
  143. // Minimum intersection over union for an example to be discarded during NMS.
  144. optional float iou_threshold = 2 [default=0.7];
  145. // Whether to use classification losses ('cls', default), localization losses
  146. // ('loc') or both losses ('both'). In the case of 'both', cls_loss_weight and
  147. // loc_loss_weight are used to compute weighted sum of the two losses.
  148. enum LossType {
  149. BOTH = 0;
  150. CLASSIFICATION = 1;
  151. LOCALIZATION = 2;
  152. }
  153. optional LossType loss_type = 3 [default=BOTH];
  154. // Maximum number of negatives to retain for each positive anchor. If
  155. // num_negatives_per_positive is 0 no prespecified negative:positive ratio is
  156. // enforced.
  157. optional int32 max_negatives_per_positive = 4 [default=0];
  158. // Minimum number of negative anchors to sample for a given image. Setting
  159. // this to a positive number samples negatives in an image without any
  160. // positive anchors and thus not bias the model towards having at least one
  161. // detection per image.
  162. optional int32 min_negatives_per_image = 5 [default=0];
  163. }
  164. // Configuration for random example sampler.
  165. message RandomExampleSampler {
  166. // The desired fraction of positive samples in batch when applying random
  167. // example sampling.
  168. optional float positive_sample_fraction = 1 [default = 0.01];
  169. }