You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

135 lines
5.1 KiB

6 years ago
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for object_detection.utils.np_box_list_test."""
  16. import numpy as np
  17. import tensorflow as tf
  18. from object_detection.utils import np_box_list
  19. class BoxListTest(tf.test.TestCase):
  20. def test_invalid_box_data(self):
  21. with self.assertRaises(ValueError):
  22. np_box_list.BoxList([0, 0, 1, 1])
  23. with self.assertRaises(ValueError):
  24. np_box_list.BoxList(np.array([[0, 0, 1, 1]], dtype=int))
  25. with self.assertRaises(ValueError):
  26. np_box_list.BoxList(np.array([0, 1, 1, 3, 4], dtype=float))
  27. with self.assertRaises(ValueError):
  28. np_box_list.BoxList(np.array([[0, 1, 1, 3], [3, 1, 1, 5]], dtype=float))
  29. def test_has_field_with_existed_field(self):
  30. boxes = np.array([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0],
  31. [0.0, 0.0, 20.0, 20.0]],
  32. dtype=float)
  33. boxlist = np_box_list.BoxList(boxes)
  34. self.assertTrue(boxlist.has_field('boxes'))
  35. def test_has_field_with_nonexisted_field(self):
  36. boxes = np.array([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0],
  37. [0.0, 0.0, 20.0, 20.0]],
  38. dtype=float)
  39. boxlist = np_box_list.BoxList(boxes)
  40. self.assertFalse(boxlist.has_field('scores'))
  41. def test_get_field_with_existed_field(self):
  42. boxes = np.array([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0],
  43. [0.0, 0.0, 20.0, 20.0]],
  44. dtype=float)
  45. boxlist = np_box_list.BoxList(boxes)
  46. self.assertTrue(np.allclose(boxlist.get_field('boxes'), boxes))
  47. def test_get_field_with_nonexited_field(self):
  48. boxes = np.array([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0],
  49. [0.0, 0.0, 20.0, 20.0]],
  50. dtype=float)
  51. boxlist = np_box_list.BoxList(boxes)
  52. with self.assertRaises(ValueError):
  53. boxlist.get_field('scores')
  54. class AddExtraFieldTest(tf.test.TestCase):
  55. def setUp(self):
  56. boxes = np.array([[3.0, 4.0, 6.0, 8.0], [14.0, 14.0, 15.0, 15.0],
  57. [0.0, 0.0, 20.0, 20.0]],
  58. dtype=float)
  59. self.boxlist = np_box_list.BoxList(boxes)
  60. def test_add_already_existed_field(self):
  61. with self.assertRaises(ValueError):
  62. self.boxlist.add_field('boxes', np.array([[0, 0, 0, 1, 0]], dtype=float))
  63. def test_add_invalid_field_data(self):
  64. with self.assertRaises(ValueError):
  65. self.boxlist.add_field('scores', np.array([0.5, 0.7], dtype=float))
  66. with self.assertRaises(ValueError):
  67. self.boxlist.add_field('scores',
  68. np.array([0.5, 0.7, 0.9, 0.1], dtype=float))
  69. def test_add_single_dimensional_field_data(self):
  70. boxlist = self.boxlist
  71. scores = np.array([0.5, 0.7, 0.9], dtype=float)
  72. boxlist.add_field('scores', scores)
  73. self.assertTrue(np.allclose(scores, self.boxlist.get_field('scores')))
  74. def test_add_multi_dimensional_field_data(self):
  75. boxlist = self.boxlist
  76. labels = np.array([[0, 0, 0, 1, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 1]],
  77. dtype=int)
  78. boxlist.add_field('labels', labels)
  79. self.assertTrue(np.allclose(labels, self.boxlist.get_field('labels')))
  80. def test_get_extra_fields(self):
  81. boxlist = self.boxlist
  82. self.assertItemsEqual(boxlist.get_extra_fields(), [])
  83. scores = np.array([0.5, 0.7, 0.9], dtype=float)
  84. boxlist.add_field('scores', scores)
  85. self.assertItemsEqual(boxlist.get_extra_fields(), ['scores'])
  86. labels = np.array([[0, 0, 0, 1, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 1]],
  87. dtype=int)
  88. boxlist.add_field('labels', labels)
  89. self.assertItemsEqual(boxlist.get_extra_fields(), ['scores', 'labels'])
  90. def test_get_coordinates(self):
  91. y_min, x_min, y_max, x_max = self.boxlist.get_coordinates()
  92. expected_y_min = np.array([3.0, 14.0, 0.0], dtype=float)
  93. expected_x_min = np.array([4.0, 14.0, 0.0], dtype=float)
  94. expected_y_max = np.array([6.0, 15.0, 20.0], dtype=float)
  95. expected_x_max = np.array([8.0, 15.0, 20.0], dtype=float)
  96. self.assertTrue(np.allclose(y_min, expected_y_min))
  97. self.assertTrue(np.allclose(x_min, expected_x_min))
  98. self.assertTrue(np.allclose(y_max, expected_y_max))
  99. self.assertTrue(np.allclose(x_max, expected_x_max))
  100. def test_num_boxes(self):
  101. boxes = np.array([[0., 0., 100., 100.], [10., 30., 50., 70.]], dtype=float)
  102. boxlist = np_box_list.BoxList(boxes)
  103. expected_num_boxes = 2
  104. self.assertEquals(boxlist.num_boxes(), expected_num_boxes)
  105. if __name__ == '__main__':
  106. tf.test.main()