|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""tf.data.Dataset builder.
|
|
|
|
Creates data sources for DetectionModels from an InputReader config. See
|
|
input_reader.proto for options.
|
|
|
|
Note: If users wishes to also use their own InputReaders with the Object
|
|
Detection configuration framework, they should define their own builder function
|
|
that wraps the build function.
|
|
"""
|
|
import functools
|
|
import tensorflow as tf
|
|
|
|
from object_detection.data_decoders import tf_example_decoder
|
|
from object_detection.protos import input_reader_pb2
|
|
|
|
|
|
def make_initializable_iterator(dataset):
|
|
"""Creates an iterator, and initializes tables.
|
|
|
|
This is useful in cases where make_one_shot_iterator wouldn't work because
|
|
the graph contains a hash table that needs to be initialized.
|
|
|
|
Args:
|
|
dataset: A `tf.data.Dataset` object.
|
|
|
|
Returns:
|
|
A `tf.data.Iterator`.
|
|
"""
|
|
iterator = dataset.make_initializable_iterator()
|
|
tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
|
|
return iterator
|
|
|
|
|
|
def read_dataset(file_read_func, input_files, config):
|
|
"""Reads a dataset, and handles repetition and shuffling.
|
|
|
|
Args:
|
|
file_read_func: Function to use in tf.contrib.data.parallel_interleave, to
|
|
read every individual file into a tf.data.Dataset.
|
|
input_files: A list of file paths to read.
|
|
config: A input_reader_builder.InputReader object.
|
|
|
|
Returns:
|
|
A tf.data.Dataset of (undecoded) tf-records based on config.
|
|
|
|
Raises:
|
|
RuntimeError: If no files are found at the supplied path(s).
|
|
"""
|
|
# Shard, shuffle, and read files.
|
|
filenames = tf.gfile.Glob(input_files)
|
|
if not filenames:
|
|
raise RuntimeError('Did not find any input files matching the glob pattern '
|
|
'{}'.format(input_files))
|
|
num_readers = config.num_readers
|
|
if num_readers > len(filenames):
|
|
num_readers = len(filenames)
|
|
tf.logging.warning('num_readers has been reduced to %d to match input file '
|
|
'shards.' % num_readers)
|
|
filename_dataset = tf.data.Dataset.from_tensor_slices(filenames)
|
|
if config.shuffle:
|
|
filename_dataset = filename_dataset.shuffle(
|
|
config.filenames_shuffle_buffer_size)
|
|
elif num_readers > 1:
|
|
tf.logging.warning('`shuffle` is false, but the input data stream is '
|
|
'still slightly shuffled since `num_readers` > 1.')
|
|
filename_dataset = filename_dataset.repeat(config.num_epochs or None)
|
|
records_dataset = filename_dataset.apply(
|
|
tf.contrib.data.parallel_interleave(
|
|
file_read_func,
|
|
cycle_length=num_readers,
|
|
block_length=config.read_block_length,
|
|
sloppy=config.shuffle))
|
|
if config.shuffle:
|
|
records_dataset = records_dataset.shuffle(config.shuffle_buffer_size)
|
|
return records_dataset
|
|
|
|
|
|
def build(input_reader_config, batch_size=None, transform_input_data_fn=None):
|
|
"""Builds a tf.data.Dataset.
|
|
|
|
Builds a tf.data.Dataset by applying the `transform_input_data_fn` on all
|
|
records. Applies a padded batch to the resulting dataset.
|
|
|
|
Args:
|
|
input_reader_config: A input_reader_pb2.InputReader object.
|
|
batch_size: Batch size. If batch size is None, no batching is performed.
|
|
transform_input_data_fn: Function to apply transformation to all records,
|
|
or None if no extra decoding is required.
|
|
|
|
Returns:
|
|
A tf.data.Dataset based on the input_reader_config.
|
|
|
|
Raises:
|
|
ValueError: On invalid input reader proto.
|
|
ValueError: If no input paths are specified.
|
|
"""
|
|
if not isinstance(input_reader_config, input_reader_pb2.InputReader):
|
|
raise ValueError('input_reader_config not of type '
|
|
'input_reader_pb2.InputReader.')
|
|
|
|
if input_reader_config.WhichOneof('input_reader') == 'tf_record_input_reader':
|
|
config = input_reader_config.tf_record_input_reader
|
|
if not config.input_path:
|
|
raise ValueError('At least one input path must be specified in '
|
|
'`input_reader_config`.')
|
|
|
|
label_map_proto_file = None
|
|
if input_reader_config.HasField('label_map_path'):
|
|
label_map_proto_file = input_reader_config.label_map_path
|
|
decoder = tf_example_decoder.TfExampleDecoder(
|
|
load_instance_masks=input_reader_config.load_instance_masks,
|
|
load_multiclass_scores=input_reader_config.load_multiclass_scores,
|
|
instance_mask_type=input_reader_config.mask_type,
|
|
label_map_proto_file=label_map_proto_file,
|
|
use_display_name=input_reader_config.use_display_name,
|
|
num_additional_channels=input_reader_config.num_additional_channels)
|
|
|
|
def process_fn(value):
|
|
"""Sets up tf graph that decodes, transforms and pads input data."""
|
|
processed_tensors = decoder.decode(value)
|
|
if transform_input_data_fn is not None:
|
|
processed_tensors = transform_input_data_fn(processed_tensors)
|
|
return processed_tensors
|
|
|
|
dataset = read_dataset(
|
|
functools.partial(tf.data.TFRecordDataset, buffer_size=8 * 1000 * 1000),
|
|
config.input_path[:], input_reader_config)
|
|
if input_reader_config.sample_1_of_n_examples > 1:
|
|
dataset = dataset.shard(input_reader_config.sample_1_of_n_examples, 0)
|
|
# TODO(rathodv): make batch size a required argument once the old binaries
|
|
# are deleted.
|
|
if batch_size:
|
|
num_parallel_calls = batch_size * input_reader_config.num_parallel_batches
|
|
else:
|
|
num_parallel_calls = input_reader_config.num_parallel_map_calls
|
|
# TODO(b/123952794): Migrate to V2 function.
|
|
if hasattr(dataset, 'map_with_legacy_function'):
|
|
data_map_fn = dataset.map_with_legacy_function
|
|
else:
|
|
data_map_fn = dataset.map
|
|
dataset = data_map_fn(process_fn, num_parallel_calls=num_parallel_calls)
|
|
if batch_size:
|
|
dataset = dataset.apply(
|
|
tf.contrib.data.batch_and_drop_remainder(batch_size))
|
|
dataset = dataset.prefetch(input_reader_config.num_prefetch_batches)
|
|
return dataset
|
|
|
|
raise ValueError('Unsupported input_reader_config.')
|