You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

334 lines
8.8 KiB

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for object_detection.utils.label_map_util."""
import os
import tensorflow as tf
from google.protobuf import text_format
from object_detection.protos import string_int_label_map_pb2
from object_detection.utils import label_map_util
class LabelMapUtilTest(tf.test.TestCase):
def _generate_label_map(self, num_classes):
label_map_proto = string_int_label_map_pb2.StringIntLabelMap()
for i in range(1, num_classes + 1):
item = label_map_proto.item.add()
item.id = i
item.name = 'label_' + str(i)
item.display_name = str(i)
return label_map_proto
def test_get_label_map_dict(self):
label_map_string = """
item {
id:2
name:'cat'
}
item {
id:1
name:'dog'
}
"""
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
with tf.gfile.Open(label_map_path, 'wb') as f:
f.write(label_map_string)
label_map_dict = label_map_util.get_label_map_dict(label_map_path)
self.assertEqual(label_map_dict['dog'], 1)
self.assertEqual(label_map_dict['cat'], 2)
def test_get_label_map_dict_display(self):
label_map_string = """
item {
id:2
display_name:'cat'
}
item {
id:1
display_name:'dog'
}
"""
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
with tf.gfile.Open(label_map_path, 'wb') as f:
f.write(label_map_string)
label_map_dict = label_map_util.get_label_map_dict(
label_map_path, use_display_name=True)
self.assertEqual(label_map_dict['dog'], 1)
self.assertEqual(label_map_dict['cat'], 2)
def test_load_bad_label_map(self):
label_map_string = """
item {
id:0
name:'class that should not be indexed at zero'
}
item {
id:2
name:'cat'
}
item {
id:1
name:'dog'
}
"""
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
with tf.gfile.Open(label_map_path, 'wb') as f:
f.write(label_map_string)
with self.assertRaises(ValueError):
label_map_util.load_labelmap(label_map_path)
def test_load_label_map_with_background(self):
label_map_string = """
item {
id:0
name:'background'
}
item {
id:2
name:'cat'
}
item {
id:1
name:'dog'
}
"""
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
with tf.gfile.Open(label_map_path, 'wb') as f:
f.write(label_map_string)
label_map_dict = label_map_util.get_label_map_dict(label_map_path)
self.assertEqual(label_map_dict['background'], 0)
self.assertEqual(label_map_dict['dog'], 1)
self.assertEqual(label_map_dict['cat'], 2)
def test_get_label_map_dict_with_fill_in_gaps_and_background(self):
label_map_string = """
item {
id:3
name:'cat'
}
item {
id:1
name:'dog'
}
"""
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
with tf.gfile.Open(label_map_path, 'wb') as f:
f.write(label_map_string)
label_map_dict = label_map_util.get_label_map_dict(
label_map_path, fill_in_gaps_and_background=True)
self.assertEqual(label_map_dict['background'], 0)
self.assertEqual(label_map_dict['dog'], 1)
self.assertEqual(label_map_dict['2'], 2)
self.assertEqual(label_map_dict['cat'], 3)
self.assertEqual(len(label_map_dict), max(label_map_dict.values()) + 1)
def test_keep_categories_with_unique_id(self):
label_map_proto = string_int_label_map_pb2.StringIntLabelMap()
label_map_string = """
item {
id:2
name:'cat'
}
item {
id:1
name:'child'
}
item {
id:1
name:'person'
}
item {
id:1
name:'n00007846'
}
"""
text_format.Merge(label_map_string, label_map_proto)
categories = label_map_util.convert_label_map_to_categories(
label_map_proto, max_num_classes=3)
self.assertListEqual([{
'id': 2,
'name': u'cat'
}, {
'id': 1,
'name': u'child'
}], categories)
def test_convert_label_map_to_categories_no_label_map(self):
categories = label_map_util.convert_label_map_to_categories(
None, max_num_classes=3)
expected_categories_list = [{
'name': u'category_1',
'id': 1
}, {
'name': u'category_2',
'id': 2
}, {
'name': u'category_3',
'id': 3
}]
self.assertListEqual(expected_categories_list, categories)
def test_convert_label_map_to_categories(self):
label_map_proto = self._generate_label_map(num_classes=4)
categories = label_map_util.convert_label_map_to_categories(
label_map_proto, max_num_classes=3)
expected_categories_list = [{
'name': u'1',
'id': 1
}, {
'name': u'2',
'id': 2
}, {
'name': u'3',
'id': 3
}]
self.assertListEqual(expected_categories_list, categories)
def test_convert_label_map_to_categories_with_few_classes(self):
label_map_proto = self._generate_label_map(num_classes=4)
cat_no_offset = label_map_util.convert_label_map_to_categories(
label_map_proto, max_num_classes=2)
expected_categories_list = [{
'name': u'1',
'id': 1
}, {
'name': u'2',
'id': 2
}]
self.assertListEqual(expected_categories_list, cat_no_offset)
def test_get_max_label_map_index(self):
num_classes = 4
label_map_proto = self._generate_label_map(num_classes=num_classes)
max_index = label_map_util.get_max_label_map_index(label_map_proto)
self.assertEqual(num_classes, max_index)
def test_create_category_index(self):
categories = [{'name': u'1', 'id': 1}, {'name': u'2', 'id': 2}]
category_index = label_map_util.create_category_index(categories)
self.assertDictEqual({
1: {
'name': u'1',
'id': 1
},
2: {
'name': u'2',
'id': 2
}
}, category_index)
def test_create_categories_from_labelmap(self):
label_map_string = """
item {
id:1
name:'dog'
}
item {
id:2
name:'cat'
}
"""
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
with tf.gfile.Open(label_map_path, 'wb') as f:
f.write(label_map_string)
categories = label_map_util.create_categories_from_labelmap(label_map_path)
self.assertListEqual([{
'name': u'dog',
'id': 1
}, {
'name': u'cat',
'id': 2
}], categories)
def test_create_category_index_from_labelmap(self):
label_map_string = """
item {
id:2
name:'cat'
}
item {
id:1
name:'dog'
}
"""
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
with tf.gfile.Open(label_map_path, 'wb') as f:
f.write(label_map_string)
category_index = label_map_util.create_category_index_from_labelmap(
label_map_path)
self.assertDictEqual({
1: {
'name': u'dog',
'id': 1
},
2: {
'name': u'cat',
'id': 2
}
}, category_index)
def test_create_category_index_from_labelmap_display(self):
label_map_string = """
item {
id:2
name:'cat'
display_name:'meow'
}
item {
id:1
name:'dog'
display_name:'woof'
}
"""
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
with tf.gfile.Open(label_map_path, 'wb') as f:
f.write(label_map_string)
self.assertDictEqual({
1: {
'name': u'dog',
'id': 1
},
2: {
'name': u'cat',
'id': 2
}
}, label_map_util.create_category_index_from_labelmap(
label_map_path, False))
self.assertDictEqual({
1: {
'name': u'woof',
'id': 1
},
2: {
'name': u'meow',
'id': 2
}
}, label_map_util.create_category_index_from_labelmap(label_map_path))
if __name__ == '__main__':
tf.test.main()