You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

377 lines
15 KiB

  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Builder for preprocessing steps."""
  16. import tensorflow as tf
  17. from object_detection.core import preprocessor
  18. from object_detection.protos import preprocessor_pb2
  19. def _get_step_config_from_proto(preprocessor_step_config, step_name):
  20. """Returns the value of a field named step_name from proto.
  21. Args:
  22. preprocessor_step_config: A preprocessor_pb2.PreprocessingStep object.
  23. step_name: Name of the field to get value from.
  24. Returns:
  25. result_dict: a sub proto message from preprocessor_step_config which will be
  26. later converted to a dictionary.
  27. Raises:
  28. ValueError: If field does not exist in proto.
  29. """
  30. for field, value in preprocessor_step_config.ListFields():
  31. if field.name == step_name:
  32. return value
  33. raise ValueError('Could not get field %s from proto!' % step_name)
  34. def _get_dict_from_proto(config):
  35. """Helper function to put all proto fields into a dictionary.
  36. For many preprocessing steps, there's an trivial 1-1 mapping from proto fields
  37. to function arguments. This function automatically populates a dictionary with
  38. the arguments from the proto.
  39. Protos that CANNOT be trivially populated include:
  40. * nested messages.
  41. * steps that check if an optional field is set (ie. where None != 0).
  42. * protos that don't map 1-1 to arguments (ie. list should be reshaped).
  43. * fields requiring additional validation (ie. repeated field has n elements).
  44. Args:
  45. config: A protobuf object that does not violate the conditions above.
  46. Returns:
  47. result_dict: |config| converted into a python dictionary.
  48. """
  49. result_dict = {}
  50. for field, value in config.ListFields():
  51. result_dict[field.name] = value
  52. return result_dict
  53. # A map from a PreprocessingStep proto config field name to the preprocessing
  54. # function that should be used. The PreprocessingStep proto should be parsable
  55. # with _get_dict_from_proto.
  56. PREPROCESSING_FUNCTION_MAP = {
  57. 'normalize_image':
  58. preprocessor.normalize_image,
  59. 'random_pixel_value_scale':
  60. preprocessor.random_pixel_value_scale,
  61. 'random_image_scale':
  62. preprocessor.random_image_scale,
  63. 'random_rgb_to_gray':
  64. preprocessor.random_rgb_to_gray,
  65. 'random_adjust_brightness':
  66. preprocessor.random_adjust_brightness,
  67. 'random_adjust_contrast':
  68. preprocessor.random_adjust_contrast,
  69. 'random_adjust_hue':
  70. preprocessor.random_adjust_hue,
  71. 'random_adjust_saturation':
  72. preprocessor.random_adjust_saturation,
  73. 'random_distort_color':
  74. preprocessor.random_distort_color,
  75. 'random_jitter_boxes':
  76. preprocessor.random_jitter_boxes,
  77. 'random_crop_to_aspect_ratio':
  78. preprocessor.random_crop_to_aspect_ratio,
  79. 'random_black_patches':
  80. preprocessor.random_black_patches,
  81. 'rgb_to_gray':
  82. preprocessor.rgb_to_gray,
  83. 'scale_boxes_to_pixel_coordinates': (
  84. preprocessor.scale_boxes_to_pixel_coordinates),
  85. 'subtract_channel_mean':
  86. preprocessor.subtract_channel_mean,
  87. 'convert_class_logits_to_softmax':
  88. preprocessor.convert_class_logits_to_softmax,
  89. }
  90. # A map to convert from preprocessor_pb2.ResizeImage.Method enum to
  91. # tf.image.ResizeMethod.
  92. RESIZE_METHOD_MAP = {
  93. preprocessor_pb2.ResizeImage.AREA: tf.image.ResizeMethod.AREA,
  94. preprocessor_pb2.ResizeImage.BICUBIC: tf.image.ResizeMethod.BICUBIC,
  95. preprocessor_pb2.ResizeImage.BILINEAR: tf.image.ResizeMethod.BILINEAR,
  96. preprocessor_pb2.ResizeImage.NEAREST_NEIGHBOR: (
  97. tf.image.ResizeMethod.NEAREST_NEIGHBOR),
  98. }
  99. def build(preprocessor_step_config):
  100. """Builds preprocessing step based on the configuration.
  101. Args:
  102. preprocessor_step_config: PreprocessingStep configuration proto.
  103. Returns:
  104. function, argmap: A callable function and an argument map to call function
  105. with.
  106. Raises:
  107. ValueError: On invalid configuration.
  108. """
  109. step_type = preprocessor_step_config.WhichOneof('preprocessing_step')
  110. if step_type in PREPROCESSING_FUNCTION_MAP:
  111. preprocessing_function = PREPROCESSING_FUNCTION_MAP[step_type]
  112. step_config = _get_step_config_from_proto(preprocessor_step_config,
  113. step_type)
  114. function_args = _get_dict_from_proto(step_config)
  115. return (preprocessing_function, function_args)
  116. if step_type == 'random_horizontal_flip':
  117. config = preprocessor_step_config.random_horizontal_flip
  118. return (preprocessor.random_horizontal_flip,
  119. {
  120. 'keypoint_flip_permutation': tuple(
  121. config.keypoint_flip_permutation),
  122. })
  123. if step_type == 'random_vertical_flip':
  124. config = preprocessor_step_config.random_vertical_flip
  125. return (preprocessor.random_vertical_flip,
  126. {
  127. 'keypoint_flip_permutation': tuple(
  128. config.keypoint_flip_permutation),
  129. })
  130. if step_type == 'random_rotation90':
  131. return (preprocessor.random_rotation90, {})
  132. if step_type == 'random_crop_image':
  133. config = preprocessor_step_config.random_crop_image
  134. return (preprocessor.random_crop_image,
  135. {
  136. 'min_object_covered': config.min_object_covered,
  137. 'aspect_ratio_range': (config.min_aspect_ratio,
  138. config.max_aspect_ratio),
  139. 'area_range': (config.min_area, config.max_area),
  140. 'overlap_thresh': config.overlap_thresh,
  141. 'clip_boxes': config.clip_boxes,
  142. 'random_coef': config.random_coef,
  143. })
  144. if step_type == 'random_pad_image':
  145. config = preprocessor_step_config.random_pad_image
  146. min_image_size = None
  147. if (config.HasField('min_image_height') !=
  148. config.HasField('min_image_width')):
  149. raise ValueError('min_image_height and min_image_width should be either '
  150. 'both set or both unset.')
  151. if config.HasField('min_image_height'):
  152. min_image_size = (config.min_image_height, config.min_image_width)
  153. max_image_size = None
  154. if (config.HasField('max_image_height') !=
  155. config.HasField('max_image_width')):
  156. raise ValueError('max_image_height and max_image_width should be either '
  157. 'both set or both unset.')
  158. if config.HasField('max_image_height'):
  159. max_image_size = (config.max_image_height, config.max_image_width)
  160. pad_color = config.pad_color or None
  161. if pad_color:
  162. if len(pad_color) != 3:
  163. tf.logging.warn('pad_color should have 3 elements (RGB) if set!')
  164. pad_color = tf.cast([x for x in config.pad_color], dtype=tf.float32)
  165. return (preprocessor.random_pad_image,
  166. {
  167. 'min_image_size': min_image_size,
  168. 'max_image_size': max_image_size,
  169. 'pad_color': pad_color,
  170. })
  171. if step_type == 'random_absolute_pad_image':
  172. config = preprocessor_step_config.random_absolute_pad_image
  173. max_height_padding = config.max_height_padding or 1
  174. max_width_padding = config.max_width_padding or 1
  175. pad_color = config.pad_color or None
  176. if pad_color:
  177. if len(pad_color) != 3:
  178. tf.logging.warn('pad_color should have 3 elements (RGB) if set!')
  179. pad_color = tf.cast([x for x in config.pad_color], dtype=tf.float32)
  180. return (preprocessor.random_absolute_pad_image,
  181. {
  182. 'max_height_padding': max_height_padding,
  183. 'max_width_padding': max_width_padding,
  184. 'pad_color': pad_color,
  185. })
  186. if step_type == 'random_crop_pad_image':
  187. config = preprocessor_step_config.random_crop_pad_image
  188. min_padded_size_ratio = config.min_padded_size_ratio
  189. if min_padded_size_ratio and len(min_padded_size_ratio) != 2:
  190. raise ValueError('min_padded_size_ratio should have 2 elements if set!')
  191. max_padded_size_ratio = config.max_padded_size_ratio
  192. if max_padded_size_ratio and len(max_padded_size_ratio) != 2:
  193. raise ValueError('max_padded_size_ratio should have 2 elements if set!')
  194. pad_color = config.pad_color or None
  195. if pad_color:
  196. if len(pad_color) != 3:
  197. tf.logging.warn('pad_color should have 3 elements (RGB) if set!')
  198. pad_color = tf.cast([x for x in config.pad_color], dtype=tf.float32)
  199. kwargs = {
  200. 'min_object_covered': config.min_object_covered,
  201. 'aspect_ratio_range': (config.min_aspect_ratio,
  202. config.max_aspect_ratio),
  203. 'area_range': (config.min_area, config.max_area),
  204. 'overlap_thresh': config.overlap_thresh,
  205. 'clip_boxes': config.clip_boxes,
  206. 'random_coef': config.random_coef,
  207. 'pad_color': pad_color,
  208. }
  209. if min_padded_size_ratio:
  210. kwargs['min_padded_size_ratio'] = tuple(min_padded_size_ratio)
  211. if max_padded_size_ratio:
  212. kwargs['max_padded_size_ratio'] = tuple(max_padded_size_ratio)
  213. return (preprocessor.random_crop_pad_image, kwargs)
  214. if step_type == 'random_resize_method':
  215. config = preprocessor_step_config.random_resize_method
  216. return (preprocessor.random_resize_method,
  217. {
  218. 'target_size': [config.target_height, config.target_width],
  219. })
  220. if step_type == 'resize_image':
  221. config = preprocessor_step_config.resize_image
  222. method = RESIZE_METHOD_MAP[config.method]
  223. return (preprocessor.resize_image,
  224. {
  225. 'new_height': config.new_height,
  226. 'new_width': config.new_width,
  227. 'method': method
  228. })
  229. if step_type == 'random_self_concat_image':
  230. config = preprocessor_step_config.random_self_concat_image
  231. return (preprocessor.random_self_concat_image, {
  232. 'concat_vertical_probability': config.concat_vertical_probability,
  233. 'concat_horizontal_probability': config.concat_horizontal_probability
  234. })
  235. if step_type == 'ssd_random_crop':
  236. config = preprocessor_step_config.ssd_random_crop
  237. if config.operations:
  238. min_object_covered = [op.min_object_covered for op in config.operations]
  239. aspect_ratio_range = [(op.min_aspect_ratio, op.max_aspect_ratio)
  240. for op in config.operations]
  241. area_range = [(op.min_area, op.max_area) for op in config.operations]
  242. overlap_thresh = [op.overlap_thresh for op in config.operations]
  243. clip_boxes = [op.clip_boxes for op in config.operations]
  244. random_coef = [op.random_coef for op in config.operations]
  245. return (preprocessor.ssd_random_crop,
  246. {
  247. 'min_object_covered': min_object_covered,
  248. 'aspect_ratio_range': aspect_ratio_range,
  249. 'area_range': area_range,
  250. 'overlap_thresh': overlap_thresh,
  251. 'clip_boxes': clip_boxes,
  252. 'random_coef': random_coef,
  253. })
  254. return (preprocessor.ssd_random_crop, {})
  255. if step_type == 'ssd_random_crop_pad':
  256. config = preprocessor_step_config.ssd_random_crop_pad
  257. if config.operations:
  258. min_object_covered = [op.min_object_covered for op in config.operations]
  259. aspect_ratio_range = [(op.min_aspect_ratio, op.max_aspect_ratio)
  260. for op in config.operations]
  261. area_range = [(op.min_area, op.max_area) for op in config.operations]
  262. overlap_thresh = [op.overlap_thresh for op in config.operations]
  263. clip_boxes = [op.clip_boxes for op in config.operations]
  264. random_coef = [op.random_coef for op in config.operations]
  265. min_padded_size_ratio = [tuple(op.min_padded_size_ratio)
  266. for op in config.operations]
  267. max_padded_size_ratio = [tuple(op.max_padded_size_ratio)
  268. for op in config.operations]
  269. pad_color = [(op.pad_color_r, op.pad_color_g, op.pad_color_b)
  270. for op in config.operations]
  271. return (preprocessor.ssd_random_crop_pad,
  272. {
  273. 'min_object_covered': min_object_covered,
  274. 'aspect_ratio_range': aspect_ratio_range,
  275. 'area_range': area_range,
  276. 'overlap_thresh': overlap_thresh,
  277. 'clip_boxes': clip_boxes,
  278. 'random_coef': random_coef,
  279. 'min_padded_size_ratio': min_padded_size_ratio,
  280. 'max_padded_size_ratio': max_padded_size_ratio,
  281. 'pad_color': pad_color,
  282. })
  283. return (preprocessor.ssd_random_crop_pad, {})
  284. if step_type == 'ssd_random_crop_fixed_aspect_ratio':
  285. config = preprocessor_step_config.ssd_random_crop_fixed_aspect_ratio
  286. if config.operations:
  287. min_object_covered = [op.min_object_covered for op in config.operations]
  288. area_range = [(op.min_area, op.max_area) for op in config.operations]
  289. overlap_thresh = [op.overlap_thresh for op in config.operations]
  290. clip_boxes = [op.clip_boxes for op in config.operations]
  291. random_coef = [op.random_coef for op in config.operations]
  292. return (preprocessor.ssd_random_crop_fixed_aspect_ratio,
  293. {
  294. 'min_object_covered': min_object_covered,
  295. 'aspect_ratio': config.aspect_ratio,
  296. 'area_range': area_range,
  297. 'overlap_thresh': overlap_thresh,
  298. 'clip_boxes': clip_boxes,
  299. 'random_coef': random_coef,
  300. })
  301. return (preprocessor.ssd_random_crop_fixed_aspect_ratio, {})
  302. if step_type == 'ssd_random_crop_pad_fixed_aspect_ratio':
  303. config = preprocessor_step_config.ssd_random_crop_pad_fixed_aspect_ratio
  304. kwargs = {}
  305. aspect_ratio = config.aspect_ratio
  306. if aspect_ratio:
  307. kwargs['aspect_ratio'] = aspect_ratio
  308. min_padded_size_ratio = config.min_padded_size_ratio
  309. if min_padded_size_ratio:
  310. if len(min_padded_size_ratio) != 2:
  311. raise ValueError('min_padded_size_ratio should have 2 elements if set!')
  312. kwargs['min_padded_size_ratio'] = tuple(min_padded_size_ratio)
  313. max_padded_size_ratio = config.max_padded_size_ratio
  314. if max_padded_size_ratio:
  315. if len(max_padded_size_ratio) != 2:
  316. raise ValueError('max_padded_size_ratio should have 2 elements if set!')
  317. kwargs['max_padded_size_ratio'] = tuple(max_padded_size_ratio)
  318. if config.operations:
  319. kwargs['min_object_covered'] = [op.min_object_covered
  320. for op in config.operations]
  321. kwargs['aspect_ratio_range'] = [(op.min_aspect_ratio, op.max_aspect_ratio)
  322. for op in config.operations]
  323. kwargs['area_range'] = [(op.min_area, op.max_area)
  324. for op in config.operations]
  325. kwargs['overlap_thresh'] = [op.overlap_thresh for op in config.operations]
  326. kwargs['clip_boxes'] = [op.clip_boxes for op in config.operations]
  327. kwargs['random_coef'] = [op.random_coef for op in config.operations]
  328. return (preprocessor.ssd_random_crop_pad_fixed_aspect_ratio, kwargs)
  329. raise ValueError('Unknown preprocessing step.')