|
|
- # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
-
- """Provides functions to prefetch tensors to feed into models."""
- import tensorflow as tf
-
-
- def prefetch(tensor_dict, capacity):
- """Creates a prefetch queue for tensors.
-
- Creates a FIFO queue to asynchronously enqueue tensor_dicts and returns a
- dequeue op that evaluates to a tensor_dict. This function is useful in
- prefetching preprocessed tensors so that the data is readily available for
- consumers.
-
- Example input pipeline when you don't need batching:
- ----------------------------------------------------
- key, string_tensor = slim.parallel_reader.parallel_read(...)
- tensor_dict = decoder.decode(string_tensor)
- tensor_dict = preprocessor.preprocess(tensor_dict, ...)
- prefetch_queue = prefetcher.prefetch(tensor_dict, capacity=20)
- tensor_dict = prefetch_queue.dequeue()
- outputs = Model(tensor_dict)
- ...
- ----------------------------------------------------
-
- For input pipelines with batching, refer to core/batcher.py
-
- Args:
- tensor_dict: a dictionary of tensors to prefetch.
- capacity: the size of the prefetch queue.
-
- Returns:
- a FIFO prefetcher queue
- """
- names = list(tensor_dict.keys())
- dtypes = [t.dtype for t in tensor_dict.values()]
- shapes = [t.get_shape() for t in tensor_dict.values()]
- prefetch_queue = tf.PaddingFIFOQueue(capacity, dtypes=dtypes,
- shapes=shapes,
- names=names,
- name='prefetch_queue')
- enqueue_op = prefetch_queue.enqueue(tensor_dict)
- tf.train.queue_runner.add_queue_runner(tf.train.queue_runner.QueueRunner(
- prefetch_queue, [enqueue_op]))
- tf.summary.scalar('queue/%s/fraction_of_%d_full' % (prefetch_queue.name,
- capacity),
- tf.to_float(prefetch_queue.size()) * (1. / capacity))
- return prefetch_queue
|