You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

136 lines
5.5 KiB

6 years ago
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Provides functions to batch a dictionary of input tensors."""
  16. import collections
  17. import tensorflow as tf
  18. from object_detection.core import prefetcher
  19. rt_shape_str = '_runtime_shapes'
  20. class BatchQueue(object):
  21. """BatchQueue class.
  22. This class creates a batch queue to asynchronously enqueue tensors_dict.
  23. It also adds a FIFO prefetcher so that the batches are readily available
  24. for the consumers. Dequeue ops for a BatchQueue object can be created via
  25. the Dequeue method which evaluates to a batch of tensor_dict.
  26. Example input pipeline with batching:
  27. ------------------------------------
  28. key, string_tensor = slim.parallel_reader.parallel_read(...)
  29. tensor_dict = decoder.decode(string_tensor)
  30. tensor_dict = preprocessor.preprocess(tensor_dict, ...)
  31. batch_queue = batcher.BatchQueue(tensor_dict,
  32. batch_size=32,
  33. batch_queue_capacity=2000,
  34. num_batch_queue_threads=8,
  35. prefetch_queue_capacity=20)
  36. tensor_dict = batch_queue.dequeue()
  37. outputs = Model(tensor_dict)
  38. ...
  39. -----------------------------------
  40. Notes:
  41. -----
  42. This class batches tensors of unequal sizes by zero padding and unpadding
  43. them after generating a batch. This can be computationally expensive when
  44. batching tensors (such as images) that are of vastly different sizes. So it is
  45. recommended that the shapes of such tensors be fully defined in tensor_dict
  46. while other lightweight tensors such as bounding box corners and class labels
  47. can be of varying sizes. Use either crop or resize operations to fully define
  48. the shape of an image in tensor_dict.
  49. It is also recommended to perform any preprocessing operations on tensors
  50. before passing to BatchQueue and subsequently calling the Dequeue method.
  51. Another caveat is that this class does not read the last batch if it is not
  52. full. The current implementation makes it hard to support that use case. So,
  53. for evaluation, when it is critical to run all the examples through your
  54. network use the input pipeline example mentioned in core/prefetcher.py.
  55. """
  56. def __init__(self, tensor_dict, batch_size, batch_queue_capacity,
  57. num_batch_queue_threads, prefetch_queue_capacity):
  58. """Constructs a batch queue holding tensor_dict.
  59. Args:
  60. tensor_dict: dictionary of tensors to batch.
  61. batch_size: batch size.
  62. batch_queue_capacity: max capacity of the queue from which the tensors are
  63. batched.
  64. num_batch_queue_threads: number of threads to use for batching.
  65. prefetch_queue_capacity: max capacity of the queue used to prefetch
  66. assembled batches.
  67. """
  68. # Remember static shapes to set shapes of batched tensors.
  69. static_shapes = collections.OrderedDict(
  70. {key: tensor.get_shape() for key, tensor in tensor_dict.items()})
  71. # Remember runtime shapes to unpad tensors after batching.
  72. runtime_shapes = collections.OrderedDict(
  73. {(key + rt_shape_str): tf.shape(tensor)
  74. for key, tensor in tensor_dict.items()})
  75. all_tensors = tensor_dict
  76. all_tensors.update(runtime_shapes)
  77. batched_tensors = tf.train.batch(
  78. all_tensors,
  79. capacity=batch_queue_capacity,
  80. batch_size=batch_size,
  81. dynamic_pad=True,
  82. num_threads=num_batch_queue_threads)
  83. self._queue = prefetcher.prefetch(batched_tensors,
  84. prefetch_queue_capacity)
  85. self._static_shapes = static_shapes
  86. self._batch_size = batch_size
  87. def dequeue(self):
  88. """Dequeues a batch of tensor_dict from the BatchQueue.
  89. TODO: use allow_smaller_final_batch to allow running over the whole eval set
  90. Returns:
  91. A list of tensor_dicts of the requested batch_size.
  92. """
  93. batched_tensors = self._queue.dequeue()
  94. # Separate input tensors from tensors containing their runtime shapes.
  95. tensors = {}
  96. shapes = {}
  97. for key, batched_tensor in batched_tensors.items():
  98. unbatched_tensor_list = tf.unstack(batched_tensor)
  99. for i, unbatched_tensor in enumerate(unbatched_tensor_list):
  100. if rt_shape_str in key:
  101. shapes[(key[:-len(rt_shape_str)], i)] = unbatched_tensor
  102. else:
  103. tensors[(key, i)] = unbatched_tensor
  104. # Undo that padding using shapes and create a list of size `batch_size` that
  105. # contains tensor dictionaries.
  106. tensor_dict_list = []
  107. batch_size = self._batch_size
  108. for batch_id in range(batch_size):
  109. tensor_dict = {}
  110. for key in self._static_shapes:
  111. tensor_dict[key] = tf.slice(tensors[(key, batch_id)],
  112. tf.zeros_like(shapes[(key, batch_id)]),
  113. shapes[(key, batch_id)])
  114. tensor_dict[key].set_shape(self._static_shapes[key])
  115. tensor_dict_list.append(tensor_dict)
  116. return tensor_dict_list