|
|
- # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """tf.data.Dataset builder.
-
- Creates data sources for DetectionModels from an InputReader config. See
- input_reader.proto for options.
-
- Note: If users wishes to also use their own InputReaders with the Object
- Detection configuration framework, they should define their own builder function
- that wraps the build function.
- """
- import functools
- import tensorflow as tf
-
- from object_detection.data_decoders import tf_example_decoder
- from object_detection.protos import input_reader_pb2
-
-
- def make_initializable_iterator(dataset):
- """Creates an iterator, and initializes tables.
-
- This is useful in cases where make_one_shot_iterator wouldn't work because
- the graph contains a hash table that needs to be initialized.
-
- Args:
- dataset: A `tf.data.Dataset` object.
-
- Returns:
- A `tf.data.Iterator`.
- """
- iterator = dataset.make_initializable_iterator()
- tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
- return iterator
-
-
- def read_dataset(file_read_func, input_files, config):
- """Reads a dataset, and handles repetition and shuffling.
-
- Args:
- file_read_func: Function to use in tf.contrib.data.parallel_interleave, to
- read every individual file into a tf.data.Dataset.
- input_files: A list of file paths to read.
- config: A input_reader_builder.InputReader object.
-
- Returns:
- A tf.data.Dataset of (undecoded) tf-records based on config.
- """
- # Shard, shuffle, and read files.
- filenames = tf.gfile.Glob(input_files)
- num_readers = config.num_readers
- if num_readers > len(filenames):
- num_readers = len(filenames)
- tf.logging.warning('num_readers has been reduced to %d to match input file '
- 'shards.' % num_readers)
- filename_dataset = tf.data.Dataset.from_tensor_slices(filenames)
- if config.shuffle:
- filename_dataset = filename_dataset.shuffle(
- config.filenames_shuffle_buffer_size)
- elif num_readers > 1:
- tf.logging.warning('`shuffle` is false, but the input data stream is '
- 'still slightly shuffled since `num_readers` > 1.')
- filename_dataset = filename_dataset.repeat(config.num_epochs or None)
- records_dataset = filename_dataset.apply(
- tf.contrib.data.parallel_interleave(
- file_read_func,
- cycle_length=num_readers,
- block_length=config.read_block_length,
- sloppy=config.shuffle))
- if config.shuffle:
- records_dataset = records_dataset.shuffle(config.shuffle_buffer_size)
- return records_dataset
-
-
- def build(input_reader_config, batch_size=None, transform_input_data_fn=None):
- """Builds a tf.data.Dataset.
-
- Builds a tf.data.Dataset by applying the `transform_input_data_fn` on all
- records. Applies a padded batch to the resulting dataset.
-
- Args:
- input_reader_config: A input_reader_pb2.InputReader object.
- batch_size: Batch size. If batch size is None, no batching is performed.
- transform_input_data_fn: Function to apply transformation to all records,
- or None if no extra decoding is required.
-
- Returns:
- A tf.data.Dataset based on the input_reader_config.
-
- Raises:
- ValueError: On invalid input reader proto.
- ValueError: If no input paths are specified.
- """
- if not isinstance(input_reader_config, input_reader_pb2.InputReader):
- raise ValueError('input_reader_config not of type '
- 'input_reader_pb2.InputReader.')
-
- if input_reader_config.WhichOneof('input_reader') == 'tf_record_input_reader':
- config = input_reader_config.tf_record_input_reader
- if not config.input_path:
- raise ValueError('At least one input path must be specified in '
- '`input_reader_config`.')
-
- label_map_proto_file = None
- if input_reader_config.HasField('label_map_path'):
- label_map_proto_file = input_reader_config.label_map_path
- decoder = tf_example_decoder.TfExampleDecoder(
- load_instance_masks=input_reader_config.load_instance_masks,
- load_multiclass_scores=input_reader_config.load_multiclass_scores,
- instance_mask_type=input_reader_config.mask_type,
- label_map_proto_file=label_map_proto_file,
- use_display_name=input_reader_config.use_display_name,
- num_additional_channels=input_reader_config.num_additional_channels)
-
- def process_fn(value):
- """Sets up tf graph that decodes, transforms and pads input data."""
- processed_tensors = decoder.decode(value)
- if transform_input_data_fn is not None:
- processed_tensors = transform_input_data_fn(processed_tensors)
- return processed_tensors
-
- dataset = read_dataset(
- functools.partial(tf.data.TFRecordDataset, buffer_size=8 * 1000 * 1000),
- config.input_path[:], input_reader_config)
- if input_reader_config.sample_1_of_n_examples > 1:
- dataset = dataset.shard(input_reader_config.sample_1_of_n_examples, 0)
- # TODO(rathodv): make batch size a required argument once the old binaries
- # are deleted.
- if batch_size:
- num_parallel_calls = batch_size * input_reader_config.num_parallel_batches
- else:
- num_parallel_calls = input_reader_config.num_parallel_map_calls
- # TODO(b/123952794): Migrate to V2 function.
- if hasattr(dataset, 'map_with_legacy_function'):
- data_map_fn = dataset.map_with_legacy_function
- else:
- data_map_fn = dataset.map
- dataset = data_map_fn(process_fn, num_parallel_calls=num_parallel_calls)
- if batch_size:
- dataset = dataset.apply(
- tf.contrib.data.batch_and_drop_remainder(batch_size))
- dataset = dataset.prefetch(input_reader_config.num_prefetch_batches)
- return dataset
-
- raise ValueError('Unsupported input_reader_config.')
|