# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Tests for object_detection.utils.label_map_util."""
|
|
|
|
import os
|
|
import tensorflow as tf
|
|
|
|
from google.protobuf import text_format
|
|
from object_detection.protos import string_int_label_map_pb2
|
|
from object_detection.utils import label_map_util
|
|
|
|
|
|
class LabelMapUtilTest(tf.test.TestCase):
|
|
|
|
def _generate_label_map(self, num_classes):
|
|
label_map_proto = string_int_label_map_pb2.StringIntLabelMap()
|
|
for i in range(1, num_classes + 1):
|
|
item = label_map_proto.item.add()
|
|
item.id = i
|
|
item.name = 'label_' + str(i)
|
|
item.display_name = str(i)
|
|
return label_map_proto
|
|
|
|
def test_get_label_map_dict(self):
|
|
label_map_string = """
|
|
item {
|
|
id:2
|
|
name:'cat'
|
|
}
|
|
item {
|
|
id:1
|
|
name:'dog'
|
|
}
|
|
"""
|
|
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
|
|
with tf.gfile.Open(label_map_path, 'wb') as f:
|
|
f.write(label_map_string)
|
|
|
|
label_map_dict = label_map_util.get_label_map_dict(label_map_path)
|
|
self.assertEqual(label_map_dict['dog'], 1)
|
|
self.assertEqual(label_map_dict['cat'], 2)
|
|
|
|
def test_get_label_map_dict_display(self):
|
|
label_map_string = """
|
|
item {
|
|
id:2
|
|
display_name:'cat'
|
|
}
|
|
item {
|
|
id:1
|
|
display_name:'dog'
|
|
}
|
|
"""
|
|
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
|
|
with tf.gfile.Open(label_map_path, 'wb') as f:
|
|
f.write(label_map_string)
|
|
|
|
label_map_dict = label_map_util.get_label_map_dict(
|
|
label_map_path, use_display_name=True)
|
|
self.assertEqual(label_map_dict['dog'], 1)
|
|
self.assertEqual(label_map_dict['cat'], 2)
|
|
|
|
def test_load_bad_label_map(self):
|
|
label_map_string = """
|
|
item {
|
|
id:0
|
|
name:'class that should not be indexed at zero'
|
|
}
|
|
item {
|
|
id:2
|
|
name:'cat'
|
|
}
|
|
item {
|
|
id:1
|
|
name:'dog'
|
|
}
|
|
"""
|
|
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
|
|
with tf.gfile.Open(label_map_path, 'wb') as f:
|
|
f.write(label_map_string)
|
|
|
|
with self.assertRaises(ValueError):
|
|
label_map_util.load_labelmap(label_map_path)
|
|
|
|
def test_load_label_map_with_background(self):
|
|
label_map_string = """
|
|
item {
|
|
id:0
|
|
name:'background'
|
|
}
|
|
item {
|
|
id:2
|
|
name:'cat'
|
|
}
|
|
item {
|
|
id:1
|
|
name:'dog'
|
|
}
|
|
"""
|
|
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
|
|
with tf.gfile.Open(label_map_path, 'wb') as f:
|
|
f.write(label_map_string)
|
|
|
|
label_map_dict = label_map_util.get_label_map_dict(label_map_path)
|
|
self.assertEqual(label_map_dict['background'], 0)
|
|
self.assertEqual(label_map_dict['dog'], 1)
|
|
self.assertEqual(label_map_dict['cat'], 2)
|
|
|
|
def test_get_label_map_dict_with_fill_in_gaps_and_background(self):
|
|
label_map_string = """
|
|
item {
|
|
id:3
|
|
name:'cat'
|
|
}
|
|
item {
|
|
id:1
|
|
name:'dog'
|
|
}
|
|
"""
|
|
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
|
|
with tf.gfile.Open(label_map_path, 'wb') as f:
|
|
f.write(label_map_string)
|
|
|
|
label_map_dict = label_map_util.get_label_map_dict(
|
|
label_map_path, fill_in_gaps_and_background=True)
|
|
|
|
self.assertEqual(label_map_dict['background'], 0)
|
|
self.assertEqual(label_map_dict['dog'], 1)
|
|
self.assertEqual(label_map_dict['2'], 2)
|
|
self.assertEqual(label_map_dict['cat'], 3)
|
|
self.assertEqual(len(label_map_dict), max(label_map_dict.values()) + 1)
|
|
|
|
def test_keep_categories_with_unique_id(self):
|
|
label_map_proto = string_int_label_map_pb2.StringIntLabelMap()
|
|
label_map_string = """
|
|
item {
|
|
id:2
|
|
name:'cat'
|
|
}
|
|
item {
|
|
id:1
|
|
name:'child'
|
|
}
|
|
item {
|
|
id:1
|
|
name:'person'
|
|
}
|
|
item {
|
|
id:1
|
|
name:'n00007846'
|
|
}
|
|
"""
|
|
text_format.Merge(label_map_string, label_map_proto)
|
|
categories = label_map_util.convert_label_map_to_categories(
|
|
label_map_proto, max_num_classes=3)
|
|
self.assertListEqual([{
|
|
'id': 2,
|
|
'name': u'cat'
|
|
}, {
|
|
'id': 1,
|
|
'name': u'child'
|
|
}], categories)
|
|
|
|
def test_convert_label_map_to_categories_no_label_map(self):
|
|
categories = label_map_util.convert_label_map_to_categories(
|
|
None, max_num_classes=3)
|
|
expected_categories_list = [{
|
|
'name': u'category_1',
|
|
'id': 1
|
|
}, {
|
|
'name': u'category_2',
|
|
'id': 2
|
|
}, {
|
|
'name': u'category_3',
|
|
'id': 3
|
|
}]
|
|
self.assertListEqual(expected_categories_list, categories)
|
|
|
|
def test_convert_label_map_to_categories(self):
|
|
label_map_proto = self._generate_label_map(num_classes=4)
|
|
categories = label_map_util.convert_label_map_to_categories(
|
|
label_map_proto, max_num_classes=3)
|
|
expected_categories_list = [{
|
|
'name': u'1',
|
|
'id': 1
|
|
}, {
|
|
'name': u'2',
|
|
'id': 2
|
|
}, {
|
|
'name': u'3',
|
|
'id': 3
|
|
}]
|
|
self.assertListEqual(expected_categories_list, categories)
|
|
|
|
def test_convert_label_map_to_categories_with_few_classes(self):
|
|
label_map_proto = self._generate_label_map(num_classes=4)
|
|
cat_no_offset = label_map_util.convert_label_map_to_categories(
|
|
label_map_proto, max_num_classes=2)
|
|
expected_categories_list = [{
|
|
'name': u'1',
|
|
'id': 1
|
|
}, {
|
|
'name': u'2',
|
|
'id': 2
|
|
}]
|
|
self.assertListEqual(expected_categories_list, cat_no_offset)
|
|
|
|
def test_get_max_label_map_index(self):
|
|
num_classes = 4
|
|
label_map_proto = self._generate_label_map(num_classes=num_classes)
|
|
max_index = label_map_util.get_max_label_map_index(label_map_proto)
|
|
self.assertEqual(num_classes, max_index)
|
|
|
|
def test_create_category_index(self):
|
|
categories = [{'name': u'1', 'id': 1}, {'name': u'2', 'id': 2}]
|
|
category_index = label_map_util.create_category_index(categories)
|
|
self.assertDictEqual({
|
|
1: {
|
|
'name': u'1',
|
|
'id': 1
|
|
},
|
|
2: {
|
|
'name': u'2',
|
|
'id': 2
|
|
}
|
|
}, category_index)
|
|
|
|
def test_create_categories_from_labelmap(self):
|
|
label_map_string = """
|
|
item {
|
|
id:1
|
|
name:'dog'
|
|
}
|
|
item {
|
|
id:2
|
|
name:'cat'
|
|
}
|
|
"""
|
|
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
|
|
with tf.gfile.Open(label_map_path, 'wb') as f:
|
|
f.write(label_map_string)
|
|
|
|
categories = label_map_util.create_categories_from_labelmap(label_map_path)
|
|
self.assertListEqual([{
|
|
'name': u'dog',
|
|
'id': 1
|
|
}, {
|
|
'name': u'cat',
|
|
'id': 2
|
|
}], categories)
|
|
|
|
def test_create_category_index_from_labelmap(self):
|
|
label_map_string = """
|
|
item {
|
|
id:2
|
|
name:'cat'
|
|
}
|
|
item {
|
|
id:1
|
|
name:'dog'
|
|
}
|
|
"""
|
|
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
|
|
with tf.gfile.Open(label_map_path, 'wb') as f:
|
|
f.write(label_map_string)
|
|
|
|
category_index = label_map_util.create_category_index_from_labelmap(
|
|
label_map_path)
|
|
self.assertDictEqual({
|
|
1: {
|
|
'name': u'dog',
|
|
'id': 1
|
|
},
|
|
2: {
|
|
'name': u'cat',
|
|
'id': 2
|
|
}
|
|
}, category_index)
|
|
|
|
def test_create_category_index_from_labelmap_display(self):
|
|
label_map_string = """
|
|
item {
|
|
id:2
|
|
name:'cat'
|
|
display_name:'meow'
|
|
}
|
|
item {
|
|
id:1
|
|
name:'dog'
|
|
display_name:'woof'
|
|
}
|
|
"""
|
|
label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
|
|
with tf.gfile.Open(label_map_path, 'wb') as f:
|
|
f.write(label_map_string)
|
|
|
|
self.assertDictEqual({
|
|
1: {
|
|
'name': u'dog',
|
|
'id': 1
|
|
},
|
|
2: {
|
|
'name': u'cat',
|
|
'id': 2
|
|
}
|
|
}, label_map_util.create_category_index_from_labelmap(
|
|
label_map_path, False))
|
|
|
|
self.assertDictEqual({
|
|
1: {
|
|
'name': u'woof',
|
|
'id': 1
|
|
},
|
|
2: {
|
|
'name': u'meow',
|
|
'id': 2
|
|
}
|
|
}, label_map_util.create_category_index_from_labelmap(label_map_path))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
tf.test.main()
|