You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

671 lines
29 KiB

6 years ago
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Model input function for tf-learn object detection model."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import functools
  20. import tensorflow as tf
  21. from object_detection.builders import dataset_builder
  22. from object_detection.builders import image_resizer_builder
  23. from object_detection.builders import model_builder
  24. from object_detection.builders import preprocessor_builder
  25. from object_detection.core import preprocessor
  26. from object_detection.core import standard_fields as fields
  27. from object_detection.data_decoders import tf_example_decoder
  28. from object_detection.protos import eval_pb2
  29. from object_detection.protos import input_reader_pb2
  30. from object_detection.protos import model_pb2
  31. from object_detection.protos import train_pb2
  32. from object_detection.utils import config_util
  33. from object_detection.utils import ops as util_ops
  34. from object_detection.utils import shape_utils
  35. HASH_KEY = 'hash'
  36. HASH_BINS = 1 << 31
  37. SERVING_FED_EXAMPLE_KEY = 'serialized_example'
  38. # A map of names to methods that help build the input pipeline.
  39. INPUT_BUILDER_UTIL_MAP = {
  40. 'dataset_build': dataset_builder.build,
  41. }
  42. def transform_input_data(tensor_dict,
  43. model_preprocess_fn,
  44. image_resizer_fn,
  45. num_classes,
  46. data_augmentation_fn=None,
  47. merge_multiple_boxes=False,
  48. retain_original_image=False,
  49. use_multiclass_scores=False,
  50. use_bfloat16=False):
  51. """A single function that is responsible for all input data transformations.
  52. Data transformation functions are applied in the following order.
  53. 1. If key fields.InputDataFields.image_additional_channels is present in
  54. tensor_dict, the additional channels will be merged into
  55. fields.InputDataFields.image.
  56. 2. data_augmentation_fn (optional): applied on tensor_dict.
  57. 3. model_preprocess_fn: applied only on image tensor in tensor_dict.
  58. 4. image_resizer_fn: applied on original image and instance mask tensor in
  59. tensor_dict.
  60. 5. one_hot_encoding: applied to classes tensor in tensor_dict.
  61. 6. merge_multiple_boxes (optional): when groundtruth boxes are exactly the
  62. same they can be merged into a single box with an associated k-hot class
  63. label.
  64. Args:
  65. tensor_dict: dictionary containing input tensors keyed by
  66. fields.InputDataFields.
  67. model_preprocess_fn: model's preprocess function to apply on image tensor.
  68. This function must take in a 4-D float tensor and return a 4-D preprocess
  69. float tensor and a tensor containing the true image shape.
  70. image_resizer_fn: image resizer function to apply on groundtruth instance
  71. `masks. This function must take a 3-D float tensor of an image and a 3-D
  72. tensor of instance masks and return a resized version of these along with
  73. the true shapes.
  74. num_classes: number of max classes to one-hot (or k-hot) encode the class
  75. labels.
  76. data_augmentation_fn: (optional) data augmentation function to apply on
  77. input `tensor_dict`.
  78. merge_multiple_boxes: (optional) whether to merge multiple groundtruth boxes
  79. and classes for a given image if the boxes are exactly the same.
  80. retain_original_image: (optional) whether to retain original image in the
  81. output dictionary.
  82. use_multiclass_scores: whether to use multiclass scores as
  83. class targets instead of one-hot encoding of `groundtruth_classes`.
  84. use_bfloat16: (optional) a bool, whether to use bfloat16 in training.
  85. Returns:
  86. A dictionary keyed by fields.InputDataFields containing the tensors obtained
  87. after applying all the transformations.
  88. """
  89. # Reshape flattened multiclass scores tensor into a 2D tensor of shape
  90. # [num_boxes, num_classes].
  91. if fields.InputDataFields.multiclass_scores in tensor_dict:
  92. tensor_dict[fields.InputDataFields.multiclass_scores] = tf.reshape(
  93. tensor_dict[fields.InputDataFields.multiclass_scores], [
  94. tf.shape(tensor_dict[fields.InputDataFields.groundtruth_boxes])[0],
  95. num_classes
  96. ])
  97. if fields.InputDataFields.groundtruth_boxes in tensor_dict:
  98. tensor_dict = util_ops.filter_groundtruth_with_nan_box_coordinates(
  99. tensor_dict)
  100. tensor_dict = util_ops.filter_unrecognized_classes(tensor_dict)
  101. if retain_original_image:
  102. tensor_dict[fields.InputDataFields.original_image] = tf.cast(
  103. image_resizer_fn(tensor_dict[fields.InputDataFields.image], None)[0],
  104. tf.uint8)
  105. if fields.InputDataFields.image_additional_channels in tensor_dict:
  106. channels = tensor_dict[fields.InputDataFields.image_additional_channels]
  107. tensor_dict[fields.InputDataFields.image] = tf.concat(
  108. [tensor_dict[fields.InputDataFields.image], channels], axis=2)
  109. # Apply data augmentation ops.
  110. if data_augmentation_fn is not None:
  111. tensor_dict = data_augmentation_fn(tensor_dict)
  112. # Apply model preprocessing ops and resize instance masks.
  113. image = tensor_dict[fields.InputDataFields.image]
  114. preprocessed_resized_image, true_image_shape = model_preprocess_fn(
  115. tf.expand_dims(tf.to_float(image), axis=0))
  116. if use_bfloat16:
  117. preprocessed_resized_image = tf.cast(
  118. preprocessed_resized_image, tf.bfloat16)
  119. tensor_dict[fields.InputDataFields.image] = tf.squeeze(
  120. preprocessed_resized_image, axis=0)
  121. tensor_dict[fields.InputDataFields.true_image_shape] = tf.squeeze(
  122. true_image_shape, axis=0)
  123. if fields.InputDataFields.groundtruth_instance_masks in tensor_dict:
  124. masks = tensor_dict[fields.InputDataFields.groundtruth_instance_masks]
  125. _, resized_masks, _ = image_resizer_fn(image, masks)
  126. if use_bfloat16:
  127. resized_masks = tf.cast(resized_masks, tf.bfloat16)
  128. tensor_dict[fields.InputDataFields.
  129. groundtruth_instance_masks] = resized_masks
  130. # Transform groundtruth classes to one hot encodings.
  131. label_offset = 1
  132. zero_indexed_groundtruth_classes = tensor_dict[
  133. fields.InputDataFields.groundtruth_classes] - label_offset
  134. tensor_dict[fields.InputDataFields.groundtruth_classes] = tf.one_hot(
  135. zero_indexed_groundtruth_classes, num_classes)
  136. if use_multiclass_scores:
  137. tensor_dict[fields.InputDataFields.groundtruth_classes] = tensor_dict[
  138. fields.InputDataFields.multiclass_scores]
  139. tensor_dict.pop(fields.InputDataFields.multiclass_scores, None)
  140. if fields.InputDataFields.groundtruth_confidences in tensor_dict:
  141. groundtruth_confidences = tensor_dict[
  142. fields.InputDataFields.groundtruth_confidences]
  143. # Map the confidences to the one-hot encoding of classes
  144. tensor_dict[fields.InputDataFields.groundtruth_confidences] = (
  145. tf.reshape(groundtruth_confidences, [-1, 1]) *
  146. tensor_dict[fields.InputDataFields.groundtruth_classes])
  147. else:
  148. groundtruth_confidences = tf.ones_like(
  149. zero_indexed_groundtruth_classes, dtype=tf.float32)
  150. tensor_dict[fields.InputDataFields.groundtruth_confidences] = (
  151. tensor_dict[fields.InputDataFields.groundtruth_classes])
  152. if merge_multiple_boxes:
  153. merged_boxes, merged_classes, merged_confidences, _ = (
  154. util_ops.merge_boxes_with_multiple_labels(
  155. tensor_dict[fields.InputDataFields.groundtruth_boxes],
  156. zero_indexed_groundtruth_classes,
  157. groundtruth_confidences,
  158. num_classes))
  159. merged_classes = tf.cast(merged_classes, tf.float32)
  160. tensor_dict[fields.InputDataFields.groundtruth_boxes] = merged_boxes
  161. tensor_dict[fields.InputDataFields.groundtruth_classes] = merged_classes
  162. tensor_dict[fields.InputDataFields.groundtruth_confidences] = (
  163. merged_confidences)
  164. if fields.InputDataFields.groundtruth_boxes in tensor_dict:
  165. tensor_dict[fields.InputDataFields.num_groundtruth_boxes] = tf.shape(
  166. tensor_dict[fields.InputDataFields.groundtruth_boxes])[0]
  167. return tensor_dict
  168. def pad_input_data_to_static_shapes(tensor_dict, max_num_boxes, num_classes,
  169. spatial_image_shape=None):
  170. """Pads input tensors to static shapes.
  171. In case num_additional_channels > 0, we assume that the additional channels
  172. have already been concatenated to the base image.
  173. Args:
  174. tensor_dict: Tensor dictionary of input data
  175. max_num_boxes: Max number of groundtruth boxes needed to compute shapes for
  176. padding.
  177. num_classes: Number of classes in the dataset needed to compute shapes for
  178. padding.
  179. spatial_image_shape: A list of two integers of the form [height, width]
  180. containing expected spatial shape of the image.
  181. Returns:
  182. A dictionary keyed by fields.InputDataFields containing padding shapes for
  183. tensors in the dataset.
  184. Raises:
  185. ValueError: If groundtruth classes is neither rank 1 nor rank 2, or if we
  186. detect that additional channels have not been concatenated yet.
  187. """
  188. if not spatial_image_shape or spatial_image_shape == [-1, -1]:
  189. height, width = None, None
  190. else:
  191. height, width = spatial_image_shape # pylint: disable=unpacking-non-sequence
  192. num_additional_channels = 0
  193. if fields.InputDataFields.image_additional_channels in tensor_dict:
  194. num_additional_channels = tensor_dict[
  195. fields.InputDataFields.image_additional_channels].shape[2].value
  196. # We assume that if num_additional_channels > 0, then it has already been
  197. # concatenated to the base image (but not the ground truth).
  198. num_channels = 3
  199. if fields.InputDataFields.image in tensor_dict:
  200. num_channels = tensor_dict[fields.InputDataFields.image].shape[2].value
  201. if num_additional_channels:
  202. if num_additional_channels >= num_channels:
  203. raise ValueError(
  204. 'Image must be already concatenated with additional channels.')
  205. if (fields.InputDataFields.original_image in tensor_dict and
  206. tensor_dict[fields.InputDataFields.original_image].shape[2].value ==
  207. num_channels):
  208. raise ValueError(
  209. 'Image must be already concatenated with additional channels.')
  210. padding_shapes = {
  211. fields.InputDataFields.image: [
  212. height, width, num_channels
  213. ],
  214. fields.InputDataFields.original_image_spatial_shape: [2],
  215. fields.InputDataFields.image_additional_channels: [
  216. height, width, num_additional_channels
  217. ],
  218. fields.InputDataFields.source_id: [],
  219. fields.InputDataFields.filename: [],
  220. fields.InputDataFields.key: [],
  221. fields.InputDataFields.groundtruth_difficult: [max_num_boxes],
  222. fields.InputDataFields.groundtruth_boxes: [max_num_boxes, 4],
  223. fields.InputDataFields.groundtruth_classes: [max_num_boxes, num_classes],
  224. fields.InputDataFields.groundtruth_instance_masks: [
  225. max_num_boxes, height, width
  226. ],
  227. fields.InputDataFields.groundtruth_is_crowd: [max_num_boxes],
  228. fields.InputDataFields.groundtruth_group_of: [max_num_boxes],
  229. fields.InputDataFields.groundtruth_area: [max_num_boxes],
  230. fields.InputDataFields.groundtruth_weights: [max_num_boxes],
  231. fields.InputDataFields.groundtruth_confidences: [
  232. max_num_boxes, num_classes
  233. ],
  234. fields.InputDataFields.num_groundtruth_boxes: [],
  235. fields.InputDataFields.groundtruth_label_types: [max_num_boxes],
  236. fields.InputDataFields.groundtruth_label_weights: [max_num_boxes],
  237. fields.InputDataFields.true_image_shape: [3],
  238. fields.InputDataFields.groundtruth_image_classes: [num_classes],
  239. fields.InputDataFields.groundtruth_image_confidences: [num_classes],
  240. }
  241. if fields.InputDataFields.original_image in tensor_dict:
  242. padding_shapes[fields.InputDataFields.original_image] = [
  243. height, width, tensor_dict[fields.InputDataFields.
  244. original_image].shape[2].value
  245. ]
  246. if fields.InputDataFields.groundtruth_keypoints in tensor_dict:
  247. tensor_shape = (
  248. tensor_dict[fields.InputDataFields.groundtruth_keypoints].shape)
  249. padding_shape = [max_num_boxes, tensor_shape[1].value,
  250. tensor_shape[2].value]
  251. padding_shapes[fields.InputDataFields.groundtruth_keypoints] = padding_shape
  252. if fields.InputDataFields.groundtruth_keypoint_visibilities in tensor_dict:
  253. tensor_shape = tensor_dict[fields.InputDataFields.
  254. groundtruth_keypoint_visibilities].shape
  255. padding_shape = [max_num_boxes, tensor_shape[1].value]
  256. padding_shapes[fields.InputDataFields.
  257. groundtruth_keypoint_visibilities] = padding_shape
  258. padded_tensor_dict = {}
  259. for tensor_name in tensor_dict:
  260. padded_tensor_dict[tensor_name] = shape_utils.pad_or_clip_nd(
  261. tensor_dict[tensor_name], padding_shapes[tensor_name])
  262. # Make sure that the number of groundtruth boxes now reflects the
  263. # padded/clipped tensors.
  264. if fields.InputDataFields.num_groundtruth_boxes in padded_tensor_dict:
  265. padded_tensor_dict[fields.InputDataFields.num_groundtruth_boxes] = (
  266. tf.minimum(
  267. padded_tensor_dict[fields.InputDataFields.num_groundtruth_boxes],
  268. max_num_boxes))
  269. return padded_tensor_dict
  270. def augment_input_data(tensor_dict, data_augmentation_options):
  271. """Applies data augmentation ops to input tensors.
  272. Args:
  273. tensor_dict: A dictionary of input tensors keyed by fields.InputDataFields.
  274. data_augmentation_options: A list of tuples, where each tuple contains a
  275. function and a dictionary that contains arguments and their values.
  276. Usually, this is the output of core/preprocessor.build.
  277. Returns:
  278. A dictionary of tensors obtained by applying data augmentation ops to the
  279. input tensor dictionary.
  280. """
  281. tensor_dict[fields.InputDataFields.image] = tf.expand_dims(
  282. tf.to_float(tensor_dict[fields.InputDataFields.image]), 0)
  283. include_instance_masks = (fields.InputDataFields.groundtruth_instance_masks
  284. in tensor_dict)
  285. include_keypoints = (fields.InputDataFields.groundtruth_keypoints
  286. in tensor_dict)
  287. include_label_weights = (fields.InputDataFields.groundtruth_weights
  288. in tensor_dict)
  289. include_label_confidences = (fields.InputDataFields.groundtruth_confidences
  290. in tensor_dict)
  291. include_multiclass_scores = (fields.InputDataFields.multiclass_scores in
  292. tensor_dict)
  293. tensor_dict = preprocessor.preprocess(
  294. tensor_dict, data_augmentation_options,
  295. func_arg_map=preprocessor.get_default_func_arg_map(
  296. include_label_weights=include_label_weights,
  297. include_label_confidences=include_label_confidences,
  298. include_multiclass_scores=include_multiclass_scores,
  299. include_instance_masks=include_instance_masks,
  300. include_keypoints=include_keypoints))
  301. tensor_dict[fields.InputDataFields.image] = tf.squeeze(
  302. tensor_dict[fields.InputDataFields.image], axis=0)
  303. return tensor_dict
  304. def _get_labels_dict(input_dict):
  305. """Extracts labels dict from input dict."""
  306. required_label_keys = [
  307. fields.InputDataFields.num_groundtruth_boxes,
  308. fields.InputDataFields.groundtruth_boxes,
  309. fields.InputDataFields.groundtruth_classes,
  310. fields.InputDataFields.groundtruth_weights,
  311. ]
  312. labels_dict = {}
  313. for key in required_label_keys:
  314. labels_dict[key] = input_dict[key]
  315. optional_label_keys = [
  316. fields.InputDataFields.groundtruth_confidences,
  317. fields.InputDataFields.groundtruth_keypoints,
  318. fields.InputDataFields.groundtruth_instance_masks,
  319. fields.InputDataFields.groundtruth_area,
  320. fields.InputDataFields.groundtruth_is_crowd,
  321. fields.InputDataFields.groundtruth_difficult
  322. ]
  323. for key in optional_label_keys:
  324. if key in input_dict:
  325. labels_dict[key] = input_dict[key]
  326. if fields.InputDataFields.groundtruth_difficult in labels_dict:
  327. labels_dict[fields.InputDataFields.groundtruth_difficult] = tf.cast(
  328. labels_dict[fields.InputDataFields.groundtruth_difficult], tf.int32)
  329. return labels_dict
  330. def _replace_empty_string_with_random_number(string_tensor):
  331. """Returns string unchanged if non-empty, and random string tensor otherwise.
  332. The random string is an integer 0 and 2**63 - 1, casted as string.
  333. Args:
  334. string_tensor: A tf.tensor of dtype string.
  335. Returns:
  336. out_string: A tf.tensor of dtype string. If string_tensor contains the empty
  337. string, out_string will contain a random integer casted to a string.
  338. Otherwise string_tensor is returned unchanged.
  339. """
  340. empty_string = tf.constant('', dtype=tf.string, name='EmptyString')
  341. random_source_id = tf.as_string(
  342. tf.random_uniform(shape=[], maxval=2**63 - 1, dtype=tf.int64))
  343. out_string = tf.cond(
  344. tf.equal(string_tensor, empty_string),
  345. true_fn=lambda: random_source_id,
  346. false_fn=lambda: string_tensor)
  347. return out_string
  348. def _get_features_dict(input_dict):
  349. """Extracts features dict from input dict."""
  350. source_id = _replace_empty_string_with_random_number(
  351. input_dict[fields.InputDataFields.source_id])
  352. hash_from_source_id = tf.string_to_hash_bucket_fast(source_id, HASH_BINS)
  353. features = {
  354. fields.InputDataFields.image:
  355. input_dict[fields.InputDataFields.image],
  356. HASH_KEY: tf.cast(hash_from_source_id, tf.int32),
  357. fields.InputDataFields.true_image_shape:
  358. input_dict[fields.InputDataFields.true_image_shape],
  359. fields.InputDataFields.original_image_spatial_shape:
  360. input_dict[fields.InputDataFields.original_image_spatial_shape]
  361. }
  362. if fields.InputDataFields.original_image in input_dict:
  363. features[fields.InputDataFields.original_image] = input_dict[
  364. fields.InputDataFields.original_image]
  365. return features
  366. def create_train_input_fn(train_config, train_input_config,
  367. model_config):
  368. """Creates a train `input` function for `Estimator`.
  369. Args:
  370. train_config: A train_pb2.TrainConfig.
  371. train_input_config: An input_reader_pb2.InputReader.
  372. model_config: A model_pb2.DetectionModel.
  373. Returns:
  374. `input_fn` for `Estimator` in TRAIN mode.
  375. """
  376. def _train_input_fn(params=None):
  377. """Returns `features` and `labels` tensor dictionaries for training.
  378. Args:
  379. params: Parameter dictionary passed from the estimator.
  380. Returns:
  381. A tf.data.Dataset that holds (features, labels) tuple.
  382. features: Dictionary of feature tensors.
  383. features[fields.InputDataFields.image] is a [batch_size, H, W, C]
  384. float32 tensor with preprocessed images.
  385. features[HASH_KEY] is a [batch_size] int32 tensor representing unique
  386. identifiers for the images.
  387. features[fields.InputDataFields.true_image_shape] is a [batch_size, 3]
  388. int32 tensor representing the true image shapes, as preprocessed
  389. images could be padded.
  390. features[fields.InputDataFields.original_image] (optional) is a
  391. [batch_size, H, W, C] float32 tensor with original images.
  392. labels: Dictionary of groundtruth tensors.
  393. labels[fields.InputDataFields.num_groundtruth_boxes] is a [batch_size]
  394. int32 tensor indicating the number of groundtruth boxes.
  395. labels[fields.InputDataFields.groundtruth_boxes] is a
  396. [batch_size, num_boxes, 4] float32 tensor containing the corners of
  397. the groundtruth boxes.
  398. labels[fields.InputDataFields.groundtruth_classes] is a
  399. [batch_size, num_boxes, num_classes] float32 one-hot tensor of
  400. classes.
  401. labels[fields.InputDataFields.groundtruth_weights] is a
  402. [batch_size, num_boxes] float32 tensor containing groundtruth weights
  403. for the boxes.
  404. -- Optional --
  405. labels[fields.InputDataFields.groundtruth_instance_masks] is a
  406. [batch_size, num_boxes, H, W] float32 tensor containing only binary
  407. values, which represent instance masks for objects.
  408. labels[fields.InputDataFields.groundtruth_keypoints] is a
  409. [batch_size, num_boxes, num_keypoints, 2] float32 tensor containing
  410. keypoints for each box.
  411. Raises:
  412. TypeError: if the `train_config`, `train_input_config` or `model_config`
  413. are not of the correct type.
  414. """
  415. if not isinstance(train_config, train_pb2.TrainConfig):
  416. raise TypeError('For training mode, the `train_config` must be a '
  417. 'train_pb2.TrainConfig.')
  418. if not isinstance(train_input_config, input_reader_pb2.InputReader):
  419. raise TypeError('The `train_input_config` must be a '
  420. 'input_reader_pb2.InputReader.')
  421. if not isinstance(model_config, model_pb2.DetectionModel):
  422. raise TypeError('The `model_config` must be a '
  423. 'model_pb2.DetectionModel.')
  424. def transform_and_pad_input_data_fn(tensor_dict):
  425. """Combines transform and pad operation."""
  426. data_augmentation_options = [
  427. preprocessor_builder.build(step)
  428. for step in train_config.data_augmentation_options
  429. ]
  430. data_augmentation_fn = functools.partial(
  431. augment_input_data,
  432. data_augmentation_options=data_augmentation_options)
  433. model = model_builder.build(model_config, is_training=True)
  434. image_resizer_config = config_util.get_image_resizer_config(model_config)
  435. image_resizer_fn = image_resizer_builder.build(image_resizer_config)
  436. transform_data_fn = functools.partial(
  437. transform_input_data, model_preprocess_fn=model.preprocess,
  438. image_resizer_fn=image_resizer_fn,
  439. num_classes=config_util.get_number_of_classes(model_config),
  440. data_augmentation_fn=data_augmentation_fn,
  441. merge_multiple_boxes=train_config.merge_multiple_label_boxes,
  442. retain_original_image=train_config.retain_original_images,
  443. use_multiclass_scores=train_config.use_multiclass_scores,
  444. use_bfloat16=train_config.use_bfloat16)
  445. tensor_dict = pad_input_data_to_static_shapes(
  446. tensor_dict=transform_data_fn(tensor_dict),
  447. max_num_boxes=train_input_config.max_number_of_boxes,
  448. num_classes=config_util.get_number_of_classes(model_config),
  449. spatial_image_shape=config_util.get_spatial_image_size(
  450. image_resizer_config))
  451. return (_get_features_dict(tensor_dict), _get_labels_dict(tensor_dict))
  452. dataset = INPUT_BUILDER_UTIL_MAP['dataset_build'](
  453. train_input_config,
  454. transform_input_data_fn=transform_and_pad_input_data_fn,
  455. batch_size=params['batch_size'] if params else train_config.batch_size)
  456. return dataset
  457. return _train_input_fn
  458. def create_eval_input_fn(eval_config, eval_input_config, model_config):
  459. """Creates an eval `input` function for `Estimator`.
  460. Args:
  461. eval_config: An eval_pb2.EvalConfig.
  462. eval_input_config: An input_reader_pb2.InputReader.
  463. model_config: A model_pb2.DetectionModel.
  464. Returns:
  465. `input_fn` for `Estimator` in EVAL mode.
  466. """
  467. def _eval_input_fn(params=None):
  468. """Returns `features` and `labels` tensor dictionaries for evaluation.
  469. Args:
  470. params: Parameter dictionary passed from the estimator.
  471. Returns:
  472. A tf.data.Dataset that holds (features, labels) tuple.
  473. features: Dictionary of feature tensors.
  474. features[fields.InputDataFields.image] is a [1, H, W, C] float32 tensor
  475. with preprocessed images.
  476. features[HASH_KEY] is a [1] int32 tensor representing unique
  477. identifiers for the images.
  478. features[fields.InputDataFields.true_image_shape] is a [1, 3]
  479. int32 tensor representing the true image shapes, as preprocessed
  480. images could be padded.
  481. features[fields.InputDataFields.original_image] is a [1, H', W', C]
  482. float32 tensor with the original image.
  483. labels: Dictionary of groundtruth tensors.
  484. labels[fields.InputDataFields.groundtruth_boxes] is a [1, num_boxes, 4]
  485. float32 tensor containing the corners of the groundtruth boxes.
  486. labels[fields.InputDataFields.groundtruth_classes] is a
  487. [num_boxes, num_classes] float32 one-hot tensor of classes.
  488. labels[fields.InputDataFields.groundtruth_area] is a [1, num_boxes]
  489. float32 tensor containing object areas.
  490. labels[fields.InputDataFields.groundtruth_is_crowd] is a [1, num_boxes]
  491. bool tensor indicating if the boxes enclose a crowd.
  492. labels[fields.InputDataFields.groundtruth_difficult] is a [1, num_boxes]
  493. int32 tensor indicating if the boxes represent difficult instances.
  494. -- Optional --
  495. labels[fields.InputDataFields.groundtruth_instance_masks] is a
  496. [1, num_boxes, H, W] float32 tensor containing only binary values,
  497. which represent instance masks for objects.
  498. Raises:
  499. TypeError: if the `eval_config`, `eval_input_config` or `model_config`
  500. are not of the correct type.
  501. """
  502. params = params or {}
  503. if not isinstance(eval_config, eval_pb2.EvalConfig):
  504. raise TypeError('For eval mode, the `eval_config` must be a '
  505. 'train_pb2.EvalConfig.')
  506. if not isinstance(eval_input_config, input_reader_pb2.InputReader):
  507. raise TypeError('The `eval_input_config` must be a '
  508. 'input_reader_pb2.InputReader.')
  509. if not isinstance(model_config, model_pb2.DetectionModel):
  510. raise TypeError('The `model_config` must be a '
  511. 'model_pb2.DetectionModel.')
  512. def transform_and_pad_input_data_fn(tensor_dict):
  513. """Combines transform and pad operation."""
  514. num_classes = config_util.get_number_of_classes(model_config)
  515. model = model_builder.build(model_config, is_training=False)
  516. image_resizer_config = config_util.get_image_resizer_config(model_config)
  517. image_resizer_fn = image_resizer_builder.build(image_resizer_config)
  518. transform_data_fn = functools.partial(
  519. transform_input_data, model_preprocess_fn=model.preprocess,
  520. image_resizer_fn=image_resizer_fn,
  521. num_classes=num_classes,
  522. data_augmentation_fn=None,
  523. retain_original_image=eval_config.retain_original_images)
  524. tensor_dict = pad_input_data_to_static_shapes(
  525. tensor_dict=transform_data_fn(tensor_dict),
  526. max_num_boxes=eval_input_config.max_number_of_boxes,
  527. num_classes=config_util.get_number_of_classes(model_config),
  528. spatial_image_shape=config_util.get_spatial_image_size(
  529. image_resizer_config))
  530. return (_get_features_dict(tensor_dict), _get_labels_dict(tensor_dict))
  531. dataset = INPUT_BUILDER_UTIL_MAP['dataset_build'](
  532. eval_input_config,
  533. batch_size=params['batch_size'] if params else eval_config.batch_size,
  534. transform_input_data_fn=transform_and_pad_input_data_fn)
  535. return dataset
  536. return _eval_input_fn
  537. def create_predict_input_fn(model_config, predict_input_config):
  538. """Creates a predict `input` function for `Estimator`.
  539. Args:
  540. model_config: A model_pb2.DetectionModel.
  541. predict_input_config: An input_reader_pb2.InputReader.
  542. Returns:
  543. `input_fn` for `Estimator` in PREDICT mode.
  544. """
  545. def _predict_input_fn(params=None):
  546. """Decodes serialized tf.Examples and returns `ServingInputReceiver`.
  547. Args:
  548. params: Parameter dictionary passed from the estimator.
  549. Returns:
  550. `ServingInputReceiver`.
  551. """
  552. del params
  553. example = tf.placeholder(dtype=tf.string, shape=[], name='tf_example')
  554. num_classes = config_util.get_number_of_classes(model_config)
  555. model = model_builder.build(model_config, is_training=False)
  556. image_resizer_config = config_util.get_image_resizer_config(model_config)
  557. image_resizer_fn = image_resizer_builder.build(image_resizer_config)
  558. transform_fn = functools.partial(
  559. transform_input_data, model_preprocess_fn=model.preprocess,
  560. image_resizer_fn=image_resizer_fn,
  561. num_classes=num_classes,
  562. data_augmentation_fn=None)
  563. decoder = tf_example_decoder.TfExampleDecoder(
  564. load_instance_masks=False,
  565. num_additional_channels=predict_input_config.num_additional_channels)
  566. input_dict = transform_fn(decoder.decode(example))
  567. images = tf.to_float(input_dict[fields.InputDataFields.image])
  568. images = tf.expand_dims(images, axis=0)
  569. true_image_shape = tf.expand_dims(
  570. input_dict[fields.InputDataFields.true_image_shape], axis=0)
  571. return tf.estimator.export.ServingInputReceiver(
  572. features={
  573. fields.InputDataFields.image: images,
  574. fields.InputDataFields.true_image_shape: true_image_shape},
  575. receiver_tensors={SERVING_FED_EXAMPLE_KEY: example})
  576. return _predict_input_fn