You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

712 lines
30 KiB

  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Model input function for tf-learn object detection model."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import functools
  20. import tensorflow as tf
  21. from object_detection.builders import dataset_builder
  22. from object_detection.builders import image_resizer_builder
  23. from object_detection.builders import model_builder
  24. from object_detection.builders import preprocessor_builder
  25. from object_detection.core import preprocessor
  26. from object_detection.core import standard_fields as fields
  27. from object_detection.data_decoders import tf_example_decoder
  28. from object_detection.protos import eval_pb2
  29. from object_detection.protos import input_reader_pb2
  30. from object_detection.protos import model_pb2
  31. from object_detection.protos import train_pb2
  32. from object_detection.utils import config_util
  33. from object_detection.utils import ops as util_ops
  34. from object_detection.utils import shape_utils
  35. HASH_KEY = 'hash'
  36. HASH_BINS = 1 << 31
  37. SERVING_FED_EXAMPLE_KEY = 'serialized_example'
  38. # A map of names to methods that help build the input pipeline.
  39. INPUT_BUILDER_UTIL_MAP = {
  40. 'dataset_build': dataset_builder.build,
  41. 'model_build': model_builder.build,
  42. }
  43. def transform_input_data(tensor_dict,
  44. model_preprocess_fn,
  45. image_resizer_fn,
  46. num_classes,
  47. data_augmentation_fn=None,
  48. merge_multiple_boxes=False,
  49. retain_original_image=False,
  50. use_multiclass_scores=False,
  51. use_bfloat16=False):
  52. """A single function that is responsible for all input data transformations.
  53. Data transformation functions are applied in the following order.
  54. 1. If key fields.InputDataFields.image_additional_channels is present in
  55. tensor_dict, the additional channels will be merged into
  56. fields.InputDataFields.image.
  57. 2. data_augmentation_fn (optional): applied on tensor_dict.
  58. 3. model_preprocess_fn: applied only on image tensor in tensor_dict.
  59. 4. image_resizer_fn: applied on original image and instance mask tensor in
  60. tensor_dict.
  61. 5. one_hot_encoding: applied to classes tensor in tensor_dict.
  62. 6. merge_multiple_boxes (optional): when groundtruth boxes are exactly the
  63. same they can be merged into a single box with an associated k-hot class
  64. label.
  65. Args:
  66. tensor_dict: dictionary containing input tensors keyed by
  67. fields.InputDataFields.
  68. model_preprocess_fn: model's preprocess function to apply on image tensor.
  69. This function must take in a 4-D float tensor and return a 4-D preprocess
  70. float tensor and a tensor containing the true image shape.
  71. image_resizer_fn: image resizer function to apply on groundtruth instance
  72. `masks. This function must take a 3-D float tensor of an image and a 3-D
  73. tensor of instance masks and return a resized version of these along with
  74. the true shapes.
  75. num_classes: number of max classes to one-hot (or k-hot) encode the class
  76. labels.
  77. data_augmentation_fn: (optional) data augmentation function to apply on
  78. input `tensor_dict`.
  79. merge_multiple_boxes: (optional) whether to merge multiple groundtruth boxes
  80. and classes for a given image if the boxes are exactly the same.
  81. retain_original_image: (optional) whether to retain original image in the
  82. output dictionary.
  83. use_multiclass_scores: whether to use multiclass scores as
  84. class targets instead of one-hot encoding of `groundtruth_classes`.
  85. use_bfloat16: (optional) a bool, whether to use bfloat16 in training.
  86. Returns:
  87. A dictionary keyed by fields.InputDataFields containing the tensors obtained
  88. after applying all the transformations.
  89. """
  90. # Reshape flattened multiclass scores tensor into a 2D tensor of shape
  91. # [num_boxes, num_classes].
  92. if fields.InputDataFields.multiclass_scores in tensor_dict:
  93. tensor_dict[fields.InputDataFields.multiclass_scores] = tf.reshape(
  94. tensor_dict[fields.InputDataFields.multiclass_scores], [
  95. tf.shape(tensor_dict[fields.InputDataFields.groundtruth_boxes])[0],
  96. num_classes
  97. ])
  98. if fields.InputDataFields.groundtruth_boxes in tensor_dict:
  99. tensor_dict = util_ops.filter_groundtruth_with_nan_box_coordinates(
  100. tensor_dict)
  101. tensor_dict = util_ops.filter_unrecognized_classes(tensor_dict)
  102. if retain_original_image:
  103. tensor_dict[fields.InputDataFields.original_image] = tf.cast(
  104. image_resizer_fn(tensor_dict[fields.InputDataFields.image], None)[0],
  105. tf.uint8)
  106. if fields.InputDataFields.image_additional_channels in tensor_dict:
  107. channels = tensor_dict[fields.InputDataFields.image_additional_channels]
  108. tensor_dict[fields.InputDataFields.image] = tf.concat(
  109. [tensor_dict[fields.InputDataFields.image], channels], axis=2)
  110. # Apply data augmentation ops.
  111. if data_augmentation_fn is not None:
  112. tensor_dict = data_augmentation_fn(tensor_dict)
  113. # Apply model preprocessing ops and resize instance masks.
  114. image = tensor_dict[fields.InputDataFields.image]
  115. preprocessed_resized_image, true_image_shape = model_preprocess_fn(
  116. tf.expand_dims(tf.cast(image, dtype=tf.float32), axis=0))
  117. if use_bfloat16:
  118. preprocessed_resized_image = tf.cast(
  119. preprocessed_resized_image, tf.bfloat16)
  120. tensor_dict[fields.InputDataFields.image] = tf.squeeze(
  121. preprocessed_resized_image, axis=0)
  122. tensor_dict[fields.InputDataFields.true_image_shape] = tf.squeeze(
  123. true_image_shape, axis=0)
  124. if fields.InputDataFields.groundtruth_instance_masks in tensor_dict:
  125. masks = tensor_dict[fields.InputDataFields.groundtruth_instance_masks]
  126. _, resized_masks, _ = image_resizer_fn(image, masks)
  127. if use_bfloat16:
  128. resized_masks = tf.cast(resized_masks, tf.bfloat16)
  129. tensor_dict[fields.InputDataFields.
  130. groundtruth_instance_masks] = resized_masks
  131. # Transform groundtruth classes to one hot encodings.
  132. label_offset = 1
  133. zero_indexed_groundtruth_classes = tensor_dict[
  134. fields.InputDataFields.groundtruth_classes] - label_offset
  135. tensor_dict[fields.InputDataFields.groundtruth_classes] = tf.one_hot(
  136. zero_indexed_groundtruth_classes, num_classes)
  137. if use_multiclass_scores:
  138. tensor_dict[fields.InputDataFields.groundtruth_classes] = tensor_dict[
  139. fields.InputDataFields.multiclass_scores]
  140. tensor_dict.pop(fields.InputDataFields.multiclass_scores, None)
  141. if fields.InputDataFields.groundtruth_confidences in tensor_dict:
  142. groundtruth_confidences = tensor_dict[
  143. fields.InputDataFields.groundtruth_confidences]
  144. # Map the confidences to the one-hot encoding of classes
  145. tensor_dict[fields.InputDataFields.groundtruth_confidences] = (
  146. tf.reshape(groundtruth_confidences, [-1, 1]) *
  147. tensor_dict[fields.InputDataFields.groundtruth_classes])
  148. else:
  149. groundtruth_confidences = tf.ones_like(
  150. zero_indexed_groundtruth_classes, dtype=tf.float32)
  151. tensor_dict[fields.InputDataFields.groundtruth_confidences] = (
  152. tensor_dict[fields.InputDataFields.groundtruth_classes])
  153. if merge_multiple_boxes:
  154. merged_boxes, merged_classes, merged_confidences, _ = (
  155. util_ops.merge_boxes_with_multiple_labels(
  156. tensor_dict[fields.InputDataFields.groundtruth_boxes],
  157. zero_indexed_groundtruth_classes,
  158. groundtruth_confidences,
  159. num_classes))
  160. merged_classes = tf.cast(merged_classes, tf.float32)
  161. tensor_dict[fields.InputDataFields.groundtruth_boxes] = merged_boxes
  162. tensor_dict[fields.InputDataFields.groundtruth_classes] = merged_classes
  163. tensor_dict[fields.InputDataFields.groundtruth_confidences] = (
  164. merged_confidences)
  165. if fields.InputDataFields.groundtruth_boxes in tensor_dict:
  166. tensor_dict[fields.InputDataFields.num_groundtruth_boxes] = tf.shape(
  167. tensor_dict[fields.InputDataFields.groundtruth_boxes])[0]
  168. return tensor_dict
  169. def pad_input_data_to_static_shapes(tensor_dict, max_num_boxes, num_classes,
  170. spatial_image_shape=None):
  171. """Pads input tensors to static shapes.
  172. In case num_additional_channels > 0, we assume that the additional channels
  173. have already been concatenated to the base image.
  174. Args:
  175. tensor_dict: Tensor dictionary of input data
  176. max_num_boxes: Max number of groundtruth boxes needed to compute shapes for
  177. padding.
  178. num_classes: Number of classes in the dataset needed to compute shapes for
  179. padding.
  180. spatial_image_shape: A list of two integers of the form [height, width]
  181. containing expected spatial shape of the image.
  182. Returns:
  183. A dictionary keyed by fields.InputDataFields containing padding shapes for
  184. tensors in the dataset.
  185. Raises:
  186. ValueError: If groundtruth classes is neither rank 1 nor rank 2, or if we
  187. detect that additional channels have not been concatenated yet.
  188. """
  189. if not spatial_image_shape or spatial_image_shape == [-1, -1]:
  190. height, width = None, None
  191. else:
  192. height, width = spatial_image_shape # pylint: disable=unpacking-non-sequence
  193. num_additional_channels = 0
  194. if fields.InputDataFields.image_additional_channels in tensor_dict:
  195. num_additional_channels = shape_utils.get_dim_as_int(tensor_dict[
  196. fields.InputDataFields.image_additional_channels].shape[2])
  197. # We assume that if num_additional_channels > 0, then it has already been
  198. # concatenated to the base image (but not the ground truth).
  199. num_channels = 3
  200. if fields.InputDataFields.image in tensor_dict:
  201. num_channels = shape_utils.get_dim_as_int(
  202. tensor_dict[fields.InputDataFields.image].shape[2])
  203. if num_additional_channels:
  204. if num_additional_channels >= num_channels:
  205. raise ValueError(
  206. 'Image must be already concatenated with additional channels.')
  207. if (fields.InputDataFields.original_image in tensor_dict and
  208. shape_utils.get_dim_as_int(
  209. tensor_dict[fields.InputDataFields.original_image].shape[2]) ==
  210. num_channels):
  211. raise ValueError(
  212. 'Image must be already concatenated with additional channels.')
  213. padding_shapes = {
  214. fields.InputDataFields.image: [
  215. height, width, num_channels
  216. ],
  217. fields.InputDataFields.original_image_spatial_shape: [2],
  218. fields.InputDataFields.image_additional_channels: [
  219. height, width, num_additional_channels
  220. ],
  221. fields.InputDataFields.source_id: [],
  222. fields.InputDataFields.filename: [],
  223. fields.InputDataFields.key: [],
  224. fields.InputDataFields.groundtruth_difficult: [max_num_boxes],
  225. fields.InputDataFields.groundtruth_boxes: [max_num_boxes, 4],
  226. fields.InputDataFields.groundtruth_classes: [max_num_boxes, num_classes],
  227. fields.InputDataFields.groundtruth_instance_masks: [
  228. max_num_boxes, height, width
  229. ],
  230. fields.InputDataFields.groundtruth_is_crowd: [max_num_boxes],
  231. fields.InputDataFields.groundtruth_group_of: [max_num_boxes],
  232. fields.InputDataFields.groundtruth_area: [max_num_boxes],
  233. fields.InputDataFields.groundtruth_weights: [max_num_boxes],
  234. fields.InputDataFields.groundtruth_confidences: [
  235. max_num_boxes, num_classes
  236. ],
  237. fields.InputDataFields.num_groundtruth_boxes: [],
  238. fields.InputDataFields.groundtruth_label_types: [max_num_boxes],
  239. fields.InputDataFields.groundtruth_label_weights: [max_num_boxes],
  240. fields.InputDataFields.true_image_shape: [3],
  241. fields.InputDataFields.groundtruth_image_classes: [num_classes],
  242. fields.InputDataFields.groundtruth_image_confidences: [num_classes],
  243. }
  244. if fields.InputDataFields.original_image in tensor_dict:
  245. padding_shapes[fields.InputDataFields.original_image] = [
  246. height, width,
  247. shape_utils.get_dim_as_int(tensor_dict[fields.InputDataFields.
  248. original_image].shape[2])
  249. ]
  250. if fields.InputDataFields.groundtruth_keypoints in tensor_dict:
  251. tensor_shape = (
  252. tensor_dict[fields.InputDataFields.groundtruth_keypoints].shape)
  253. padding_shape = [max_num_boxes,
  254. shape_utils.get_dim_as_int(tensor_shape[1]),
  255. shape_utils.get_dim_as_int(tensor_shape[2])]
  256. padding_shapes[fields.InputDataFields.groundtruth_keypoints] = padding_shape
  257. if fields.InputDataFields.groundtruth_keypoint_visibilities in tensor_dict:
  258. tensor_shape = tensor_dict[fields.InputDataFields.
  259. groundtruth_keypoint_visibilities].shape
  260. padding_shape = [max_num_boxes, shape_utils.get_dim_as_int(tensor_shape[1])]
  261. padding_shapes[fields.InputDataFields.
  262. groundtruth_keypoint_visibilities] = padding_shape
  263. padded_tensor_dict = {}
  264. for tensor_name in tensor_dict:
  265. padded_tensor_dict[tensor_name] = shape_utils.pad_or_clip_nd(
  266. tensor_dict[tensor_name], padding_shapes[tensor_name])
  267. # Make sure that the number of groundtruth boxes now reflects the
  268. # padded/clipped tensors.
  269. if fields.InputDataFields.num_groundtruth_boxes in padded_tensor_dict:
  270. padded_tensor_dict[fields.InputDataFields.num_groundtruth_boxes] = (
  271. tf.minimum(
  272. padded_tensor_dict[fields.InputDataFields.num_groundtruth_boxes],
  273. max_num_boxes))
  274. return padded_tensor_dict
  275. def augment_input_data(tensor_dict, data_augmentation_options):
  276. """Applies data augmentation ops to input tensors.
  277. Args:
  278. tensor_dict: A dictionary of input tensors keyed by fields.InputDataFields.
  279. data_augmentation_options: A list of tuples, where each tuple contains a
  280. function and a dictionary that contains arguments and their values.
  281. Usually, this is the output of core/preprocessor.build.
  282. Returns:
  283. A dictionary of tensors obtained by applying data augmentation ops to the
  284. input tensor dictionary.
  285. """
  286. tensor_dict[fields.InputDataFields.image] = tf.expand_dims(
  287. tf.cast(tensor_dict[fields.InputDataFields.image], dtype=tf.float32), 0)
  288. include_instance_masks = (fields.InputDataFields.groundtruth_instance_masks
  289. in tensor_dict)
  290. include_keypoints = (fields.InputDataFields.groundtruth_keypoints
  291. in tensor_dict)
  292. include_label_weights = (fields.InputDataFields.groundtruth_weights
  293. in tensor_dict)
  294. include_label_confidences = (fields.InputDataFields.groundtruth_confidences
  295. in tensor_dict)
  296. include_multiclass_scores = (fields.InputDataFields.multiclass_scores in
  297. tensor_dict)
  298. tensor_dict = preprocessor.preprocess(
  299. tensor_dict, data_augmentation_options,
  300. func_arg_map=preprocessor.get_default_func_arg_map(
  301. include_label_weights=include_label_weights,
  302. include_label_confidences=include_label_confidences,
  303. include_multiclass_scores=include_multiclass_scores,
  304. include_instance_masks=include_instance_masks,
  305. include_keypoints=include_keypoints))
  306. tensor_dict[fields.InputDataFields.image] = tf.squeeze(
  307. tensor_dict[fields.InputDataFields.image], axis=0)
  308. return tensor_dict
  309. def _get_labels_dict(input_dict):
  310. """Extracts labels dict from input dict."""
  311. required_label_keys = [
  312. fields.InputDataFields.num_groundtruth_boxes,
  313. fields.InputDataFields.groundtruth_boxes,
  314. fields.InputDataFields.groundtruth_classes,
  315. fields.InputDataFields.groundtruth_weights,
  316. ]
  317. labels_dict = {}
  318. for key in required_label_keys:
  319. labels_dict[key] = input_dict[key]
  320. optional_label_keys = [
  321. fields.InputDataFields.groundtruth_confidences,
  322. fields.InputDataFields.groundtruth_keypoints,
  323. fields.InputDataFields.groundtruth_instance_masks,
  324. fields.InputDataFields.groundtruth_area,
  325. fields.InputDataFields.groundtruth_is_crowd,
  326. fields.InputDataFields.groundtruth_difficult
  327. ]
  328. for key in optional_label_keys:
  329. if key in input_dict:
  330. labels_dict[key] = input_dict[key]
  331. if fields.InputDataFields.groundtruth_difficult in labels_dict:
  332. labels_dict[fields.InputDataFields.groundtruth_difficult] = tf.cast(
  333. labels_dict[fields.InputDataFields.groundtruth_difficult], tf.int32)
  334. return labels_dict
  335. def _replace_empty_string_with_random_number(string_tensor):
  336. """Returns string unchanged if non-empty, and random string tensor otherwise.
  337. The random string is an integer 0 and 2**63 - 1, casted as string.
  338. Args:
  339. string_tensor: A tf.tensor of dtype string.
  340. Returns:
  341. out_string: A tf.tensor of dtype string. If string_tensor contains the empty
  342. string, out_string will contain a random integer casted to a string.
  343. Otherwise string_tensor is returned unchanged.
  344. """
  345. empty_string = tf.constant('', dtype=tf.string, name='EmptyString')
  346. random_source_id = tf.as_string(
  347. tf.random_uniform(shape=[], maxval=2**63 - 1, dtype=tf.int64))
  348. out_string = tf.cond(
  349. tf.equal(string_tensor, empty_string),
  350. true_fn=lambda: random_source_id,
  351. false_fn=lambda: string_tensor)
  352. return out_string
  353. def _get_features_dict(input_dict):
  354. """Extracts features dict from input dict."""
  355. source_id = _replace_empty_string_with_random_number(
  356. input_dict[fields.InputDataFields.source_id])
  357. hash_from_source_id = tf.string_to_hash_bucket_fast(source_id, HASH_BINS)
  358. features = {
  359. fields.InputDataFields.image:
  360. input_dict[fields.InputDataFields.image],
  361. HASH_KEY: tf.cast(hash_from_source_id, tf.int32),
  362. fields.InputDataFields.true_image_shape:
  363. input_dict[fields.InputDataFields.true_image_shape],
  364. fields.InputDataFields.original_image_spatial_shape:
  365. input_dict[fields.InputDataFields.original_image_spatial_shape]
  366. }
  367. if fields.InputDataFields.original_image in input_dict:
  368. features[fields.InputDataFields.original_image] = input_dict[
  369. fields.InputDataFields.original_image]
  370. return features
  371. def create_train_input_fn(train_config, train_input_config,
  372. model_config):
  373. """Creates a train `input` function for `Estimator`.
  374. Args:
  375. train_config: A train_pb2.TrainConfig.
  376. train_input_config: An input_reader_pb2.InputReader.
  377. model_config: A model_pb2.DetectionModel.
  378. Returns:
  379. `input_fn` for `Estimator` in TRAIN mode.
  380. """
  381. def _train_input_fn(params=None):
  382. return train_input(train_config, train_input_config, model_config,
  383. params=params)
  384. return _train_input_fn
  385. def train_input(train_config, train_input_config,
  386. model_config, model=None, params=None):
  387. """Returns `features` and `labels` tensor dictionaries for training.
  388. Args:
  389. train_config: A train_pb2.TrainConfig.
  390. train_input_config: An input_reader_pb2.InputReader.
  391. model_config: A model_pb2.DetectionModel.
  392. model: A pre-constructed Detection Model.
  393. If None, one will be created from the config.
  394. params: Parameter dictionary passed from the estimator.
  395. Returns:
  396. A tf.data.Dataset that holds (features, labels) tuple.
  397. features: Dictionary of feature tensors.
  398. features[fields.InputDataFields.image] is a [batch_size, H, W, C]
  399. float32 tensor with preprocessed images.
  400. features[HASH_KEY] is a [batch_size] int32 tensor representing unique
  401. identifiers for the images.
  402. features[fields.InputDataFields.true_image_shape] is a [batch_size, 3]
  403. int32 tensor representing the true image shapes, as preprocessed
  404. images could be padded.
  405. features[fields.InputDataFields.original_image] (optional) is a
  406. [batch_size, H, W, C] float32 tensor with original images.
  407. labels: Dictionary of groundtruth tensors.
  408. labels[fields.InputDataFields.num_groundtruth_boxes] is a [batch_size]
  409. int32 tensor indicating the number of groundtruth boxes.
  410. labels[fields.InputDataFields.groundtruth_boxes] is a
  411. [batch_size, num_boxes, 4] float32 tensor containing the corners of
  412. the groundtruth boxes.
  413. labels[fields.InputDataFields.groundtruth_classes] is a
  414. [batch_size, num_boxes, num_classes] float32 one-hot tensor of
  415. classes.
  416. labels[fields.InputDataFields.groundtruth_weights] is a
  417. [batch_size, num_boxes] float32 tensor containing groundtruth weights
  418. for the boxes.
  419. -- Optional --
  420. labels[fields.InputDataFields.groundtruth_instance_masks] is a
  421. [batch_size, num_boxes, H, W] float32 tensor containing only binary
  422. values, which represent instance masks for objects.
  423. labels[fields.InputDataFields.groundtruth_keypoints] is a
  424. [batch_size, num_boxes, num_keypoints, 2] float32 tensor containing
  425. keypoints for each box.
  426. Raises:
  427. TypeError: if the `train_config`, `train_input_config` or `model_config`
  428. are not of the correct type.
  429. """
  430. if not isinstance(train_config, train_pb2.TrainConfig):
  431. raise TypeError('For training mode, the `train_config` must be a '
  432. 'train_pb2.TrainConfig.')
  433. if not isinstance(train_input_config, input_reader_pb2.InputReader):
  434. raise TypeError('The `train_input_config` must be a '
  435. 'input_reader_pb2.InputReader.')
  436. if not isinstance(model_config, model_pb2.DetectionModel):
  437. raise TypeError('The `model_config` must be a '
  438. 'model_pb2.DetectionModel.')
  439. if model is None:
  440. model_preprocess_fn = INPUT_BUILDER_UTIL_MAP['model_build'](
  441. model_config, is_training=True).preprocess
  442. else:
  443. model_preprocess_fn = model.preprocess
  444. def transform_and_pad_input_data_fn(tensor_dict):
  445. """Combines transform and pad operation."""
  446. data_augmentation_options = [
  447. preprocessor_builder.build(step)
  448. for step in train_config.data_augmentation_options
  449. ]
  450. data_augmentation_fn = functools.partial(
  451. augment_input_data,
  452. data_augmentation_options=data_augmentation_options)
  453. image_resizer_config = config_util.get_image_resizer_config(model_config)
  454. image_resizer_fn = image_resizer_builder.build(image_resizer_config)
  455. transform_data_fn = functools.partial(
  456. transform_input_data, model_preprocess_fn=model_preprocess_fn,
  457. image_resizer_fn=image_resizer_fn,
  458. num_classes=config_util.get_number_of_classes(model_config),
  459. data_augmentation_fn=data_augmentation_fn,
  460. merge_multiple_boxes=train_config.merge_multiple_label_boxes,
  461. retain_original_image=train_config.retain_original_images,
  462. use_multiclass_scores=train_config.use_multiclass_scores,
  463. use_bfloat16=train_config.use_bfloat16)
  464. tensor_dict = pad_input_data_to_static_shapes(
  465. tensor_dict=transform_data_fn(tensor_dict),
  466. max_num_boxes=train_input_config.max_number_of_boxes,
  467. num_classes=config_util.get_number_of_classes(model_config),
  468. spatial_image_shape=config_util.get_spatial_image_size(
  469. image_resizer_config))
  470. return (_get_features_dict(tensor_dict), _get_labels_dict(tensor_dict))
  471. dataset = INPUT_BUILDER_UTIL_MAP['dataset_build'](
  472. train_input_config,
  473. transform_input_data_fn=transform_and_pad_input_data_fn,
  474. batch_size=params['batch_size'] if params else train_config.batch_size)
  475. return dataset
  476. def create_eval_input_fn(eval_config, eval_input_config, model_config):
  477. """Creates an eval `input` function for `Estimator`.
  478. Args:
  479. eval_config: An eval_pb2.EvalConfig.
  480. eval_input_config: An input_reader_pb2.InputReader.
  481. model_config: A model_pb2.DetectionModel.
  482. Returns:
  483. `input_fn` for `Estimator` in EVAL mode.
  484. """
  485. def _eval_input_fn(params=None):
  486. return eval_input(eval_config, eval_input_config, model_config,
  487. params=params)
  488. return _eval_input_fn
  489. def eval_input(eval_config, eval_input_config, model_config,
  490. model=None, params=None):
  491. """Returns `features` and `labels` tensor dictionaries for evaluation.
  492. Args:
  493. eval_config: An eval_pb2.EvalConfig.
  494. eval_input_config: An input_reader_pb2.InputReader.
  495. model_config: A model_pb2.DetectionModel.
  496. model: A pre-constructed Detection Model.
  497. If None, one will be created from the config.
  498. params: Parameter dictionary passed from the estimator.
  499. Returns:
  500. A tf.data.Dataset that holds (features, labels) tuple.
  501. features: Dictionary of feature tensors.
  502. features[fields.InputDataFields.image] is a [1, H, W, C] float32 tensor
  503. with preprocessed images.
  504. features[HASH_KEY] is a [1] int32 tensor representing unique
  505. identifiers for the images.
  506. features[fields.InputDataFields.true_image_shape] is a [1, 3]
  507. int32 tensor representing the true image shapes, as preprocessed
  508. images could be padded.
  509. features[fields.InputDataFields.original_image] is a [1, H', W', C]
  510. float32 tensor with the original image.
  511. labels: Dictionary of groundtruth tensors.
  512. labels[fields.InputDataFields.groundtruth_boxes] is a [1, num_boxes, 4]
  513. float32 tensor containing the corners of the groundtruth boxes.
  514. labels[fields.InputDataFields.groundtruth_classes] is a
  515. [num_boxes, num_classes] float32 one-hot tensor of classes.
  516. labels[fields.InputDataFields.groundtruth_area] is a [1, num_boxes]
  517. float32 tensor containing object areas.
  518. labels[fields.InputDataFields.groundtruth_is_crowd] is a [1, num_boxes]
  519. bool tensor indicating if the boxes enclose a crowd.
  520. labels[fields.InputDataFields.groundtruth_difficult] is a [1, num_boxes]
  521. int32 tensor indicating if the boxes represent difficult instances.
  522. -- Optional --
  523. labels[fields.InputDataFields.groundtruth_instance_masks] is a
  524. [1, num_boxes, H, W] float32 tensor containing only binary values,
  525. which represent instance masks for objects.
  526. Raises:
  527. TypeError: if the `eval_config`, `eval_input_config` or `model_config`
  528. are not of the correct type.
  529. """
  530. params = params or {}
  531. if not isinstance(eval_config, eval_pb2.EvalConfig):
  532. raise TypeError('For eval mode, the `eval_config` must be a '
  533. 'train_pb2.EvalConfig.')
  534. if not isinstance(eval_input_config, input_reader_pb2.InputReader):
  535. raise TypeError('The `eval_input_config` must be a '
  536. 'input_reader_pb2.InputReader.')
  537. if not isinstance(model_config, model_pb2.DetectionModel):
  538. raise TypeError('The `model_config` must be a '
  539. 'model_pb2.DetectionModel.')
  540. if model is None:
  541. model_preprocess_fn = INPUT_BUILDER_UTIL_MAP['model_build'](
  542. model_config, is_training=False).preprocess
  543. else:
  544. model_preprocess_fn = model.preprocess
  545. def transform_and_pad_input_data_fn(tensor_dict):
  546. """Combines transform and pad operation."""
  547. num_classes = config_util.get_number_of_classes(model_config)
  548. image_resizer_config = config_util.get_image_resizer_config(model_config)
  549. image_resizer_fn = image_resizer_builder.build(image_resizer_config)
  550. transform_data_fn = functools.partial(
  551. transform_input_data, model_preprocess_fn=model_preprocess_fn,
  552. image_resizer_fn=image_resizer_fn,
  553. num_classes=num_classes,
  554. data_augmentation_fn=None,
  555. retain_original_image=eval_config.retain_original_images)
  556. tensor_dict = pad_input_data_to_static_shapes(
  557. tensor_dict=transform_data_fn(tensor_dict),
  558. max_num_boxes=eval_input_config.max_number_of_boxes,
  559. num_classes=config_util.get_number_of_classes(model_config),
  560. spatial_image_shape=config_util.get_spatial_image_size(
  561. image_resizer_config))
  562. return (_get_features_dict(tensor_dict), _get_labels_dict(tensor_dict))
  563. dataset = INPUT_BUILDER_UTIL_MAP['dataset_build'](
  564. eval_input_config,
  565. batch_size=params['batch_size'] if params else eval_config.batch_size,
  566. transform_input_data_fn=transform_and_pad_input_data_fn)
  567. return dataset
  568. def create_predict_input_fn(model_config, predict_input_config):
  569. """Creates a predict `input` function for `Estimator`.
  570. Args:
  571. model_config: A model_pb2.DetectionModel.
  572. predict_input_config: An input_reader_pb2.InputReader.
  573. Returns:
  574. `input_fn` for `Estimator` in PREDICT mode.
  575. """
  576. def _predict_input_fn(params=None):
  577. """Decodes serialized tf.Examples and returns `ServingInputReceiver`.
  578. Args:
  579. params: Parameter dictionary passed from the estimator.
  580. Returns:
  581. `ServingInputReceiver`.
  582. """
  583. del params
  584. example = tf.placeholder(dtype=tf.string, shape=[], name='tf_example')
  585. num_classes = config_util.get_number_of_classes(model_config)
  586. model_preprocess_fn = INPUT_BUILDER_UTIL_MAP['model_build'](
  587. model_config, is_training=False).preprocess
  588. image_resizer_config = config_util.get_image_resizer_config(model_config)
  589. image_resizer_fn = image_resizer_builder.build(image_resizer_config)
  590. transform_fn = functools.partial(
  591. transform_input_data, model_preprocess_fn=model_preprocess_fn,
  592. image_resizer_fn=image_resizer_fn,
  593. num_classes=num_classes,
  594. data_augmentation_fn=None)
  595. decoder = tf_example_decoder.TfExampleDecoder(
  596. load_instance_masks=False,
  597. num_additional_channels=predict_input_config.num_additional_channels)
  598. input_dict = transform_fn(decoder.decode(example))
  599. images = tf.cast(input_dict[fields.InputDataFields.image], dtype=tf.float32)
  600. images = tf.expand_dims(images, axis=0)
  601. true_image_shape = tf.expand_dims(
  602. input_dict[fields.InputDataFields.true_image_shape], axis=0)
  603. return tf.estimator.export.ServingInputReceiver(
  604. features={
  605. fields.InputDataFields.image: images,
  606. fields.InputDataFields.true_image_shape: true_image_shape},
  607. receiver_tensors={SERVING_FED_EXAMPLE_KEY: example})
  608. return _predict_input_fn