You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

479 lines
21 KiB

6 years ago
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tensorflow Example proto decoder for object detection.
  16. A decoder to decode string tensors containing serialized tensorflow.Example
  17. protos for object detection.
  18. """
  19. import tensorflow as tf
  20. from object_detection.core import data_decoder
  21. from object_detection.core import standard_fields as fields
  22. from object_detection.protos import input_reader_pb2
  23. from object_detection.utils import label_map_util
  24. slim_example_decoder = tf.contrib.slim.tfexample_decoder
  25. class _ClassTensorHandler(slim_example_decoder.Tensor):
  26. """An ItemHandler to fetch class ids from class text."""
  27. def __init__(self,
  28. tensor_key,
  29. label_map_proto_file,
  30. shape_keys=None,
  31. shape=None,
  32. default_value=''):
  33. """Initializes the LookupTensor handler.
  34. Simply calls a vocabulary (most often, a label mapping) lookup.
  35. Args:
  36. tensor_key: the name of the `TFExample` feature to read the tensor from.
  37. label_map_proto_file: File path to a text format LabelMapProto message
  38. mapping class text to id.
  39. shape_keys: Optional name or list of names of the TF-Example feature in
  40. which the tensor shape is stored. If a list, then each corresponds to
  41. one dimension of the shape.
  42. shape: Optional output shape of the `Tensor`. If provided, the `Tensor` is
  43. reshaped accordingly.
  44. default_value: The value used when the `tensor_key` is not found in a
  45. particular `TFExample`.
  46. Raises:
  47. ValueError: if both `shape_keys` and `shape` are specified.
  48. """
  49. name_to_id = label_map_util.get_label_map_dict(
  50. label_map_proto_file, use_display_name=False)
  51. # We use a default_value of -1, but we expect all labels to be contained
  52. # in the label map.
  53. name_to_id_table = tf.contrib.lookup.HashTable(
  54. initializer=tf.contrib.lookup.KeyValueTensorInitializer(
  55. keys=tf.constant(list(name_to_id.keys())),
  56. values=tf.constant(list(name_to_id.values()), dtype=tf.int64)),
  57. default_value=-1)
  58. display_name_to_id = label_map_util.get_label_map_dict(
  59. label_map_proto_file, use_display_name=True)
  60. # We use a default_value of -1, but we expect all labels to be contained
  61. # in the label map.
  62. display_name_to_id_table = tf.contrib.lookup.HashTable(
  63. initializer=tf.contrib.lookup.KeyValueTensorInitializer(
  64. keys=tf.constant(list(display_name_to_id.keys())),
  65. values=tf.constant(
  66. list(display_name_to_id.values()), dtype=tf.int64)),
  67. default_value=-1)
  68. self._name_to_id_table = name_to_id_table
  69. self._display_name_to_id_table = display_name_to_id_table
  70. super(_ClassTensorHandler, self).__init__(tensor_key, shape_keys, shape,
  71. default_value)
  72. def tensors_to_item(self, keys_to_tensors):
  73. unmapped_tensor = super(_ClassTensorHandler,
  74. self).tensors_to_item(keys_to_tensors)
  75. return tf.maximum(self._name_to_id_table.lookup(unmapped_tensor),
  76. self._display_name_to_id_table.lookup(unmapped_tensor))
  77. class _BackupHandler(slim_example_decoder.ItemHandler):
  78. """An ItemHandler that tries two ItemHandlers in order."""
  79. def __init__(self, handler, backup):
  80. """Initializes the BackupHandler handler.
  81. If the first Handler's tensors_to_item returns a Tensor with no elements,
  82. the second Handler is used.
  83. Args:
  84. handler: The primary ItemHandler.
  85. backup: The backup ItemHandler.
  86. Raises:
  87. ValueError: if either is not an ItemHandler.
  88. """
  89. if not isinstance(handler, slim_example_decoder.ItemHandler):
  90. raise ValueError('Primary handler is of type %s instead of ItemHandler' %
  91. type(handler))
  92. if not isinstance(backup, slim_example_decoder.ItemHandler):
  93. raise ValueError(
  94. 'Backup handler is of type %s instead of ItemHandler' % type(backup))
  95. self._handler = handler
  96. self._backup = backup
  97. super(_BackupHandler, self).__init__(handler.keys + backup.keys)
  98. def tensors_to_item(self, keys_to_tensors):
  99. item = self._handler.tensors_to_item(keys_to_tensors)
  100. return tf.cond(
  101. pred=tf.equal(tf.reduce_prod(tf.shape(item)), 0),
  102. true_fn=lambda: self._backup.tensors_to_item(keys_to_tensors),
  103. false_fn=lambda: item)
  104. class TfExampleDecoder(data_decoder.DataDecoder):
  105. """Tensorflow Example proto decoder."""
  106. def __init__(self,
  107. load_instance_masks=False,
  108. instance_mask_type=input_reader_pb2.NUMERICAL_MASKS,
  109. label_map_proto_file=None,
  110. use_display_name=False,
  111. dct_method='',
  112. num_keypoints=0,
  113. num_additional_channels=0,
  114. load_multiclass_scores=False):
  115. """Constructor sets keys_to_features and items_to_handlers.
  116. Args:
  117. load_instance_masks: whether or not to load and handle instance masks.
  118. instance_mask_type: type of instance masks. Options are provided in
  119. input_reader.proto. This is only used if `load_instance_masks` is True.
  120. label_map_proto_file: a file path to a
  121. object_detection.protos.StringIntLabelMap proto. If provided, then the
  122. mapped IDs of 'image/object/class/text' will take precedence over the
  123. existing 'image/object/class/label' ID. Also, if provided, it is
  124. assumed that 'image/object/class/text' will be in the data.
  125. use_display_name: whether or not to use the `display_name` for label
  126. mapping (instead of `name`). Only used if label_map_proto_file is
  127. provided.
  128. dct_method: An optional string. Defaults to None. It only takes
  129. effect when image format is jpeg, used to specify a hint about the
  130. algorithm used for jpeg decompression. Currently valid values
  131. are ['INTEGER_FAST', 'INTEGER_ACCURATE']. The hint may be ignored, for
  132. example, the jpeg library does not have that specific option.
  133. num_keypoints: the number of keypoints per object.
  134. num_additional_channels: how many additional channels to use.
  135. load_multiclass_scores: Whether to load multiclass scores associated with
  136. boxes.
  137. Raises:
  138. ValueError: If `instance_mask_type` option is not one of
  139. input_reader_pb2.DEFAULT, input_reader_pb2.NUMERICAL, or
  140. input_reader_pb2.PNG_MASKS.
  141. """
  142. # TODO(rathodv): delete unused `use_display_name` argument once we change
  143. # other decoders to handle label maps similarly.
  144. del use_display_name
  145. self.keys_to_features = {
  146. 'image/encoded':
  147. tf.FixedLenFeature((), tf.string, default_value=''),
  148. 'image/format':
  149. tf.FixedLenFeature((), tf.string, default_value='jpeg'),
  150. 'image/filename':
  151. tf.FixedLenFeature((), tf.string, default_value=''),
  152. 'image/key/sha256':
  153. tf.FixedLenFeature((), tf.string, default_value=''),
  154. 'image/source_id':
  155. tf.FixedLenFeature((), tf.string, default_value=''),
  156. 'image/height':
  157. tf.FixedLenFeature((), tf.int64, default_value=1),
  158. 'image/width':
  159. tf.FixedLenFeature((), tf.int64, default_value=1),
  160. # Image-level labels.
  161. 'image/class/text':
  162. tf.VarLenFeature(tf.string),
  163. 'image/class/label':
  164. tf.VarLenFeature(tf.int64),
  165. # Object boxes and classes.
  166. 'image/object/bbox/xmin':
  167. tf.VarLenFeature(tf.float32),
  168. 'image/object/bbox/xmax':
  169. tf.VarLenFeature(tf.float32),
  170. 'image/object/bbox/ymin':
  171. tf.VarLenFeature(tf.float32),
  172. 'image/object/bbox/ymax':
  173. tf.VarLenFeature(tf.float32),
  174. 'image/object/class/label':
  175. tf.VarLenFeature(tf.int64),
  176. 'image/object/class/text':
  177. tf.VarLenFeature(tf.string),
  178. 'image/object/area':
  179. tf.VarLenFeature(tf.float32),
  180. 'image/object/is_crowd':
  181. tf.VarLenFeature(tf.int64),
  182. 'image/object/difficult':
  183. tf.VarLenFeature(tf.int64),
  184. 'image/object/group_of':
  185. tf.VarLenFeature(tf.int64),
  186. 'image/object/weight':
  187. tf.VarLenFeature(tf.float32),
  188. }
  189. # We are checking `dct_method` instead of passing it directly in order to
  190. # ensure TF version 1.6 compatibility.
  191. if dct_method:
  192. image = slim_example_decoder.Image(
  193. image_key='image/encoded',
  194. format_key='image/format',
  195. channels=3,
  196. dct_method=dct_method)
  197. additional_channel_image = slim_example_decoder.Image(
  198. image_key='image/additional_channels/encoded',
  199. format_key='image/format',
  200. channels=1,
  201. repeated=True,
  202. dct_method=dct_method)
  203. else:
  204. image = slim_example_decoder.Image(
  205. image_key='image/encoded', format_key='image/format', channels=3)
  206. additional_channel_image = slim_example_decoder.Image(
  207. image_key='image/additional_channels/encoded',
  208. format_key='image/format',
  209. channels=1,
  210. repeated=True)
  211. self.items_to_handlers = {
  212. fields.InputDataFields.image:
  213. image,
  214. fields.InputDataFields.source_id: (
  215. slim_example_decoder.Tensor('image/source_id')),
  216. fields.InputDataFields.key: (
  217. slim_example_decoder.Tensor('image/key/sha256')),
  218. fields.InputDataFields.filename: (
  219. slim_example_decoder.Tensor('image/filename')),
  220. # Object boxes and classes.
  221. fields.InputDataFields.groundtruth_boxes: (
  222. slim_example_decoder.BoundingBox(['ymin', 'xmin', 'ymax', 'xmax'],
  223. 'image/object/bbox/')),
  224. fields.InputDataFields.groundtruth_area:
  225. slim_example_decoder.Tensor('image/object/area'),
  226. fields.InputDataFields.groundtruth_is_crowd: (
  227. slim_example_decoder.Tensor('image/object/is_crowd')),
  228. fields.InputDataFields.groundtruth_difficult: (
  229. slim_example_decoder.Tensor('image/object/difficult')),
  230. fields.InputDataFields.groundtruth_group_of: (
  231. slim_example_decoder.Tensor('image/object/group_of')),
  232. fields.InputDataFields.groundtruth_weights: (
  233. slim_example_decoder.Tensor('image/object/weight')),
  234. }
  235. if load_multiclass_scores:
  236. self.keys_to_features[
  237. 'image/object/class/multiclass_scores'] = tf.VarLenFeature(tf.float32)
  238. self.items_to_handlers[fields.InputDataFields.multiclass_scores] = (
  239. slim_example_decoder.Tensor('image/object/class/multiclass_scores'))
  240. if num_additional_channels > 0:
  241. self.keys_to_features[
  242. 'image/additional_channels/encoded'] = tf.FixedLenFeature(
  243. (num_additional_channels,), tf.string)
  244. self.items_to_handlers[
  245. fields.InputDataFields.
  246. image_additional_channels] = additional_channel_image
  247. self._num_keypoints = num_keypoints
  248. if num_keypoints > 0:
  249. self.keys_to_features['image/object/keypoint/x'] = (
  250. tf.VarLenFeature(tf.float32))
  251. self.keys_to_features['image/object/keypoint/y'] = (
  252. tf.VarLenFeature(tf.float32))
  253. self.items_to_handlers[fields.InputDataFields.groundtruth_keypoints] = (
  254. slim_example_decoder.ItemHandlerCallback(
  255. ['image/object/keypoint/y', 'image/object/keypoint/x'],
  256. self._reshape_keypoints))
  257. if load_instance_masks:
  258. if instance_mask_type in (input_reader_pb2.DEFAULT,
  259. input_reader_pb2.NUMERICAL_MASKS):
  260. self.keys_to_features['image/object/mask'] = (
  261. tf.VarLenFeature(tf.float32))
  262. self.items_to_handlers[
  263. fields.InputDataFields.groundtruth_instance_masks] = (
  264. slim_example_decoder.ItemHandlerCallback(
  265. ['image/object/mask', 'image/height', 'image/width'],
  266. self._reshape_instance_masks))
  267. elif instance_mask_type == input_reader_pb2.PNG_MASKS:
  268. self.keys_to_features['image/object/mask'] = tf.VarLenFeature(tf.string)
  269. self.items_to_handlers[
  270. fields.InputDataFields.groundtruth_instance_masks] = (
  271. slim_example_decoder.ItemHandlerCallback(
  272. ['image/object/mask', 'image/height', 'image/width'],
  273. self._decode_png_instance_masks))
  274. else:
  275. raise ValueError('Did not recognize the `instance_mask_type` option.')
  276. if label_map_proto_file:
  277. # If the label_map_proto is provided, try to use it in conjunction with
  278. # the class text, and fall back to a materialized ID.
  279. label_handler = _BackupHandler(
  280. _ClassTensorHandler(
  281. 'image/object/class/text', label_map_proto_file,
  282. default_value=''),
  283. slim_example_decoder.Tensor('image/object/class/label'))
  284. image_label_handler = _BackupHandler(
  285. _ClassTensorHandler(
  286. fields.TfExampleFields.image_class_text,
  287. label_map_proto_file,
  288. default_value=''),
  289. slim_example_decoder.Tensor(fields.TfExampleFields.image_class_label))
  290. else:
  291. label_handler = slim_example_decoder.Tensor('image/object/class/label')
  292. image_label_handler = slim_example_decoder.Tensor(
  293. fields.TfExampleFields.image_class_label)
  294. self.items_to_handlers[
  295. fields.InputDataFields.groundtruth_classes] = label_handler
  296. self.items_to_handlers[
  297. fields.InputDataFields.groundtruth_image_classes] = image_label_handler
  298. def decode(self, tf_example_string_tensor):
  299. """Decodes serialized tensorflow example and returns a tensor dictionary.
  300. Args:
  301. tf_example_string_tensor: a string tensor holding a serialized tensorflow
  302. example proto.
  303. Returns:
  304. A dictionary of the following tensors.
  305. fields.InputDataFields.image - 3D uint8 tensor of shape [None, None, 3]
  306. containing image.
  307. fields.InputDataFields.original_image_spatial_shape - 1D int32 tensor of
  308. shape [2] containing shape of the image.
  309. fields.InputDataFields.source_id - string tensor containing original
  310. image id.
  311. fields.InputDataFields.key - string tensor with unique sha256 hash key.
  312. fields.InputDataFields.filename - string tensor with original dataset
  313. filename.
  314. fields.InputDataFields.groundtruth_boxes - 2D float32 tensor of shape
  315. [None, 4] containing box corners.
  316. fields.InputDataFields.groundtruth_classes - 1D int64 tensor of shape
  317. [None] containing classes for the boxes.
  318. fields.InputDataFields.groundtruth_weights - 1D float32 tensor of
  319. shape [None] indicating the weights of groundtruth boxes.
  320. fields.InputDataFields.groundtruth_area - 1D float32 tensor of shape
  321. [None] containing containing object mask area in pixel squared.
  322. fields.InputDataFields.groundtruth_is_crowd - 1D bool tensor of shape
  323. [None] indicating if the boxes enclose a crowd.
  324. Optional:
  325. fields.InputDataFields.image_additional_channels - 3D uint8 tensor of
  326. shape [None, None, num_additional_channels]. 1st dim is height; 2nd dim
  327. is width; 3rd dim is the number of additional channels.
  328. fields.InputDataFields.groundtruth_difficult - 1D bool tensor of shape
  329. [None] indicating if the boxes represent `difficult` instances.
  330. fields.InputDataFields.groundtruth_group_of - 1D bool tensor of shape
  331. [None] indicating if the boxes represent `group_of` instances.
  332. fields.InputDataFields.groundtruth_keypoints - 3D float32 tensor of
  333. shape [None, None, 2] containing keypoints, where the coordinates of
  334. the keypoints are ordered (y, x).
  335. fields.InputDataFields.groundtruth_instance_masks - 3D float32 tensor of
  336. shape [None, None, None] containing instance masks.
  337. fields.InputDataFields.groundtruth_image_classes - 1D uint64 of shape
  338. [None] containing classes for the boxes.
  339. fields.InputDataFields.multiclass_scores - 1D float32 tensor of shape
  340. [None * num_classes] containing flattened multiclass scores for
  341. groundtruth boxes.
  342. """
  343. serialized_example = tf.reshape(tf_example_string_tensor, shape=[])
  344. decoder = slim_example_decoder.TFExampleDecoder(self.keys_to_features,
  345. self.items_to_handlers)
  346. keys = decoder.list_items()
  347. tensors = decoder.decode(serialized_example, items=keys)
  348. tensor_dict = dict(zip(keys, tensors))
  349. is_crowd = fields.InputDataFields.groundtruth_is_crowd
  350. tensor_dict[is_crowd] = tf.cast(tensor_dict[is_crowd], dtype=tf.bool)
  351. tensor_dict[fields.InputDataFields.image].set_shape([None, None, 3])
  352. tensor_dict[fields.InputDataFields.original_image_spatial_shape] = tf.shape(
  353. tensor_dict[fields.InputDataFields.image])[:2]
  354. if fields.InputDataFields.image_additional_channels in tensor_dict:
  355. channels = tensor_dict[fields.InputDataFields.image_additional_channels]
  356. channels = tf.squeeze(channels, axis=3)
  357. channels = tf.transpose(channels, perm=[1, 2, 0])
  358. tensor_dict[fields.InputDataFields.image_additional_channels] = channels
  359. def default_groundtruth_weights():
  360. return tf.ones(
  361. [tf.shape(tensor_dict[fields.InputDataFields.groundtruth_boxes])[0]],
  362. dtype=tf.float32)
  363. tensor_dict[fields.InputDataFields.groundtruth_weights] = tf.cond(
  364. tf.greater(
  365. tf.shape(
  366. tensor_dict[fields.InputDataFields.groundtruth_weights])[0],
  367. 0), lambda: tensor_dict[fields.InputDataFields.groundtruth_weights],
  368. default_groundtruth_weights)
  369. return tensor_dict
  370. def _reshape_keypoints(self, keys_to_tensors):
  371. """Reshape keypoints.
  372. The instance segmentation masks are reshaped to [num_instances,
  373. num_keypoints, 2].
  374. Args:
  375. keys_to_tensors: a dictionary from keys to tensors.
  376. Returns:
  377. A 3-D float tensor of shape [num_instances, num_keypoints, 2] with values
  378. in {0, 1}.
  379. """
  380. y = keys_to_tensors['image/object/keypoint/y']
  381. if isinstance(y, tf.SparseTensor):
  382. y = tf.sparse_tensor_to_dense(y)
  383. y = tf.expand_dims(y, 1)
  384. x = keys_to_tensors['image/object/keypoint/x']
  385. if isinstance(x, tf.SparseTensor):
  386. x = tf.sparse_tensor_to_dense(x)
  387. x = tf.expand_dims(x, 1)
  388. keypoints = tf.concat([y, x], 1)
  389. keypoints = tf.reshape(keypoints, [-1, self._num_keypoints, 2])
  390. return keypoints
  391. def _reshape_instance_masks(self, keys_to_tensors):
  392. """Reshape instance segmentation masks.
  393. The instance segmentation masks are reshaped to [num_instances, height,
  394. width].
  395. Args:
  396. keys_to_tensors: a dictionary from keys to tensors.
  397. Returns:
  398. A 3-D float tensor of shape [num_instances, height, width] with values
  399. in {0, 1}.
  400. """
  401. height = keys_to_tensors['image/height']
  402. width = keys_to_tensors['image/width']
  403. to_shape = tf.cast(tf.stack([-1, height, width]), tf.int32)
  404. masks = keys_to_tensors['image/object/mask']
  405. if isinstance(masks, tf.SparseTensor):
  406. masks = tf.sparse_tensor_to_dense(masks)
  407. masks = tf.reshape(tf.to_float(tf.greater(masks, 0.0)), to_shape)
  408. return tf.cast(masks, tf.float32)
  409. def _decode_png_instance_masks(self, keys_to_tensors):
  410. """Decode PNG instance segmentation masks and stack into dense tensor.
  411. The instance segmentation masks are reshaped to [num_instances, height,
  412. width].
  413. Args:
  414. keys_to_tensors: a dictionary from keys to tensors.
  415. Returns:
  416. A 3-D float tensor of shape [num_instances, height, width] with values
  417. in {0, 1}.
  418. """
  419. def decode_png_mask(image_buffer):
  420. image = tf.squeeze(
  421. tf.image.decode_image(image_buffer, channels=1), axis=2)
  422. image.set_shape([None, None])
  423. image = tf.to_float(tf.greater(image, 0))
  424. return image
  425. png_masks = keys_to_tensors['image/object/mask']
  426. height = keys_to_tensors['image/height']
  427. width = keys_to_tensors['image/width']
  428. if isinstance(png_masks, tf.SparseTensor):
  429. png_masks = tf.sparse_tensor_to_dense(png_masks, default_value='')
  430. return tf.cond(
  431. tf.greater(tf.size(png_masks), 0),
  432. lambda: tf.map_fn(decode_png_mask, png_masks, dtype=tf.float32),
  433. lambda: tf.zeros(tf.to_int32(tf.stack([0, height, width]))))