You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

334 lines
8.8 KiB

6 years ago
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for object_detection.utils.label_map_util."""
  16. import os
  17. import tensorflow as tf
  18. from google.protobuf import text_format
  19. from object_detection.protos import string_int_label_map_pb2
  20. from object_detection.utils import label_map_util
  21. class LabelMapUtilTest(tf.test.TestCase):
  22. def _generate_label_map(self, num_classes):
  23. label_map_proto = string_int_label_map_pb2.StringIntLabelMap()
  24. for i in range(1, num_classes + 1):
  25. item = label_map_proto.item.add()
  26. item.id = i
  27. item.name = 'label_' + str(i)
  28. item.display_name = str(i)
  29. return label_map_proto
  30. def test_get_label_map_dict(self):
  31. label_map_string = """
  32. item {
  33. id:2
  34. name:'cat'
  35. }
  36. item {
  37. id:1
  38. name:'dog'
  39. }
  40. """
  41. label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
  42. with tf.gfile.Open(label_map_path, 'wb') as f:
  43. f.write(label_map_string)
  44. label_map_dict = label_map_util.get_label_map_dict(label_map_path)
  45. self.assertEqual(label_map_dict['dog'], 1)
  46. self.assertEqual(label_map_dict['cat'], 2)
  47. def test_get_label_map_dict_display(self):
  48. label_map_string = """
  49. item {
  50. id:2
  51. display_name:'cat'
  52. }
  53. item {
  54. id:1
  55. display_name:'dog'
  56. }
  57. """
  58. label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
  59. with tf.gfile.Open(label_map_path, 'wb') as f:
  60. f.write(label_map_string)
  61. label_map_dict = label_map_util.get_label_map_dict(
  62. label_map_path, use_display_name=True)
  63. self.assertEqual(label_map_dict['dog'], 1)
  64. self.assertEqual(label_map_dict['cat'], 2)
  65. def test_load_bad_label_map(self):
  66. label_map_string = """
  67. item {
  68. id:0
  69. name:'class that should not be indexed at zero'
  70. }
  71. item {
  72. id:2
  73. name:'cat'
  74. }
  75. item {
  76. id:1
  77. name:'dog'
  78. }
  79. """
  80. label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
  81. with tf.gfile.Open(label_map_path, 'wb') as f:
  82. f.write(label_map_string)
  83. with self.assertRaises(ValueError):
  84. label_map_util.load_labelmap(label_map_path)
  85. def test_load_label_map_with_background(self):
  86. label_map_string = """
  87. item {
  88. id:0
  89. name:'background'
  90. }
  91. item {
  92. id:2
  93. name:'cat'
  94. }
  95. item {
  96. id:1
  97. name:'dog'
  98. }
  99. """
  100. label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
  101. with tf.gfile.Open(label_map_path, 'wb') as f:
  102. f.write(label_map_string)
  103. label_map_dict = label_map_util.get_label_map_dict(label_map_path)
  104. self.assertEqual(label_map_dict['background'], 0)
  105. self.assertEqual(label_map_dict['dog'], 1)
  106. self.assertEqual(label_map_dict['cat'], 2)
  107. def test_get_label_map_dict_with_fill_in_gaps_and_background(self):
  108. label_map_string = """
  109. item {
  110. id:3
  111. name:'cat'
  112. }
  113. item {
  114. id:1
  115. name:'dog'
  116. }
  117. """
  118. label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
  119. with tf.gfile.Open(label_map_path, 'wb') as f:
  120. f.write(label_map_string)
  121. label_map_dict = label_map_util.get_label_map_dict(
  122. label_map_path, fill_in_gaps_and_background=True)
  123. self.assertEqual(label_map_dict['background'], 0)
  124. self.assertEqual(label_map_dict['dog'], 1)
  125. self.assertEqual(label_map_dict['2'], 2)
  126. self.assertEqual(label_map_dict['cat'], 3)
  127. self.assertEqual(len(label_map_dict), max(label_map_dict.values()) + 1)
  128. def test_keep_categories_with_unique_id(self):
  129. label_map_proto = string_int_label_map_pb2.StringIntLabelMap()
  130. label_map_string = """
  131. item {
  132. id:2
  133. name:'cat'
  134. }
  135. item {
  136. id:1
  137. name:'child'
  138. }
  139. item {
  140. id:1
  141. name:'person'
  142. }
  143. item {
  144. id:1
  145. name:'n00007846'
  146. }
  147. """
  148. text_format.Merge(label_map_string, label_map_proto)
  149. categories = label_map_util.convert_label_map_to_categories(
  150. label_map_proto, max_num_classes=3)
  151. self.assertListEqual([{
  152. 'id': 2,
  153. 'name': u'cat'
  154. }, {
  155. 'id': 1,
  156. 'name': u'child'
  157. }], categories)
  158. def test_convert_label_map_to_categories_no_label_map(self):
  159. categories = label_map_util.convert_label_map_to_categories(
  160. None, max_num_classes=3)
  161. expected_categories_list = [{
  162. 'name': u'category_1',
  163. 'id': 1
  164. }, {
  165. 'name': u'category_2',
  166. 'id': 2
  167. }, {
  168. 'name': u'category_3',
  169. 'id': 3
  170. }]
  171. self.assertListEqual(expected_categories_list, categories)
  172. def test_convert_label_map_to_categories(self):
  173. label_map_proto = self._generate_label_map(num_classes=4)
  174. categories = label_map_util.convert_label_map_to_categories(
  175. label_map_proto, max_num_classes=3)
  176. expected_categories_list = [{
  177. 'name': u'1',
  178. 'id': 1
  179. }, {
  180. 'name': u'2',
  181. 'id': 2
  182. }, {
  183. 'name': u'3',
  184. 'id': 3
  185. }]
  186. self.assertListEqual(expected_categories_list, categories)
  187. def test_convert_label_map_to_categories_with_few_classes(self):
  188. label_map_proto = self._generate_label_map(num_classes=4)
  189. cat_no_offset = label_map_util.convert_label_map_to_categories(
  190. label_map_proto, max_num_classes=2)
  191. expected_categories_list = [{
  192. 'name': u'1',
  193. 'id': 1
  194. }, {
  195. 'name': u'2',
  196. 'id': 2
  197. }]
  198. self.assertListEqual(expected_categories_list, cat_no_offset)
  199. def test_get_max_label_map_index(self):
  200. num_classes = 4
  201. label_map_proto = self._generate_label_map(num_classes=num_classes)
  202. max_index = label_map_util.get_max_label_map_index(label_map_proto)
  203. self.assertEqual(num_classes, max_index)
  204. def test_create_category_index(self):
  205. categories = [{'name': u'1', 'id': 1}, {'name': u'2', 'id': 2}]
  206. category_index = label_map_util.create_category_index(categories)
  207. self.assertDictEqual({
  208. 1: {
  209. 'name': u'1',
  210. 'id': 1
  211. },
  212. 2: {
  213. 'name': u'2',
  214. 'id': 2
  215. }
  216. }, category_index)
  217. def test_create_categories_from_labelmap(self):
  218. label_map_string = """
  219. item {
  220. id:1
  221. name:'dog'
  222. }
  223. item {
  224. id:2
  225. name:'cat'
  226. }
  227. """
  228. label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
  229. with tf.gfile.Open(label_map_path, 'wb') as f:
  230. f.write(label_map_string)
  231. categories = label_map_util.create_categories_from_labelmap(label_map_path)
  232. self.assertListEqual([{
  233. 'name': u'dog',
  234. 'id': 1
  235. }, {
  236. 'name': u'cat',
  237. 'id': 2
  238. }], categories)
  239. def test_create_category_index_from_labelmap(self):
  240. label_map_string = """
  241. item {
  242. id:2
  243. name:'cat'
  244. }
  245. item {
  246. id:1
  247. name:'dog'
  248. }
  249. """
  250. label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
  251. with tf.gfile.Open(label_map_path, 'wb') as f:
  252. f.write(label_map_string)
  253. category_index = label_map_util.create_category_index_from_labelmap(
  254. label_map_path)
  255. self.assertDictEqual({
  256. 1: {
  257. 'name': u'dog',
  258. 'id': 1
  259. },
  260. 2: {
  261. 'name': u'cat',
  262. 'id': 2
  263. }
  264. }, category_index)
  265. def test_create_category_index_from_labelmap_display(self):
  266. label_map_string = """
  267. item {
  268. id:2
  269. name:'cat'
  270. display_name:'meow'
  271. }
  272. item {
  273. id:1
  274. name:'dog'
  275. display_name:'woof'
  276. }
  277. """
  278. label_map_path = os.path.join(self.get_temp_dir(), 'label_map.pbtxt')
  279. with tf.gfile.Open(label_map_path, 'wb') as f:
  280. f.write(label_map_string)
  281. self.assertDictEqual({
  282. 1: {
  283. 'name': u'dog',
  284. 'id': 1
  285. },
  286. 2: {
  287. 'name': u'cat',
  288. 'id': 2
  289. }
  290. }, label_map_util.create_category_index_from_labelmap(
  291. label_map_path, False))
  292. self.assertDictEqual({
  293. 1: {
  294. 'name': u'woof',
  295. 'id': 1
  296. },
  297. 2: {
  298. 'name': u'meow',
  299. 'id': 2
  300. }
  301. }, label_map_util.create_category_index_from_labelmap(label_map_path))
  302. if __name__ == '__main__':
  303. tf.test.main()