You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

88 lines
3.4 KiB

6 years ago
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for object_detection.np_mask_ops."""
  16. import numpy as np
  17. import tensorflow as tf
  18. from object_detection.utils import np_mask_ops
  19. class MaskOpsTests(tf.test.TestCase):
  20. def setUp(self):
  21. masks1_0 = np.array([[0, 0, 0, 0, 0, 0, 0, 0],
  22. [0, 0, 0, 0, 0, 0, 0, 0],
  23. [0, 0, 0, 0, 0, 0, 0, 0],
  24. [1, 1, 1, 1, 0, 0, 0, 0],
  25. [1, 1, 1, 1, 0, 0, 0, 0]],
  26. dtype=np.uint8)
  27. masks1_1 = np.array([[1, 1, 1, 1, 1, 1, 1, 1],
  28. [1, 1, 0, 0, 0, 0, 0, 0],
  29. [0, 0, 0, 0, 0, 0, 0, 0],
  30. [0, 0, 0, 0, 0, 0, 0, 0],
  31. [0, 0, 0, 0, 0, 0, 0, 0]],
  32. dtype=np.uint8)
  33. masks1 = np.stack([masks1_0, masks1_1])
  34. masks2_0 = np.array([[0, 0, 0, 0, 0, 0, 0, 0],
  35. [0, 0, 0, 0, 0, 0, 0, 0],
  36. [0, 0, 0, 0, 0, 0, 0, 0],
  37. [1, 1, 1, 1, 0, 0, 0, 0],
  38. [1, 1, 1, 1, 0, 0, 0, 0]],
  39. dtype=np.uint8)
  40. masks2_1 = np.array([[1, 1, 1, 1, 1, 1, 1, 0],
  41. [1, 1, 1, 1, 1, 0, 0, 0],
  42. [1, 1, 1, 0, 0, 0, 0, 0],
  43. [0, 0, 0, 0, 0, 0, 0, 0],
  44. [0, 0, 0, 0, 0, 0, 0, 0]],
  45. dtype=np.uint8)
  46. masks2_2 = np.array([[1, 1, 1, 1, 1, 0, 0, 0],
  47. [1, 1, 1, 1, 1, 0, 0, 0],
  48. [1, 1, 1, 1, 1, 0, 0, 0],
  49. [1, 1, 1, 1, 1, 0, 0, 0],
  50. [1, 1, 1, 1, 1, 0, 0, 0]],
  51. dtype=np.uint8)
  52. masks2 = np.stack([masks2_0, masks2_1, masks2_2])
  53. self.masks1 = masks1
  54. self.masks2 = masks2
  55. def testArea(self):
  56. areas = np_mask_ops.area(self.masks1)
  57. expected_areas = np.array([8.0, 10.0], dtype=np.float32)
  58. self.assertAllClose(expected_areas, areas)
  59. def testIntersection(self):
  60. intersection = np_mask_ops.intersection(self.masks1, self.masks2)
  61. expected_intersection = np.array(
  62. [[8.0, 0.0, 8.0], [0.0, 9.0, 7.0]], dtype=np.float32)
  63. self.assertAllClose(intersection, expected_intersection)
  64. def testIOU(self):
  65. iou = np_mask_ops.iou(self.masks1, self.masks2)
  66. expected_iou = np.array(
  67. [[1.0, 0.0, 8.0/25.0], [0.0, 9.0 / 16.0, 7.0 / 28.0]], dtype=np.float32)
  68. self.assertAllClose(iou, expected_iou)
  69. def testIOA(self):
  70. ioa21 = np_mask_ops.ioa(self.masks1, self.masks2)
  71. expected_ioa21 = np.array([[1.0, 0.0, 8.0/25.0],
  72. [0.0, 9.0/15.0, 7.0/25.0]],
  73. dtype=np.float32)
  74. self.assertAllClose(ioa21, expected_ioa21)
  75. if __name__ == '__main__':
  76. tf.test.main()