|
|
- # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
-
- """Numpy BoxMaskList classes and functions."""
-
- import numpy as np
- from object_detection.utils import np_box_list
-
-
- class BoxMaskList(np_box_list.BoxList):
- """Convenience wrapper for BoxList with masks.
-
- BoxMaskList extends the np_box_list.BoxList to contain masks as well.
- In particular, its constructor receives both boxes and masks. Note that the
- masks correspond to the full image.
- """
-
- def __init__(self, box_data, mask_data):
- """Constructs box collection.
-
- Args:
- box_data: a numpy array of shape [N, 4] representing box coordinates
- mask_data: a numpy array of shape [N, height, width] representing masks
- with values are in {0,1}. The masks correspond to the full
- image. The height and the width will be equal to image height and width.
-
- Raises:
- ValueError: if bbox data is not a numpy array
- ValueError: if invalid dimensions for bbox data
- ValueError: if mask data is not a numpy array
- ValueError: if invalid dimension for mask data
- """
- super(BoxMaskList, self).__init__(box_data)
- if not isinstance(mask_data, np.ndarray):
- raise ValueError('Mask data must be a numpy array.')
- if len(mask_data.shape) != 3:
- raise ValueError('Invalid dimensions for mask data.')
- if mask_data.dtype != np.uint8:
- raise ValueError('Invalid data type for mask data: uint8 is required.')
- if mask_data.shape[0] != box_data.shape[0]:
- raise ValueError('There should be the same number of boxes and masks.')
- self.data['masks'] = mask_data
-
- def get_masks(self):
- """Convenience function for accessing masks.
-
- Returns:
- a numpy array of shape [N, height, width] representing masks
- """
- return self.get_field('masks')
-
|