You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

133 lines
4.4 KiB

6 years ago
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Numpy BoxList classes and functions."""
  16. import numpy as np
  17. class BoxList(object):
  18. """Box collection.
  19. BoxList represents a list of bounding boxes as numpy array, where each
  20. bounding box is represented as a row of 4 numbers,
  21. [y_min, x_min, y_max, x_max]. It is assumed that all bounding boxes within a
  22. given list correspond to a single image.
  23. Optionally, users can add additional related fields (such as
  24. objectness/classification scores).
  25. """
  26. def __init__(self, data):
  27. """Constructs box collection.
  28. Args:
  29. data: a numpy array of shape [N, 4] representing box coordinates
  30. Raises:
  31. ValueError: if bbox data is not a numpy array
  32. ValueError: if invalid dimensions for bbox data
  33. """
  34. if not isinstance(data, np.ndarray):
  35. raise ValueError('data must be a numpy array.')
  36. if len(data.shape) != 2 or data.shape[1] != 4:
  37. raise ValueError('Invalid dimensions for box data.')
  38. if data.dtype != np.float32 and data.dtype != np.float64:
  39. raise ValueError('Invalid data type for box data: float is required.')
  40. if not self._is_valid_boxes(data):
  41. raise ValueError('Invalid box data. data must be a numpy array of '
  42. 'N*[y_min, x_min, y_max, x_max]')
  43. self.data = {'boxes': data}
  44. def num_boxes(self):
  45. """Return number of boxes held in collections."""
  46. return self.data['boxes'].shape[0]
  47. def get_extra_fields(self):
  48. """Return all non-box fields."""
  49. return [k for k in self.data.keys() if k != 'boxes']
  50. def has_field(self, field):
  51. return field in self.data
  52. def add_field(self, field, field_data):
  53. """Add data to a specified field.
  54. Args:
  55. field: a string parameter used to speficy a related field to be accessed.
  56. field_data: a numpy array of [N, ...] representing the data associated
  57. with the field.
  58. Raises:
  59. ValueError: if the field is already exist or the dimension of the field
  60. data does not matches the number of boxes.
  61. """
  62. if self.has_field(field):
  63. raise ValueError('Field ' + field + 'already exists')
  64. if len(field_data.shape) < 1 or field_data.shape[0] != self.num_boxes():
  65. raise ValueError('Invalid dimensions for field data')
  66. self.data[field] = field_data
  67. def get(self):
  68. """Convenience function for accesssing box coordinates.
  69. Returns:
  70. a numpy array of shape [N, 4] representing box corners
  71. """
  72. return self.get_field('boxes')
  73. def get_field(self, field):
  74. """Accesses data associated with the specified field in the box collection.
  75. Args:
  76. field: a string parameter used to speficy a related field to be accessed.
  77. Returns:
  78. a numpy 1-d array representing data of an associated field
  79. Raises:
  80. ValueError: if invalid field
  81. """
  82. if not self.has_field(field):
  83. raise ValueError('field {} does not exist'.format(field))
  84. return self.data[field]
  85. def get_coordinates(self):
  86. """Get corner coordinates of boxes.
  87. Returns:
  88. a list of 4 1-d numpy arrays [y_min, x_min, y_max, x_max]
  89. """
  90. box_coordinates = self.get()
  91. y_min = box_coordinates[:, 0]
  92. x_min = box_coordinates[:, 1]
  93. y_max = box_coordinates[:, 2]
  94. x_max = box_coordinates[:, 3]
  95. return [y_min, x_min, y_max, x_max]
  96. def _is_valid_boxes(self, data):
  97. """Check whether data fullfills the format of N*[ymin, xmin, ymax, xmin].
  98. Args:
  99. data: a numpy array of shape [N, 4] representing box coordinates
  100. Returns:
  101. a boolean indicating whether all ymax of boxes are equal or greater than
  102. ymin, and all xmax of boxes are equal or greater than xmin.
  103. """
  104. if data.shape[0] > 0:
  105. for i in range(data.shape[0]):
  106. if data[i, 0] > data[i, 2] or data[i, 1] > data[i, 3]:
  107. return False
  108. return True