You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

63 lines
2.4 KiB

6 years ago
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Numpy BoxMaskList classes and functions."""
  16. import numpy as np
  17. from object_detection.utils import np_box_list
  18. class BoxMaskList(np_box_list.BoxList):
  19. """Convenience wrapper for BoxList with masks.
  20. BoxMaskList extends the np_box_list.BoxList to contain masks as well.
  21. In particular, its constructor receives both boxes and masks. Note that the
  22. masks correspond to the full image.
  23. """
  24. def __init__(self, box_data, mask_data):
  25. """Constructs box collection.
  26. Args:
  27. box_data: a numpy array of shape [N, 4] representing box coordinates
  28. mask_data: a numpy array of shape [N, height, width] representing masks
  29. with values are in {0,1}. The masks correspond to the full
  30. image. The height and the width will be equal to image height and width.
  31. Raises:
  32. ValueError: if bbox data is not a numpy array
  33. ValueError: if invalid dimensions for bbox data
  34. ValueError: if mask data is not a numpy array
  35. ValueError: if invalid dimension for mask data
  36. """
  37. super(BoxMaskList, self).__init__(box_data)
  38. if not isinstance(mask_data, np.ndarray):
  39. raise ValueError('Mask data must be a numpy array.')
  40. if len(mask_data.shape) != 3:
  41. raise ValueError('Invalid dimensions for mask data.')
  42. if mask_data.dtype != np.uint8:
  43. raise ValueError('Invalid data type for mask data: uint8 is required.')
  44. if mask_data.shape[0] != box_data.shape[0]:
  45. raise ValueError('There should be the same number of boxes and masks.')
  46. self.data['masks'] = mask_data
  47. def get_masks(self):
  48. """Convenience function for accessing masks.
  49. Returns:
  50. a numpy array of shape [N, height, width] representing masks
  51. """
  52. return self.get_field('masks')