You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

97 lines
3.5 KiB

  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for google3.image.understanding.object_detection.utils.json_utils."""
  16. import os
  17. import tensorflow as tf
  18. from object_detection.utils import json_utils
  19. class JsonUtilsTest(tf.test.TestCase):
  20. def testDumpReasonablePrecision(self):
  21. output_path = os.path.join(tf.test.get_temp_dir(), 'test.json')
  22. with tf.gfile.GFile(output_path, 'w') as f:
  23. json_utils.Dump(1.0, f, float_digits=2)
  24. with tf.gfile.GFile(output_path, 'r') as f:
  25. self.assertEqual(f.read(), '1.00')
  26. def testDumpPassExtraParams(self):
  27. output_path = os.path.join(tf.test.get_temp_dir(), 'test.json')
  28. with tf.gfile.GFile(output_path, 'w') as f:
  29. json_utils.Dump([1.0], f, float_digits=2, indent=3)
  30. with tf.gfile.GFile(output_path, 'r') as f:
  31. self.assertEqual(f.read(), '[\n 1.00\n]')
  32. def testDumpZeroPrecision(self):
  33. output_path = os.path.join(tf.test.get_temp_dir(), 'test.json')
  34. with tf.gfile.GFile(output_path, 'w') as f:
  35. json_utils.Dump(1.0, f, float_digits=0, indent=3)
  36. with tf.gfile.GFile(output_path, 'r') as f:
  37. self.assertEqual(f.read(), '1')
  38. def testDumpUnspecifiedPrecision(self):
  39. output_path = os.path.join(tf.test.get_temp_dir(), 'test.json')
  40. with tf.gfile.GFile(output_path, 'w') as f:
  41. json_utils.Dump(1.012345, f)
  42. with tf.gfile.GFile(output_path, 'r') as f:
  43. self.assertEqual(f.read(), '1.012345')
  44. def testDumpsReasonablePrecision(self):
  45. s = json_utils.Dumps(1.0, float_digits=2)
  46. self.assertEqual(s, '1.00')
  47. def testDumpsPassExtraParams(self):
  48. s = json_utils.Dumps([1.0], float_digits=2, indent=3)
  49. self.assertEqual(s, '[\n 1.00\n]')
  50. def testDumpsZeroPrecision(self):
  51. s = json_utils.Dumps(1.0, float_digits=0)
  52. self.assertEqual(s, '1')
  53. def testDumpsUnspecifiedPrecision(self):
  54. s = json_utils.Dumps(1.012345)
  55. self.assertEqual(s, '1.012345')
  56. def testPrettyParams(self):
  57. s = json_utils.Dumps({'v': 1.012345, 'n': 2}, **json_utils.PrettyParams())
  58. self.assertEqual(s, '{\n "n": 2,\n "v": 1.0123\n}')
  59. def testPrettyParamsExtraParamsInside(self):
  60. s = json_utils.Dumps(
  61. {'v': 1.012345,
  62. 'n': float('nan')}, **json_utils.PrettyParams(allow_nan=True))
  63. self.assertEqual(s, '{\n "n": NaN,\n "v": 1.0123\n}')
  64. with self.assertRaises(ValueError):
  65. s = json_utils.Dumps(
  66. {'v': 1.012345,
  67. 'n': float('nan')}, **json_utils.PrettyParams(allow_nan=False))
  68. def testPrettyParamsExtraParamsOutside(self):
  69. s = json_utils.Dumps(
  70. {'v': 1.012345,
  71. 'n': float('nan')}, allow_nan=True, **json_utils.PrettyParams())
  72. self.assertEqual(s, '{\n "n": NaN,\n "v": 1.0123\n}')
  73. with self.assertRaises(ValueError):
  74. s = json_utils.Dumps(
  75. {'v': 1.012345,
  76. 'n': float('nan')}, allow_nan=False, **json_utils.PrettyParams())
  77. if __name__ == '__main__':
  78. tf.test.main()