You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

554 lines
20 KiB

6 years ago
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Bounding Box List operations for Numpy BoxLists.
  16. Example box operations that are supported:
  17. * Areas: compute bounding box areas
  18. * IOU: pairwise intersection-over-union scores
  19. """
  20. import numpy as np
  21. from object_detection.utils import np_box_list
  22. from object_detection.utils import np_box_ops
  23. class SortOrder(object):
  24. """Enum class for sort order.
  25. Attributes:
  26. ascend: ascend order.
  27. descend: descend order.
  28. """
  29. ASCEND = 1
  30. DESCEND = 2
  31. def area(boxlist):
  32. """Computes area of boxes.
  33. Args:
  34. boxlist: BoxList holding N boxes
  35. Returns:
  36. a numpy array with shape [N*1] representing box areas
  37. """
  38. y_min, x_min, y_max, x_max = boxlist.get_coordinates()
  39. return (y_max - y_min) * (x_max - x_min)
  40. def intersection(boxlist1, boxlist2):
  41. """Compute pairwise intersection areas between boxes.
  42. Args:
  43. boxlist1: BoxList holding N boxes
  44. boxlist2: BoxList holding M boxes
  45. Returns:
  46. a numpy array with shape [N*M] representing pairwise intersection area
  47. """
  48. return np_box_ops.intersection(boxlist1.get(), boxlist2.get())
  49. def iou(boxlist1, boxlist2):
  50. """Computes pairwise intersection-over-union between box collections.
  51. Args:
  52. boxlist1: BoxList holding N boxes
  53. boxlist2: BoxList holding M boxes
  54. Returns:
  55. a numpy array with shape [N, M] representing pairwise iou scores.
  56. """
  57. return np_box_ops.iou(boxlist1.get(), boxlist2.get())
  58. def ioa(boxlist1, boxlist2):
  59. """Computes pairwise intersection-over-area between box collections.
  60. Intersection-over-area (ioa) between two boxes box1 and box2 is defined as
  61. their intersection area over box2's area. Note that ioa is not symmetric,
  62. that is, IOA(box1, box2) != IOA(box2, box1).
  63. Args:
  64. boxlist1: BoxList holding N boxes
  65. boxlist2: BoxList holding M boxes
  66. Returns:
  67. a numpy array with shape [N, M] representing pairwise ioa scores.
  68. """
  69. return np_box_ops.ioa(boxlist1.get(), boxlist2.get())
  70. def gather(boxlist, indices, fields=None):
  71. """Gather boxes from BoxList according to indices and return new BoxList.
  72. By default, gather returns boxes corresponding to the input index list, as
  73. well as all additional fields stored in the boxlist (indexing into the
  74. first dimension). However one can optionally only gather from a
  75. subset of fields.
  76. Args:
  77. boxlist: BoxList holding N boxes
  78. indices: a 1-d numpy array of type int_
  79. fields: (optional) list of fields to also gather from. If None (default),
  80. all fields are gathered from. Pass an empty fields list to only gather
  81. the box coordinates.
  82. Returns:
  83. subboxlist: a BoxList corresponding to the subset of the input BoxList
  84. specified by indices
  85. Raises:
  86. ValueError: if specified field is not contained in boxlist or if the
  87. indices are not of type int_
  88. """
  89. if indices.size:
  90. if np.amax(indices) >= boxlist.num_boxes() or np.amin(indices) < 0:
  91. raise ValueError('indices are out of valid range.')
  92. subboxlist = np_box_list.BoxList(boxlist.get()[indices, :])
  93. if fields is None:
  94. fields = boxlist.get_extra_fields()
  95. for field in fields:
  96. extra_field_data = boxlist.get_field(field)
  97. subboxlist.add_field(field, extra_field_data[indices, ...])
  98. return subboxlist
  99. def sort_by_field(boxlist, field, order=SortOrder.DESCEND):
  100. """Sort boxes and associated fields according to a scalar field.
  101. A common use case is reordering the boxes according to descending scores.
  102. Args:
  103. boxlist: BoxList holding N boxes.
  104. field: A BoxList field for sorting and reordering the BoxList.
  105. order: (Optional) 'descend' or 'ascend'. Default is descend.
  106. Returns:
  107. sorted_boxlist: A sorted BoxList with the field in the specified order.
  108. Raises:
  109. ValueError: if specified field does not exist or is not of single dimension.
  110. ValueError: if the order is not either descend or ascend.
  111. """
  112. if not boxlist.has_field(field):
  113. raise ValueError('Field ' + field + ' does not exist')
  114. if len(boxlist.get_field(field).shape) != 1:
  115. raise ValueError('Field ' + field + 'should be single dimension.')
  116. if order != SortOrder.DESCEND and order != SortOrder.ASCEND:
  117. raise ValueError('Invalid sort order')
  118. field_to_sort = boxlist.get_field(field)
  119. sorted_indices = np.argsort(field_to_sort)
  120. if order == SortOrder.DESCEND:
  121. sorted_indices = sorted_indices[::-1]
  122. return gather(boxlist, sorted_indices)
  123. def non_max_suppression(boxlist,
  124. max_output_size=10000,
  125. iou_threshold=1.0,
  126. score_threshold=-10.0):
  127. """Non maximum suppression.
  128. This op greedily selects a subset of detection bounding boxes, pruning
  129. away boxes that have high IOU (intersection over union) overlap (> thresh)
  130. with already selected boxes. In each iteration, the detected bounding box with
  131. highest score in the available pool is selected.
  132. Args:
  133. boxlist: BoxList holding N boxes. Must contain a 'scores' field
  134. representing detection scores. All scores belong to the same class.
  135. max_output_size: maximum number of retained boxes
  136. iou_threshold: intersection over union threshold.
  137. score_threshold: minimum score threshold. Remove the boxes with scores
  138. less than this value. Default value is set to -10. A very
  139. low threshold to pass pretty much all the boxes, unless
  140. the user sets a different score threshold.
  141. Returns:
  142. a BoxList holding M boxes where M <= max_output_size
  143. Raises:
  144. ValueError: if 'scores' field does not exist
  145. ValueError: if threshold is not in [0, 1]
  146. ValueError: if max_output_size < 0
  147. """
  148. if not boxlist.has_field('scores'):
  149. raise ValueError('Field scores does not exist')
  150. if iou_threshold < 0. or iou_threshold > 1.0:
  151. raise ValueError('IOU threshold must be in [0, 1]')
  152. if max_output_size < 0:
  153. raise ValueError('max_output_size must be bigger than 0.')
  154. boxlist = filter_scores_greater_than(boxlist, score_threshold)
  155. if boxlist.num_boxes() == 0:
  156. return boxlist
  157. boxlist = sort_by_field(boxlist, 'scores')
  158. # Prevent further computation if NMS is disabled.
  159. if iou_threshold == 1.0:
  160. if boxlist.num_boxes() > max_output_size:
  161. selected_indices = np.arange(max_output_size)
  162. return gather(boxlist, selected_indices)
  163. else:
  164. return boxlist
  165. boxes = boxlist.get()
  166. num_boxes = boxlist.num_boxes()
  167. # is_index_valid is True only for all remaining valid boxes,
  168. is_index_valid = np.full(num_boxes, 1, dtype=bool)
  169. selected_indices = []
  170. num_output = 0
  171. for i in range(num_boxes):
  172. if num_output < max_output_size:
  173. if is_index_valid[i]:
  174. num_output += 1
  175. selected_indices.append(i)
  176. is_index_valid[i] = False
  177. valid_indices = np.where(is_index_valid)[0]
  178. if valid_indices.size == 0:
  179. break
  180. intersect_over_union = np_box_ops.iou(
  181. np.expand_dims(boxes[i, :], axis=0), boxes[valid_indices, :])
  182. intersect_over_union = np.squeeze(intersect_over_union, axis=0)
  183. is_index_valid[valid_indices] = np.logical_and(
  184. is_index_valid[valid_indices],
  185. intersect_over_union <= iou_threshold)
  186. return gather(boxlist, np.array(selected_indices))
  187. def multi_class_non_max_suppression(boxlist, score_thresh, iou_thresh,
  188. max_output_size):
  189. """Multi-class version of non maximum suppression.
  190. This op greedily selects a subset of detection bounding boxes, pruning
  191. away boxes that have high IOU (intersection over union) overlap (> thresh)
  192. with already selected boxes. It operates independently for each class for
  193. which scores are provided (via the scores field of the input box_list),
  194. pruning boxes with score less than a provided threshold prior to
  195. applying NMS.
  196. Args:
  197. boxlist: BoxList holding N boxes. Must contain a 'scores' field
  198. representing detection scores. This scores field is a tensor that can
  199. be 1 dimensional (in the case of a single class) or 2-dimensional, which
  200. which case we assume that it takes the shape [num_boxes, num_classes].
  201. We further assume that this rank is known statically and that
  202. scores.shape[1] is also known (i.e., the number of classes is fixed
  203. and known at graph construction time).
  204. score_thresh: scalar threshold for score (low scoring boxes are removed).
  205. iou_thresh: scalar threshold for IOU (boxes that that high IOU overlap
  206. with previously selected boxes are removed).
  207. max_output_size: maximum number of retained boxes per class.
  208. Returns:
  209. a BoxList holding M boxes with a rank-1 scores field representing
  210. corresponding scores for each box with scores sorted in decreasing order
  211. and a rank-1 classes field representing a class label for each box.
  212. Raises:
  213. ValueError: if iou_thresh is not in [0, 1] or if input boxlist does not have
  214. a valid scores field.
  215. """
  216. if not 0 <= iou_thresh <= 1.0:
  217. raise ValueError('thresh must be between 0 and 1')
  218. if not isinstance(boxlist, np_box_list.BoxList):
  219. raise ValueError('boxlist must be a BoxList')
  220. if not boxlist.has_field('scores'):
  221. raise ValueError('input boxlist must have \'scores\' field')
  222. scores = boxlist.get_field('scores')
  223. if len(scores.shape) == 1:
  224. scores = np.reshape(scores, [-1, 1])
  225. elif len(scores.shape) == 2:
  226. if scores.shape[1] is None:
  227. raise ValueError('scores field must have statically defined second '
  228. 'dimension')
  229. else:
  230. raise ValueError('scores field must be of rank 1 or 2')
  231. num_boxes = boxlist.num_boxes()
  232. num_scores = scores.shape[0]
  233. num_classes = scores.shape[1]
  234. if num_boxes != num_scores:
  235. raise ValueError('Incorrect scores field length: actual vs expected.')
  236. selected_boxes_list = []
  237. for class_idx in range(num_classes):
  238. boxlist_and_class_scores = np_box_list.BoxList(boxlist.get())
  239. class_scores = np.reshape(scores[0:num_scores, class_idx], [-1])
  240. boxlist_and_class_scores.add_field('scores', class_scores)
  241. boxlist_filt = filter_scores_greater_than(boxlist_and_class_scores,
  242. score_thresh)
  243. nms_result = non_max_suppression(boxlist_filt,
  244. max_output_size=max_output_size,
  245. iou_threshold=iou_thresh,
  246. score_threshold=score_thresh)
  247. nms_result.add_field(
  248. 'classes', np.zeros_like(nms_result.get_field('scores')) + class_idx)
  249. selected_boxes_list.append(nms_result)
  250. selected_boxes = concatenate(selected_boxes_list)
  251. sorted_boxes = sort_by_field(selected_boxes, 'scores')
  252. return sorted_boxes
  253. def scale(boxlist, y_scale, x_scale):
  254. """Scale box coordinates in x and y dimensions.
  255. Args:
  256. boxlist: BoxList holding N boxes
  257. y_scale: float
  258. x_scale: float
  259. Returns:
  260. boxlist: BoxList holding N boxes
  261. """
  262. y_min, x_min, y_max, x_max = np.array_split(boxlist.get(), 4, axis=1)
  263. y_min = y_scale * y_min
  264. y_max = y_scale * y_max
  265. x_min = x_scale * x_min
  266. x_max = x_scale * x_max
  267. scaled_boxlist = np_box_list.BoxList(np.hstack([y_min, x_min, y_max, x_max]))
  268. fields = boxlist.get_extra_fields()
  269. for field in fields:
  270. extra_field_data = boxlist.get_field(field)
  271. scaled_boxlist.add_field(field, extra_field_data)
  272. return scaled_boxlist
  273. def clip_to_window(boxlist, window):
  274. """Clip bounding boxes to a window.
  275. This op clips input bounding boxes (represented by bounding box
  276. corners) to a window, optionally filtering out boxes that do not
  277. overlap at all with the window.
  278. Args:
  279. boxlist: BoxList holding M_in boxes
  280. window: a numpy array of shape [4] representing the
  281. [y_min, x_min, y_max, x_max] window to which the op
  282. should clip boxes.
  283. Returns:
  284. a BoxList holding M_out boxes where M_out <= M_in
  285. """
  286. y_min, x_min, y_max, x_max = np.array_split(boxlist.get(), 4, axis=1)
  287. win_y_min = window[0]
  288. win_x_min = window[1]
  289. win_y_max = window[2]
  290. win_x_max = window[3]
  291. y_min_clipped = np.fmax(np.fmin(y_min, win_y_max), win_y_min)
  292. y_max_clipped = np.fmax(np.fmin(y_max, win_y_max), win_y_min)
  293. x_min_clipped = np.fmax(np.fmin(x_min, win_x_max), win_x_min)
  294. x_max_clipped = np.fmax(np.fmin(x_max, win_x_max), win_x_min)
  295. clipped = np_box_list.BoxList(
  296. np.hstack([y_min_clipped, x_min_clipped, y_max_clipped, x_max_clipped]))
  297. clipped = _copy_extra_fields(clipped, boxlist)
  298. areas = area(clipped)
  299. nonzero_area_indices = np.reshape(np.nonzero(np.greater(areas, 0.0)),
  300. [-1]).astype(np.int32)
  301. return gather(clipped, nonzero_area_indices)
  302. def prune_non_overlapping_boxes(boxlist1, boxlist2, minoverlap=0.0):
  303. """Prunes the boxes in boxlist1 that overlap less than thresh with boxlist2.
  304. For each box in boxlist1, we want its IOA to be more than minoverlap with
  305. at least one of the boxes in boxlist2. If it does not, we remove it.
  306. Args:
  307. boxlist1: BoxList holding N boxes.
  308. boxlist2: BoxList holding M boxes.
  309. minoverlap: Minimum required overlap between boxes, to count them as
  310. overlapping.
  311. Returns:
  312. A pruned boxlist with size [N', 4].
  313. """
  314. intersection_over_area = ioa(boxlist2, boxlist1) # [M, N] tensor
  315. intersection_over_area = np.amax(intersection_over_area, axis=0) # [N] tensor
  316. keep_bool = np.greater_equal(intersection_over_area, np.array(minoverlap))
  317. keep_inds = np.nonzero(keep_bool)[0]
  318. new_boxlist1 = gather(boxlist1, keep_inds)
  319. return new_boxlist1
  320. def prune_outside_window(boxlist, window):
  321. """Prunes bounding boxes that fall outside a given window.
  322. This function prunes bounding boxes that even partially fall outside the given
  323. window. See also ClipToWindow which only prunes bounding boxes that fall
  324. completely outside the window, and clips any bounding boxes that partially
  325. overflow.
  326. Args:
  327. boxlist: a BoxList holding M_in boxes.
  328. window: a numpy array of size 4, representing [ymin, xmin, ymax, xmax]
  329. of the window.
  330. Returns:
  331. pruned_corners: a tensor with shape [M_out, 4] where M_out <= M_in.
  332. valid_indices: a tensor with shape [M_out] indexing the valid bounding boxes
  333. in the input tensor.
  334. """
  335. y_min, x_min, y_max, x_max = np.array_split(boxlist.get(), 4, axis=1)
  336. win_y_min = window[0]
  337. win_x_min = window[1]
  338. win_y_max = window[2]
  339. win_x_max = window[3]
  340. coordinate_violations = np.hstack([np.less(y_min, win_y_min),
  341. np.less(x_min, win_x_min),
  342. np.greater(y_max, win_y_max),
  343. np.greater(x_max, win_x_max)])
  344. valid_indices = np.reshape(
  345. np.where(np.logical_not(np.max(coordinate_violations, axis=1))), [-1])
  346. return gather(boxlist, valid_indices), valid_indices
  347. def concatenate(boxlists, fields=None):
  348. """Concatenate list of BoxLists.
  349. This op concatenates a list of input BoxLists into a larger BoxList. It also
  350. handles concatenation of BoxList fields as long as the field tensor shapes
  351. are equal except for the first dimension.
  352. Args:
  353. boxlists: list of BoxList objects
  354. fields: optional list of fields to also concatenate. By default, all
  355. fields from the first BoxList in the list are included in the
  356. concatenation.
  357. Returns:
  358. a BoxList with number of boxes equal to
  359. sum([boxlist.num_boxes() for boxlist in BoxList])
  360. Raises:
  361. ValueError: if boxlists is invalid (i.e., is not a list, is empty, or
  362. contains non BoxList objects), or if requested fields are not contained in
  363. all boxlists
  364. """
  365. if not isinstance(boxlists, list):
  366. raise ValueError('boxlists should be a list')
  367. if not boxlists:
  368. raise ValueError('boxlists should have nonzero length')
  369. for boxlist in boxlists:
  370. if not isinstance(boxlist, np_box_list.BoxList):
  371. raise ValueError('all elements of boxlists should be BoxList objects')
  372. concatenated = np_box_list.BoxList(
  373. np.vstack([boxlist.get() for boxlist in boxlists]))
  374. if fields is None:
  375. fields = boxlists[0].get_extra_fields()
  376. for field in fields:
  377. first_field_shape = boxlists[0].get_field(field).shape
  378. first_field_shape = first_field_shape[1:]
  379. for boxlist in boxlists:
  380. if not boxlist.has_field(field):
  381. raise ValueError('boxlist must contain all requested fields')
  382. field_shape = boxlist.get_field(field).shape
  383. field_shape = field_shape[1:]
  384. if field_shape != first_field_shape:
  385. raise ValueError('field %s must have same shape for all boxlists '
  386. 'except for the 0th dimension.' % field)
  387. concatenated_field = np.concatenate(
  388. [boxlist.get_field(field) for boxlist in boxlists], axis=0)
  389. concatenated.add_field(field, concatenated_field)
  390. return concatenated
  391. def filter_scores_greater_than(boxlist, thresh):
  392. """Filter to keep only boxes with score exceeding a given threshold.
  393. This op keeps the collection of boxes whose corresponding scores are
  394. greater than the input threshold.
  395. Args:
  396. boxlist: BoxList holding N boxes. Must contain a 'scores' field
  397. representing detection scores.
  398. thresh: scalar threshold
  399. Returns:
  400. a BoxList holding M boxes where M <= N
  401. Raises:
  402. ValueError: if boxlist not a BoxList object or if it does not
  403. have a scores field
  404. """
  405. if not isinstance(boxlist, np_box_list.BoxList):
  406. raise ValueError('boxlist must be a BoxList')
  407. if not boxlist.has_field('scores'):
  408. raise ValueError('input boxlist must have \'scores\' field')
  409. scores = boxlist.get_field('scores')
  410. if len(scores.shape) > 2:
  411. raise ValueError('Scores should have rank 1 or 2')
  412. if len(scores.shape) == 2 and scores.shape[1] != 1:
  413. raise ValueError('Scores should have rank 1 or have shape '
  414. 'consistent with [None, 1]')
  415. high_score_indices = np.reshape(np.where(np.greater(scores, thresh)),
  416. [-1]).astype(np.int32)
  417. return gather(boxlist, high_score_indices)
  418. def change_coordinate_frame(boxlist, window):
  419. """Change coordinate frame of the boxlist to be relative to window's frame.
  420. Given a window of the form [ymin, xmin, ymax, xmax],
  421. changes bounding box coordinates from boxlist to be relative to this window
  422. (e.g., the min corner maps to (0,0) and the max corner maps to (1,1)).
  423. An example use case is data augmentation: where we are given groundtruth
  424. boxes (boxlist) and would like to randomly crop the image to some
  425. window (window). In this case we need to change the coordinate frame of
  426. each groundtruth box to be relative to this new window.
  427. Args:
  428. boxlist: A BoxList object holding N boxes.
  429. window: a size 4 1-D numpy array.
  430. Returns:
  431. Returns a BoxList object with N boxes.
  432. """
  433. win_height = window[2] - window[0]
  434. win_width = window[3] - window[1]
  435. boxlist_new = scale(
  436. np_box_list.BoxList(boxlist.get() -
  437. [window[0], window[1], window[0], window[1]]),
  438. 1.0 / win_height, 1.0 / win_width)
  439. _copy_extra_fields(boxlist_new, boxlist)
  440. return boxlist_new
  441. def _copy_extra_fields(boxlist_to_copy_to, boxlist_to_copy_from):
  442. """Copies the extra fields of boxlist_to_copy_from to boxlist_to_copy_to.
  443. Args:
  444. boxlist_to_copy_to: BoxList to which extra fields are copied.
  445. boxlist_to_copy_from: BoxList from which fields are copied.
  446. Returns:
  447. boxlist_to_copy_to with extra fields.
  448. """
  449. for field in boxlist_to_copy_from.get_extra_fields():
  450. boxlist_to_copy_to.add_field(field, boxlist_to_copy_from.get_field(field))
  451. return boxlist_to_copy_to
  452. def _update_valid_indices_by_removing_high_iou_boxes(
  453. selected_indices, is_index_valid, intersect_over_union, threshold):
  454. max_iou = np.max(intersect_over_union[:, selected_indices], axis=1)
  455. return np.logical_and(is_index_valid, max_iou <= threshold)