# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for object_detection.utils.label_map_util.""" import os import tensorflow as tf from google.protobuf import text_format from object_detection.protos import string_int_label_map_pb2 from object_detection.utils import label_map_util class LabelMapUtilTest(tf.test.TestCase): def _generate_label_map(self, num_classes): label_map_proto = string_int_label_map_pb2.StringIntLabelMap() for i in range(1, num_classes + 1): item = label_map_proto.item.add() item.id = i item.name = 'label_' + str(i) item.display_name = str(i) return label_map_proto def test_get_label_map_dict(self): label_map_string = """ item { id:2 name:'cat' } item { id:1 name:'dog' } """ label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt') with tf.gfile.Open(label_map_path, 'wb') as f: f.write(label_map_string) label_map_dict = label_map_util.get_label_map_dict(label_map_path) self.assertEqual(label_map_dict['dog'], 1) self.assertEqual(label_map_dict['cat'], 2) def test_get_label_map_dict_display(self): label_map_string = """ item { id:2 display_name:'cat' } item { id:1 display_name:'dog' } """ label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt') with tf.gfile.Open(label_map_path, 'wb') as f: f.write(label_map_string) label_map_dict = label_map_util.get_label_map_dict( label_map_path, use_display_name=True) self.assertEqual(label_map_dict['dog'], 1) self.assertEqual(label_map_dict['cat'], 2) def test_load_bad_label_map(self): label_map_string = """ item { id:0 name:'class that should not be indexed at zero' } item { id:2 name:'cat' } item { id:1 name:'dog' } """ label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt') with tf.gfile.Open(label_map_path, 'wb') as f: f.write(label_map_string) with self.assertRaises(ValueError): label_map_util.load_labelmap(label_map_path) def test_load_label_map_with_background(self): label_map_string = """ item { id:0 name:'background' } item { id:2 name:'cat' } item { id:1 name:'dog' } """ label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt') with tf.gfile.Open(label_map_path, 'wb') as f: f.write(label_map_string) label_map_dict = label_map_util.get_label_map_dict(label_map_path) self.assertEqual(label_map_dict['background'], 0) self.assertEqual(label_map_dict['dog'], 1) self.assertEqual(label_map_dict['cat'], 2) def test_get_label_map_dict_with_fill_in_gaps_and_background(self): label_map_string = """ item { id:3 name:'cat' } item { id:1 name:'dog' } """ label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt') with tf.gfile.Open(label_map_path, 'wb') as f: f.write(label_map_string) label_map_dict = label_map_util.get_label_map_dict( label_map_path, fill_in_gaps_and_background=True) self.assertEqual(label_map_dict['background'], 0) self.assertEqual(label_map_dict['dog'], 1) self.assertEqual(label_map_dict['2'], 2) self.assertEqual(label_map_dict['cat'], 3) self.assertEqual(len(label_map_dict), max(label_map_dict.values()) + 1) def test_keep_categories_with_unique_id(self): label_map_proto = string_int_label_map_pb2.StringIntLabelMap() label_map_string = """ item { id:2 name:'cat' } item { id:1 name:'child' } item { id:1 name:'person' } item { id:1 name:'n00007846' } """ text_format.Merge(label_map_string, label_map_proto) categories = label_map_util.convert_label_map_to_categories( label_map_proto, max_num_classes=3) self.assertListEqual([{ 'id': 2, 'name': u'cat' }, { 'id': 1, 'name': u'child' }], categories) def test_convert_label_map_to_categories_no_label_map(self): categories = label_map_util.convert_label_map_to_categories( None, max_num_classes=3) expected_categories_list = [{ 'name': u'category_1', 'id': 1 }, { 'name': u'category_2', 'id': 2 }, { 'name': u'category_3', 'id': 3 }] self.assertListEqual(expected_categories_list, categories) def test_convert_label_map_to_categories(self): label_map_proto = self._generate_label_map(num_classes=4) categories = label_map_util.convert_label_map_to_categories( label_map_proto, max_num_classes=3) expected_categories_list = [{ 'name': u'1', 'id': 1 }, { 'name': u'2', 'id': 2 }, { 'name': u'3', 'id': 3 }] self.assertListEqual(expected_categories_list, categories) def test_convert_label_map_to_categories_with_few_classes(self): label_map_proto = self._generate_label_map(num_classes=4) cat_no_offset = label_map_util.convert_label_map_to_categories( label_map_proto, max_num_classes=2) expected_categories_list = [{ 'name': u'1', 'id': 1 }, { 'name': u'2', 'id': 2 }] self.assertListEqual(expected_categories_list, cat_no_offset) def test_get_max_label_map_index(self): num_classes = 4 label_map_proto = self._generate_label_map(num_classes=num_classes) max_index = label_map_util.get_max_label_map_index(label_map_proto) self.assertEqual(num_classes, max_index) def test_create_category_index(self): categories = [{'name': u'1', 'id': 1}, {'name': u'2', 'id': 2}] category_index = label_map_util.create_category_index(categories) self.assertDictEqual({ 1: { 'name': u'1', 'id': 1 }, 2: { 'name': u'2', 'id': 2 } }, category_index) def test_create_categories_from_labelmap(self): label_map_string = """ item { id:1 name:'dog' } item { id:2 name:'cat' } """ label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt') with tf.gfile.Open(label_map_path, 'wb') as f: f.write(label_map_string) categories = label_map_util.create_categories_from_labelmap(label_map_path) self.assertListEqual([{ 'name': u'dog', 'id': 1 }, { 'name': u'cat', 'id': 2 }], categories) def test_create_category_index_from_labelmap(self): label_map_string = """ item { id:2 name:'cat' } item { id:1 name:'dog' } """ label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt') with tf.gfile.Open(label_map_path, 'wb') as f: f.write(label_map_string) category_index = label_map_util.create_category_index_from_labelmap( label_map_path) self.assertDictEqual({ 1: { 'name': u'dog', 'id': 1 }, 2: { 'name': u'cat', 'id': 2 } }, category_index) def test_create_category_index_from_labelmap_display(self): label_map_string = """ item { id:2 name:'cat' display_name:'meow' } item { id:1 name:'dog' display_name:'woof' } """ label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt') with tf.gfile.Open(label_map_path, 'wb') as f: f.write(label_map_string) self.assertDictEqual({ 1: { 'name': u'dog', 'id': 1 }, 2: { 'name': u'cat', 'id': 2 } }, label_map_util.create_category_index_from_labelmap( label_map_path, False)) self.assertDictEqual({ 1: { 'name': u'woof', 'id': 1 }, 2: { 'name': u'meow', 'id': 2 } }, label_map_util.create_category_index_from_labelmap(label_map_path)) if __name__ == '__main__': tf.test.main()