|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
"""Provides functions to prefetch tensors to feed into models."""
|
|
import tensorflow as tf
|
|
|
|
|
|
def prefetch(tensor_dict, capacity):
|
|
"""Creates a prefetch queue for tensors.
|
|
|
|
Creates a FIFO queue to asynchronously enqueue tensor_dicts and returns a
|
|
dequeue op that evaluates to a tensor_dict. This function is useful in
|
|
prefetching preprocessed tensors so that the data is readily available for
|
|
consumers.
|
|
|
|
Example input pipeline when you don't need batching:
|
|
----------------------------------------------------
|
|
key, string_tensor = slim.parallel_reader.parallel_read(...)
|
|
tensor_dict = decoder.decode(string_tensor)
|
|
tensor_dict = preprocessor.preprocess(tensor_dict, ...)
|
|
prefetch_queue = prefetcher.prefetch(tensor_dict, capacity=20)
|
|
tensor_dict = prefetch_queue.dequeue()
|
|
outputs = Model(tensor_dict)
|
|
...
|
|
----------------------------------------------------
|
|
|
|
For input pipelines with batching, refer to core/batcher.py
|
|
|
|
Args:
|
|
tensor_dict: a dictionary of tensors to prefetch.
|
|
capacity: the size of the prefetch queue.
|
|
|
|
Returns:
|
|
a FIFO prefetcher queue
|
|
"""
|
|
names = list(tensor_dict.keys())
|
|
dtypes = [t.dtype for t in tensor_dict.values()]
|
|
shapes = [t.get_shape() for t in tensor_dict.values()]
|
|
prefetch_queue = tf.PaddingFIFOQueue(capacity, dtypes=dtypes,
|
|
shapes=shapes,
|
|
names=names,
|
|
name='prefetch_queue')
|
|
enqueue_op = prefetch_queue.enqueue(tensor_dict)
|
|
tf.train.queue_runner.add_queue_runner(tf.train.queue_runner.QueueRunner(
|
|
prefetch_queue, [enqueue_op]))
|
|
tf.summary.scalar('queue/%s/fraction_of_%d_full' % (prefetch_queue.name,
|
|
capacity),
|
|
tf.to_float(prefetch_queue.size()) * (1. / capacity))
|
|
return prefetch_queue
|