|
|
- # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
-
- """Input reader builder.
-
- Creates data sources for DetectionModels from an InputReader config. See
- input_reader.proto for options.
-
- Note: If users wishes to also use their own InputReaders with the Object
- Detection configuration framework, they should define their own builder function
- that wraps the build function.
- """
-
- import tensorflow as tf
-
- from object_detection.data_decoders import tf_example_decoder
- from object_detection.protos import input_reader_pb2
-
- parallel_reader = tf.contrib.slim.parallel_reader
-
-
- def build(input_reader_config):
- """Builds a tensor dictionary based on the InputReader config.
-
- Args:
- input_reader_config: A input_reader_pb2.InputReader object.
-
- Returns:
- A tensor dict based on the input_reader_config.
-
- Raises:
- ValueError: On invalid input reader proto.
- ValueError: If no input paths are specified.
- """
- if not isinstance(input_reader_config, input_reader_pb2.InputReader):
- raise ValueError('input_reader_config not of type '
- 'input_reader_pb2.InputReader.')
-
- if input_reader_config.WhichOneof('input_reader') == 'tf_record_input_reader':
- config = input_reader_config.tf_record_input_reader
- if not config.input_path:
- raise ValueError('At least one input path must be specified in '
- '`input_reader_config`.')
- _, string_tensor = parallel_reader.parallel_read(
- config.input_path[:], # Convert `RepeatedScalarContainer` to list.
- reader_class=tf.TFRecordReader,
- num_epochs=(input_reader_config.num_epochs
- if input_reader_config.num_epochs else None),
- num_readers=input_reader_config.num_readers,
- shuffle=input_reader_config.shuffle,
- dtypes=[tf.string, tf.string],
- capacity=input_reader_config.queue_capacity,
- min_after_dequeue=input_reader_config.min_after_dequeue)
-
- label_map_proto_file = None
- if input_reader_config.HasField('label_map_path'):
- label_map_proto_file = input_reader_config.label_map_path
- decoder = tf_example_decoder.TfExampleDecoder(
- load_instance_masks=input_reader_config.load_instance_masks,
- instance_mask_type=input_reader_config.mask_type,
- label_map_proto_file=label_map_proto_file)
- return decoder.decode(string_tensor)
-
- raise ValueError('Unsupported input_reader_config.')
|