|
|
- # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """Utility functions for detection inference."""
- from __future__ import division
-
- import tensorflow as tf
-
- from object_detection.core import standard_fields
-
-
- def build_input(tfrecord_paths):
- """Builds the graph's input.
-
- Args:
- tfrecord_paths: List of paths to the input TFRecords
-
- Returns:
- serialized_example_tensor: The next serialized example. String scalar Tensor
- image_tensor: The decoded image of the example. Uint8 tensor,
- shape=[1, None, None,3]
- """
- filename_queue = tf.train.string_input_producer(
- tfrecord_paths, shuffle=False, num_epochs=1)
-
- tf_record_reader = tf.TFRecordReader()
- _, serialized_example_tensor = tf_record_reader.read(filename_queue)
- features = tf.parse_single_example(
- serialized_example_tensor,
- features={
- standard_fields.TfExampleFields.image_encoded:
- tf.FixedLenFeature([], tf.string),
- })
- encoded_image = features[standard_fields.TfExampleFields.image_encoded]
- image_tensor = tf.image.decode_image(encoded_image, channels=3)
- image_tensor.set_shape([None, None, 3])
- image_tensor = tf.expand_dims(image_tensor, 0)
-
- return serialized_example_tensor, image_tensor
-
-
- def build_inference_graph(image_tensor, inference_graph_path):
- """Loads the inference graph and connects it to the input image.
-
- Args:
- image_tensor: The input image. uint8 tensor, shape=[1, None, None, 3]
- inference_graph_path: Path to the inference graph with embedded weights
-
- Returns:
- detected_boxes_tensor: Detected boxes. Float tensor,
- shape=[num_detections, 4]
- detected_scores_tensor: Detected scores. Float tensor,
- shape=[num_detections]
- detected_labels_tensor: Detected labels. Int64 tensor,
- shape=[num_detections]
- """
- with tf.gfile.Open(inference_graph_path, 'rb') as graph_def_file:
- graph_content = graph_def_file.read()
- graph_def = tf.GraphDef()
- graph_def.MergeFromString(graph_content)
-
- tf.import_graph_def(
- graph_def, name='', input_map={'image_tensor': image_tensor})
-
- g = tf.get_default_graph()
-
- num_detections_tensor = tf.squeeze(
- g.get_tensor_by_name('num_detections:0'), 0)
- num_detections_tensor = tf.cast(num_detections_tensor, tf.int32)
-
- detected_boxes_tensor = tf.squeeze(
- g.get_tensor_by_name('detection_boxes:0'), 0)
- detected_boxes_tensor = detected_boxes_tensor[:num_detections_tensor]
-
- detected_scores_tensor = tf.squeeze(
- g.get_tensor_by_name('detection_scores:0'), 0)
- detected_scores_tensor = detected_scores_tensor[:num_detections_tensor]
-
- detected_labels_tensor = tf.squeeze(
- g.get_tensor_by_name('detection_classes:0'), 0)
- detected_labels_tensor = tf.cast(detected_labels_tensor, tf.int64)
- detected_labels_tensor = detected_labels_tensor[:num_detections_tensor]
-
- return detected_boxes_tensor, detected_scores_tensor, detected_labels_tensor
-
-
- def infer_detections_and_add_to_example(
- serialized_example_tensor, detected_boxes_tensor, detected_scores_tensor,
- detected_labels_tensor, discard_image_pixels):
- """Runs the supplied tensors and adds the inferred detections to the example.
-
- Args:
- serialized_example_tensor: Serialized TF example. Scalar string tensor
- detected_boxes_tensor: Detected boxes. Float tensor,
- shape=[num_detections, 4]
- detected_scores_tensor: Detected scores. Float tensor,
- shape=[num_detections]
- detected_labels_tensor: Detected labels. Int64 tensor,
- shape=[num_detections]
- discard_image_pixels: If true, discards the image from the result
- Returns:
- The de-serialized TF example augmented with the inferred detections.
- """
- tf_example = tf.train.Example()
- (serialized_example, detected_boxes, detected_scores,
- detected_classes) = tf.get_default_session().run([
- serialized_example_tensor, detected_boxes_tensor, detected_scores_tensor,
- detected_labels_tensor
- ])
- detected_boxes = detected_boxes.T
-
- tf_example.ParseFromString(serialized_example)
- feature = tf_example.features.feature
- feature[standard_fields.TfExampleFields.
- detection_score].float_list.value[:] = detected_scores
- feature[standard_fields.TfExampleFields.
- detection_bbox_ymin].float_list.value[:] = detected_boxes[0]
- feature[standard_fields.TfExampleFields.
- detection_bbox_xmin].float_list.value[:] = detected_boxes[1]
- feature[standard_fields.TfExampleFields.
- detection_bbox_ymax].float_list.value[:] = detected_boxes[2]
- feature[standard_fields.TfExampleFields.
- detection_bbox_xmax].float_list.value[:] = detected_boxes[3]
- feature[standard_fields.TfExampleFields.
- detection_class_label].int64_list.value[:] = detected_classes
-
- if discard_image_pixels:
- del feature[standard_fields.TfExampleFields.image_encoded]
-
- return tf_example
|