You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

118 lines
4.5 KiB

  1. # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Object detection calibration metrics.
  16. """
  17. from __future__ import absolute_import
  18. from __future__ import division
  19. from __future__ import print_function
  20. import tensorflow as tf
  21. from tensorflow.python.ops import metrics_impl
  22. def _safe_div(numerator, denominator):
  23. """Divides two tensors element-wise, returning 0 if the denominator is <= 0.
  24. Args:
  25. numerator: A real `Tensor`.
  26. denominator: A real `Tensor`, with dtype matching `numerator`.
  27. Returns:
  28. 0 if `denominator` <= 0, else `numerator` / `denominator`
  29. """
  30. t = tf.truediv(numerator, denominator)
  31. zero = tf.zeros_like(t, dtype=denominator.dtype)
  32. condition = tf.greater(denominator, zero)
  33. zero = tf.cast(zero, t.dtype)
  34. return tf.where(condition, t, zero)
  35. def _ece_from_bins(bin_counts, bin_true_sum, bin_preds_sum, name):
  36. """Calculates Expected Calibration Error from accumulated statistics."""
  37. bin_accuracies = _safe_div(bin_true_sum, bin_counts)
  38. bin_confidences = _safe_div(bin_preds_sum, bin_counts)
  39. abs_bin_errors = tf.abs(bin_accuracies - bin_confidences)
  40. bin_weights = _safe_div(bin_counts, tf.reduce_sum(bin_counts))
  41. return tf.reduce_sum(abs_bin_errors * bin_weights, name=name)
  42. def expected_calibration_error(y_true, y_pred, nbins=20):
  43. """Calculates Expected Calibration Error (ECE).
  44. ECE is a scalar summary statistic of calibration error. It is the
  45. sample-weighted average of the difference between the predicted and true
  46. probabilities of a positive detection across uniformly-spaced model
  47. confidences [0, 1]. See referenced paper for a thorough explanation.
  48. Reference:
  49. Guo, et. al, "On Calibration of Modern Neural Networks"
  50. Page 2, Expected Calibration Error (ECE).
  51. https://arxiv.org/pdf/1706.04599.pdf
  52. This function creates three local variables, `bin_counts`, `bin_true_sum`, and
  53. `bin_preds_sum` that are used to compute ECE. For estimation of the metric
  54. over a stream of data, the function creates an `update_op` operation that
  55. updates these variables and returns the ECE.
  56. Args:
  57. y_true: 1-D tf.int64 Tensor of binarized ground truth, corresponding to each
  58. prediction in y_pred.
  59. y_pred: 1-D tf.float32 tensor of model confidence scores in range
  60. [0.0, 1.0].
  61. nbins: int specifying the number of uniformly-spaced bins into which y_pred
  62. will be bucketed.
  63. Returns:
  64. value_op: A value metric op that returns ece.
  65. update_op: An operation that increments the `bin_counts`, `bin_true_sum`,
  66. and `bin_preds_sum` variables appropriately and whose value matches `ece`.
  67. Raises:
  68. InvalidArgumentError: if y_pred is not in [0.0, 1.0].
  69. """
  70. bin_counts = metrics_impl.metric_variable(
  71. [nbins], tf.float32, name='bin_counts')
  72. bin_true_sum = metrics_impl.metric_variable(
  73. [nbins], tf.float32, name='true_sum')
  74. bin_preds_sum = metrics_impl.metric_variable(
  75. [nbins], tf.float32, name='preds_sum')
  76. with tf.control_dependencies([
  77. tf.assert_greater_equal(y_pred, 0.0),
  78. tf.assert_less_equal(y_pred, 1.0),
  79. ]):
  80. bin_ids = tf.histogram_fixed_width_bins(y_pred, [0.0, 1.0], nbins=nbins)
  81. with tf.control_dependencies([bin_ids]):
  82. update_bin_counts_op = tf.assign_add(
  83. bin_counts, tf.cast(tf.bincount(bin_ids, minlength=nbins),
  84. dtype=tf.float32))
  85. update_bin_true_sum_op = tf.assign_add(
  86. bin_true_sum,
  87. tf.cast(tf.bincount(bin_ids, weights=y_true, minlength=nbins),
  88. dtype=tf.float32))
  89. update_bin_preds_sum_op = tf.assign_add(
  90. bin_preds_sum,
  91. tf.cast(tf.bincount(bin_ids, weights=y_pred, minlength=nbins),
  92. dtype=tf.float32))
  93. ece_update_op = _ece_from_bins(
  94. update_bin_counts_op,
  95. update_bin_true_sum_op,
  96. update_bin_preds_sum_op,
  97. name='update_op')
  98. ece = _ece_from_bins(bin_counts, bin_true_sum, bin_preds_sum, name='value')
  99. return ece, ece_update_op