|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
"""Provides functions to batch a dictionary of input tensors."""
|
|
import collections
|
|
|
|
import tensorflow as tf
|
|
|
|
from object_detection.core import prefetcher
|
|
|
|
rt_shape_str = '_runtime_shapes'
|
|
|
|
|
|
class BatchQueue(object):
|
|
"""BatchQueue class.
|
|
|
|
This class creates a batch queue to asynchronously enqueue tensors_dict.
|
|
It also adds a FIFO prefetcher so that the batches are readily available
|
|
for the consumers. Dequeue ops for a BatchQueue object can be created via
|
|
the Dequeue method which evaluates to a batch of tensor_dict.
|
|
|
|
Example input pipeline with batching:
|
|
------------------------------------
|
|
key, string_tensor = slim.parallel_reader.parallel_read(...)
|
|
tensor_dict = decoder.decode(string_tensor)
|
|
tensor_dict = preprocessor.preprocess(tensor_dict, ...)
|
|
batch_queue = batcher.BatchQueue(tensor_dict,
|
|
batch_size=32,
|
|
batch_queue_capacity=2000,
|
|
num_batch_queue_threads=8,
|
|
prefetch_queue_capacity=20)
|
|
tensor_dict = batch_queue.dequeue()
|
|
outputs = Model(tensor_dict)
|
|
...
|
|
-----------------------------------
|
|
|
|
Notes:
|
|
-----
|
|
This class batches tensors of unequal sizes by zero padding and unpadding
|
|
them after generating a batch. This can be computationally expensive when
|
|
batching tensors (such as images) that are of vastly different sizes. So it is
|
|
recommended that the shapes of such tensors be fully defined in tensor_dict
|
|
while other lightweight tensors such as bounding box corners and class labels
|
|
can be of varying sizes. Use either crop or resize operations to fully define
|
|
the shape of an image in tensor_dict.
|
|
|
|
It is also recommended to perform any preprocessing operations on tensors
|
|
before passing to BatchQueue and subsequently calling the Dequeue method.
|
|
|
|
Another caveat is that this class does not read the last batch if it is not
|
|
full. The current implementation makes it hard to support that use case. So,
|
|
for evaluation, when it is critical to run all the examples through your
|
|
network use the input pipeline example mentioned in core/prefetcher.py.
|
|
"""
|
|
|
|
def __init__(self, tensor_dict, batch_size, batch_queue_capacity,
|
|
num_batch_queue_threads, prefetch_queue_capacity):
|
|
"""Constructs a batch queue holding tensor_dict.
|
|
|
|
Args:
|
|
tensor_dict: dictionary of tensors to batch.
|
|
batch_size: batch size.
|
|
batch_queue_capacity: max capacity of the queue from which the tensors are
|
|
batched.
|
|
num_batch_queue_threads: number of threads to use for batching.
|
|
prefetch_queue_capacity: max capacity of the queue used to prefetch
|
|
assembled batches.
|
|
"""
|
|
# Remember static shapes to set shapes of batched tensors.
|
|
static_shapes = collections.OrderedDict(
|
|
{key: tensor.get_shape() for key, tensor in tensor_dict.items()})
|
|
# Remember runtime shapes to unpad tensors after batching.
|
|
runtime_shapes = collections.OrderedDict(
|
|
{(key + rt_shape_str): tf.shape(tensor)
|
|
for key, tensor in tensor_dict.items()})
|
|
|
|
all_tensors = tensor_dict
|
|
all_tensors.update(runtime_shapes)
|
|
batched_tensors = tf.train.batch(
|
|
all_tensors,
|
|
capacity=batch_queue_capacity,
|
|
batch_size=batch_size,
|
|
dynamic_pad=True,
|
|
num_threads=num_batch_queue_threads)
|
|
|
|
self._queue = prefetcher.prefetch(batched_tensors,
|
|
prefetch_queue_capacity)
|
|
self._static_shapes = static_shapes
|
|
self._batch_size = batch_size
|
|
|
|
def dequeue(self):
|
|
"""Dequeues a batch of tensor_dict from the BatchQueue.
|
|
|
|
TODO: use allow_smaller_final_batch to allow running over the whole eval set
|
|
|
|
Returns:
|
|
A list of tensor_dicts of the requested batch_size.
|
|
"""
|
|
batched_tensors = self._queue.dequeue()
|
|
# Separate input tensors from tensors containing their runtime shapes.
|
|
tensors = {}
|
|
shapes = {}
|
|
for key, batched_tensor in batched_tensors.items():
|
|
unbatched_tensor_list = tf.unstack(batched_tensor)
|
|
for i, unbatched_tensor in enumerate(unbatched_tensor_list):
|
|
if rt_shape_str in key:
|
|
shapes[(key[:-len(rt_shape_str)], i)] = unbatched_tensor
|
|
else:
|
|
tensors[(key, i)] = unbatched_tensor
|
|
|
|
# Undo that padding using shapes and create a list of size `batch_size` that
|
|
# contains tensor dictionaries.
|
|
tensor_dict_list = []
|
|
batch_size = self._batch_size
|
|
for batch_id in range(batch_size):
|
|
tensor_dict = {}
|
|
for key in self._static_shapes:
|
|
tensor_dict[key] = tf.slice(tensors[(key, batch_id)],
|
|
tf.zeros_like(shapes[(key, batch_id)]),
|
|
shapes[(key, batch_id)])
|
|
tensor_dict[key].set_shape(self._static_shapes[key])
|
|
tensor_dict_list.append(tensor_dict)
|
|
|
|
return tensor_dict_list
|