You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

61 lines
2.5 KiB

6 years ago
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Provides functions to prefetch tensors to feed into models."""
  16. import tensorflow as tf
  17. def prefetch(tensor_dict, capacity):
  18. """Creates a prefetch queue for tensors.
  19. Creates a FIFO queue to asynchronously enqueue tensor_dicts and returns a
  20. dequeue op that evaluates to a tensor_dict. This function is useful in
  21. prefetching preprocessed tensors so that the data is readily available for
  22. consumers.
  23. Example input pipeline when you don't need batching:
  24. ----------------------------------------------------
  25. key, string_tensor = slim.parallel_reader.parallel_read(...)
  26. tensor_dict = decoder.decode(string_tensor)
  27. tensor_dict = preprocessor.preprocess(tensor_dict, ...)
  28. prefetch_queue = prefetcher.prefetch(tensor_dict, capacity=20)
  29. tensor_dict = prefetch_queue.dequeue()
  30. outputs = Model(tensor_dict)
  31. ...
  32. ----------------------------------------------------
  33. For input pipelines with batching, refer to core/batcher.py
  34. Args:
  35. tensor_dict: a dictionary of tensors to prefetch.
  36. capacity: the size of the prefetch queue.
  37. Returns:
  38. a FIFO prefetcher queue
  39. """
  40. names = list(tensor_dict.keys())
  41. dtypes = [t.dtype for t in tensor_dict.values()]
  42. shapes = [t.get_shape() for t in tensor_dict.values()]
  43. prefetch_queue = tf.PaddingFIFOQueue(capacity, dtypes=dtypes,
  44. shapes=shapes,
  45. names=names,
  46. name='prefetch_queue')
  47. enqueue_op = prefetch_queue.enqueue(tensor_dict)
  48. tf.train.queue_runner.add_queue_runner(tf.train.queue_runner.QueueRunner(
  49. prefetch_queue, [enqueue_op]))
  50. tf.summary.scalar('queue/%s/fraction_of_%d_full' % (prefetch_queue.name,
  51. capacity),
  52. tf.to_float(prefetch_queue.size()) * (1. / capacity))
  53. return prefetch_queue