- # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- r"""Infers detections on a TFRecord of TFExamples given an inference graph.
- Example usage:
- ./infer_detections \
- --input_tfrecord_paths=/path/to/input/tfrecord1,/path/to/input/tfrecord2 \
- --output_tfrecord_path_prefix=/path/to/output/detections.tfrecord \
- --inference_graph=/path/to/frozen_weights_inference_graph.pb
- The output is a TFRecord of TFExamples. Each TFExample from the input is first
- augmented with detections from the inference graph and then copied to the
- output.
- The input and output nodes of the inference graph are expected to have the same
- types, shapes, and semantics, as the input and output nodes of graphs produced
- by export_inference_graph.py, when run with --input_type=image_tensor.
- The script can also discard the image pixels in the output. This greatly
- reduces the output size and can potentially accelerate reading data in
- subsequent processing steps that don't require the images (e.g. computing
- metrics).
- """
- import itertools
- import tensorflow as tf
- from object_detection.inference import detection_inference
- tf.flags.DEFINE_string('input_tfrecord_paths', None,
- 'A comma separated list of paths to input TFRecords.')
- tf.flags.DEFINE_string('output_tfrecord_path', None,
- 'Path to the output TFRecord.')
- tf.flags.DEFINE_string('inference_graph', None,
- 'Path to the inference graph with embedded weights.')
- tf.flags.DEFINE_boolean('discard_image_pixels', False,
- 'Discards the images in the output TFExamples. This'
- ' significantly reduces the output size and is useful'
- ' if the subsequent tools don\'t need access to the'
- ' images (e.g. when computing evaluation measures).')
- FLAGS = tf.flags.FLAGS
- def main(_):
- tf.logging.set_verbosity(tf.logging.INFO)
- required_flags = ['input_tfrecord_paths', 'output_tfrecord_path',
- 'inference_graph']
- for flag_name in required_flags:
- if not getattr(FLAGS, flag_name):
- raise ValueError('Flag --{} is required'.format(flag_name))
- with tf.Session() as sess:
- input_tfrecord_paths = [
- v for v in FLAGS.input_tfrecord_paths.split(',') if v]
- tf.logging.info('Reading input from %d files', len(input_tfrecord_paths))
- serialized_example_tensor, image_tensor = detection_inference.build_input(
- input_tfrecord_paths)
- tf.logging.info('Reading graph and building model...')
- (detected_boxes_tensor, detected_scores_tensor,
- detected_labels_tensor) = detection_inference.build_inference_graph(
- image_tensor, FLAGS.inference_graph)
- tf.logging.info('Running inference and writing output to {}'.format(
- FLAGS.output_tfrecord_path))
- sess.run(tf.local_variables_initializer())
- tf.train.start_queue_runners()
- with tf.python_io.TFRecordWriter(
- FLAGS.output_tfrecord_path) as tf_record_writer:
- try:
- for counter in itertools.count():
- tf.logging.log_every_n(tf.logging.INFO, 'Processed %d images...', 10,
- counter)
- tf_example = detection_inference.infer_detections_and_add_to_example(
- serialized_example_tensor, detected_boxes_tensor,
- detected_scores_tensor, detected_labels_tensor,
- FLAGS.discard_image_pixels)
- tf_record_writer.write(tf_example.SerializeToString())
- except tf.errors.OutOfRangeError:
- tf.logging.info('Finished processing records')
- if __name__ == '__main__':
- tf.app.run()