# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
r"""Tests for detection_inference.py."""
|
|
|
|
import os
|
|
import StringIO
|
|
|
|
import numpy as np
|
|
from PIL import Image
|
|
import tensorflow as tf
|
|
|
|
from object_detection.core import standard_fields
|
|
from object_detection.inference import detection_inference
|
|
from object_detection.utils import dataset_util
|
|
|
|
|
|
def get_mock_tfrecord_path():
|
|
return os.path.join(tf.test.get_temp_dir(), 'mock.tfrec')
|
|
|
|
|
|
def create_mock_tfrecord():
|
|
pil_image = Image.fromarray(np.array([[[123, 0, 0]]], dtype=np.uint8), 'RGB')
|
|
image_output_stream = StringIO.StringIO()
|
|
pil_image.save(image_output_stream, format='png')
|
|
encoded_image = image_output_stream.getvalue()
|
|
|
|
feature_map = {
|
|
'test_field':
|
|
dataset_util.float_list_feature([1, 2, 3, 4]),
|
|
standard_fields.TfExampleFields.image_encoded:
|
|
dataset_util.bytes_feature(encoded_image),
|
|
}
|
|
|
|
tf_example = tf.train.Example(features=tf.train.Features(feature=feature_map))
|
|
with tf.python_io.TFRecordWriter(get_mock_tfrecord_path()) as writer:
|
|
writer.write(tf_example.SerializeToString())
|
|
|
|
|
|
def get_mock_graph_path():
|
|
return os.path.join(tf.test.get_temp_dir(), 'mock_graph.pb')
|
|
|
|
|
|
def create_mock_graph():
|
|
g = tf.Graph()
|
|
with g.as_default():
|
|
in_image_tensor = tf.placeholder(
|
|
tf.uint8, shape=[1, None, None, 3], name='image_tensor')
|
|
tf.constant([2.0], name='num_detections')
|
|
tf.constant(
|
|
[[[0, 0.8, 0.7, 1], [0.1, 0.2, 0.8, 0.9], [0.2, 0.3, 0.4, 0.5]]],
|
|
name='detection_boxes')
|
|
tf.constant([[0.1, 0.2, 0.3]], name='detection_scores')
|
|
tf.identity(
|
|
tf.constant([[1.0, 2.0, 3.0]]) *
|
|
tf.reduce_sum(tf.cast(in_image_tensor, dtype=tf.float32)),
|
|
name='detection_classes')
|
|
graph_def = g.as_graph_def()
|
|
|
|
with tf.gfile.Open(get_mock_graph_path(), 'w') as fl:
|
|
fl.write(graph_def.SerializeToString())
|
|
|
|
|
|
class InferDetectionsTests(tf.test.TestCase):
|
|
|
|
def test_simple(self):
|
|
create_mock_graph()
|
|
create_mock_tfrecord()
|
|
|
|
serialized_example_tensor, image_tensor = detection_inference.build_input(
|
|
[get_mock_tfrecord_path()])
|
|
self.assertAllEqual(image_tensor.get_shape().as_list(), [1, None, None, 3])
|
|
|
|
(detected_boxes_tensor, detected_scores_tensor,
|
|
detected_labels_tensor) = detection_inference.build_inference_graph(
|
|
image_tensor, get_mock_graph_path())
|
|
|
|
with self.test_session(use_gpu=False) as sess:
|
|
sess.run(tf.global_variables_initializer())
|
|
sess.run(tf.local_variables_initializer())
|
|
tf.train.start_queue_runners()
|
|
|
|
tf_example = detection_inference.infer_detections_and_add_to_example(
|
|
serialized_example_tensor, detected_boxes_tensor,
|
|
detected_scores_tensor, detected_labels_tensor, False)
|
|
|
|
self.assertProtoEquals(r"""
|
|
features {
|
|
feature {
|
|
key: "image/detection/bbox/ymin"
|
|
value { float_list { value: [0.0, 0.1] } } }
|
|
feature {
|
|
key: "image/detection/bbox/xmin"
|
|
value { float_list { value: [0.8, 0.2] } } }
|
|
feature {
|
|
key: "image/detection/bbox/ymax"
|
|
value { float_list { value: [0.7, 0.8] } } }
|
|
feature {
|
|
key: "image/detection/bbox/xmax"
|
|
value { float_list { value: [1.0, 0.9] } } }
|
|
feature {
|
|
key: "image/detection/label"
|
|
value { int64_list { value: [123, 246] } } }
|
|
feature {
|
|
key: "image/detection/score"
|
|
value { float_list { value: [0.1, 0.2] } } }
|
|
feature {
|
|
key: "image/encoded"
|
|
value { bytes_list { value:
|
|
"\211PNG\r\n\032\n\000\000\000\rIHDR\000\000\000\001\000\000"
|
|
"\000\001\010\002\000\000\000\220wS\336\000\000\000\022IDATx"
|
|
"\234b\250f`\000\000\000\000\377\377\003\000\001u\000|gO\242"
|
|
"\213\000\000\000\000IEND\256B`\202" } } }
|
|
feature {
|
|
key: "test_field"
|
|
value { float_list { value: [1.0, 2.0, 3.0, 4.0] } } } }
|
|
""", tf_example)
|
|
|
|
def test_discard_image(self):
|
|
create_mock_graph()
|
|
create_mock_tfrecord()
|
|
|
|
serialized_example_tensor, image_tensor = detection_inference.build_input(
|
|
[get_mock_tfrecord_path()])
|
|
(detected_boxes_tensor, detected_scores_tensor,
|
|
detected_labels_tensor) = detection_inference.build_inference_graph(
|
|
image_tensor, get_mock_graph_path())
|
|
|
|
with self.test_session(use_gpu=False) as sess:
|
|
sess.run(tf.global_variables_initializer())
|
|
sess.run(tf.local_variables_initializer())
|
|
tf.train.start_queue_runners()
|
|
|
|
tf_example = detection_inference.infer_detections_and_add_to_example(
|
|
serialized_example_tensor, detected_boxes_tensor,
|
|
detected_scores_tensor, detected_labels_tensor, True)
|
|
|
|
self.assertProtoEquals(r"""
|
|
features {
|
|
feature {
|
|
key: "image/detection/bbox/ymin"
|
|
value { float_list { value: [0.0, 0.1] } } }
|
|
feature {
|
|
key: "image/detection/bbox/xmin"
|
|
value { float_list { value: [0.8, 0.2] } } }
|
|
feature {
|
|
key: "image/detection/bbox/ymax"
|
|
value { float_list { value: [0.7, 0.8] } } }
|
|
feature {
|
|
key: "image/detection/bbox/xmax"
|
|
value { float_list { value: [1.0, 0.9] } } }
|
|
feature {
|
|
key: "image/detection/label"
|
|
value { int64_list { value: [123, 246] } } }
|
|
feature {
|
|
key: "image/detection/score"
|
|
value { float_list { value: [0.1, 0.2] } } }
|
|
feature {
|
|
key: "test_field"
|
|
value { float_list { value: [1.0, 2.0, 3.0, 4.0] } } } }
|
|
""", tf_example)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
tf.test.main()
|