You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

109 lines
4.2 KiB

  1. # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for calibration_metrics."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import numpy as np
  20. import tensorflow as tf
  21. from object_detection.metrics import calibration_metrics
  22. class CalibrationLibTest(tf.test.TestCase):
  23. @staticmethod
  24. def _get_calibration_placeholders():
  25. """Returns TF placeholders for y_true and y_pred."""
  26. return (tf.placeholder(tf.int64, shape=(None)),
  27. tf.placeholder(tf.float32, shape=(None)))
  28. def test_expected_calibration_error_all_bins_filled(self):
  29. """Test expected calibration error when all bins contain predictions."""
  30. y_true, y_pred = self._get_calibration_placeholders()
  31. expected_ece_op, update_op = calibration_metrics.expected_calibration_error(
  32. y_true, y_pred, nbins=2)
  33. with self.test_session() as sess:
  34. metrics_vars = tf.get_collection(tf.GraphKeys.METRIC_VARIABLES)
  35. sess.run(tf.variables_initializer(var_list=metrics_vars))
  36. # Bin calibration errors (|confidence - accuracy| * bin_weight):
  37. # - [0,0.5): |0.2 - 0.333| * (3/5) = 0.08
  38. # - [0.5, 1]: |0.75 - 0.5| * (2/5) = 0.1
  39. sess.run(
  40. update_op,
  41. feed_dict={
  42. y_pred: np.array([0., 0.2, 0.4, 0.5, 1.0]),
  43. y_true: np.array([0, 0, 1, 0, 1])
  44. })
  45. actual_ece = 0.08 + 0.1
  46. expected_ece = sess.run(expected_ece_op)
  47. self.assertAlmostEqual(actual_ece, expected_ece)
  48. def test_expected_calibration_error_all_bins_not_filled(self):
  49. """Test expected calibration error when no predictions for one bin."""
  50. y_true, y_pred = self._get_calibration_placeholders()
  51. expected_ece_op, update_op = calibration_metrics.expected_calibration_error(
  52. y_true, y_pred, nbins=2)
  53. with self.test_session() as sess:
  54. metrics_vars = tf.get_collection(tf.GraphKeys.METRIC_VARIABLES)
  55. sess.run(tf.variables_initializer(var_list=metrics_vars))
  56. # Bin calibration errors (|confidence - accuracy| * bin_weight):
  57. # - [0,0.5): |0.2 - 0.333| * (3/5) = 0.08
  58. # - [0.5, 1]: |0.75 - 0.5| * (2/5) = 0.1
  59. sess.run(
  60. update_op,
  61. feed_dict={
  62. y_pred: np.array([0., 0.2, 0.4]),
  63. y_true: np.array([0, 0, 1])
  64. })
  65. actual_ece = np.abs(0.2 - (1 / 3.))
  66. expected_ece = sess.run(expected_ece_op)
  67. self.assertAlmostEqual(actual_ece, expected_ece)
  68. def test_expected_calibration_error_with_multiple_data_streams(self):
  69. """Test expected calibration error when multiple data batches provided."""
  70. y_true, y_pred = self._get_calibration_placeholders()
  71. expected_ece_op, update_op = calibration_metrics.expected_calibration_error(
  72. y_true, y_pred, nbins=2)
  73. with self.test_session() as sess:
  74. metrics_vars = tf.get_collection(tf.GraphKeys.METRIC_VARIABLES)
  75. sess.run(tf.variables_initializer(var_list=metrics_vars))
  76. # Identical data to test_expected_calibration_error_all_bins_filled,
  77. # except split over three batches.
  78. sess.run(
  79. update_op,
  80. feed_dict={
  81. y_pred: np.array([0., 0.2]),
  82. y_true: np.array([0, 0])
  83. })
  84. sess.run(
  85. update_op,
  86. feed_dict={
  87. y_pred: np.array([0.4, 0.5]),
  88. y_true: np.array([1, 0])
  89. })
  90. sess.run(
  91. update_op, feed_dict={
  92. y_pred: np.array([1.0]),
  93. y_true: np.array([1])
  94. })
  95. actual_ece = 0.08 + 0.1
  96. expected_ece = sess.run(expected_ece_op)
  97. self.assertAlmostEqual(actual_ece, expected_ece)
  98. if __name__ == '__main__':
  99. tf.test.main()