|
|
- # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
-
- """Tests for input_reader_builder."""
-
- import os
- import numpy as np
- import tensorflow as tf
-
- from google.protobuf import text_format
-
- from object_detection.builders import input_reader_builder
- from object_detection.core import standard_fields as fields
- from object_detection.protos import input_reader_pb2
- from object_detection.utils import dataset_util
-
-
- class InputReaderBuilderTest(tf.test.TestCase):
-
- def create_tf_record(self):
- path = os.path.join(self.get_temp_dir(), 'tfrecord')
- writer = tf.python_io.TFRecordWriter(path)
-
- image_tensor = np.random.randint(255, size=(4, 5, 3)).astype(np.uint8)
- flat_mask = (4 * 5) * [1.0]
- with self.test_session():
- encoded_jpeg = tf.image.encode_jpeg(tf.constant(image_tensor)).eval()
- example = tf.train.Example(features=tf.train.Features(feature={
- 'image/encoded': dataset_util.bytes_feature(encoded_jpeg),
- 'image/format': dataset_util.bytes_feature('jpeg'.encode('utf8')),
- 'image/height': dataset_util.int64_feature(4),
- 'image/width': dataset_util.int64_feature(5),
- 'image/object/bbox/xmin': dataset_util.float_list_feature([0.0]),
- 'image/object/bbox/xmax': dataset_util.float_list_feature([1.0]),
- 'image/object/bbox/ymin': dataset_util.float_list_feature([0.0]),
- 'image/object/bbox/ymax': dataset_util.float_list_feature([1.0]),
- 'image/object/class/label': dataset_util.int64_list_feature([2]),
- 'image/object/mask': dataset_util.float_list_feature(flat_mask),
- }))
- writer.write(example.SerializeToString())
- writer.close()
-
- return path
-
- def test_build_tf_record_input_reader(self):
- tf_record_path = self.create_tf_record()
-
- input_reader_text_proto = """
- shuffle: false
- num_readers: 1
- tf_record_input_reader {{
- input_path: '{0}'
- }}
- """.format(tf_record_path)
- input_reader_proto = input_reader_pb2.InputReader()
- text_format.Merge(input_reader_text_proto, input_reader_proto)
- tensor_dict = input_reader_builder.build(input_reader_proto)
-
- with tf.train.MonitoredSession() as sess:
- output_dict = sess.run(tensor_dict)
-
- self.assertTrue(fields.InputDataFields.groundtruth_instance_masks
- not in output_dict)
- self.assertEquals(
- (4, 5, 3), output_dict[fields.InputDataFields.image].shape)
- self.assertEquals(
- [2], output_dict[fields.InputDataFields.groundtruth_classes])
- self.assertEquals(
- (1, 4), output_dict[fields.InputDataFields.groundtruth_boxes].shape)
- self.assertAllEqual(
- [0.0, 0.0, 1.0, 1.0],
- output_dict[fields.InputDataFields.groundtruth_boxes][0])
-
- def test_build_tf_record_input_reader_and_load_instance_masks(self):
- tf_record_path = self.create_tf_record()
-
- input_reader_text_proto = """
- shuffle: false
- num_readers: 1
- load_instance_masks: true
- tf_record_input_reader {{
- input_path: '{0}'
- }}
- """.format(tf_record_path)
- input_reader_proto = input_reader_pb2.InputReader()
- text_format.Merge(input_reader_text_proto, input_reader_proto)
- tensor_dict = input_reader_builder.build(input_reader_proto)
-
- with tf.train.MonitoredSession() as sess:
- output_dict = sess.run(tensor_dict)
-
- self.assertEquals(
- (4, 5, 3), output_dict[fields.InputDataFields.image].shape)
- self.assertEquals(
- [2], output_dict[fields.InputDataFields.groundtruth_classes])
- self.assertEquals(
- (1, 4), output_dict[fields.InputDataFields.groundtruth_boxes].shape)
- self.assertAllEqual(
- [0.0, 0.0, 1.0, 1.0],
- output_dict[fields.InputDataFields.groundtruth_boxes][0])
- self.assertAllEqual(
- (1, 4, 5),
- output_dict[fields.InputDataFields.groundtruth_instance_masks].shape)
-
- def test_raises_error_with_no_input_paths(self):
- input_reader_text_proto = """
- shuffle: false
- num_readers: 1
- load_instance_masks: true
- """
- input_reader_proto = input_reader_pb2.InputReader()
- text_format.Merge(input_reader_text_proto, input_reader_proto)
- with self.assertRaises(ValueError):
- input_reader_builder.build(input_reader_proto)
-
- if __name__ == '__main__':
- tf.test.main()
|