You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

561 lines
18 KiB

  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for losses_builder."""
  16. import tensorflow as tf
  17. from google.protobuf import text_format
  18. from object_detection.builders import losses_builder
  19. from object_detection.core import losses
  20. from object_detection.protos import losses_pb2
  21. from object_detection.utils import ops
  22. class LocalizationLossBuilderTest(tf.test.TestCase):
  23. def test_build_weighted_l2_localization_loss(self):
  24. losses_text_proto = """
  25. localization_loss {
  26. weighted_l2 {
  27. }
  28. }
  29. classification_loss {
  30. weighted_softmax {
  31. }
  32. }
  33. """
  34. losses_proto = losses_pb2.Loss()
  35. text_format.Merge(losses_text_proto, losses_proto)
  36. _, localization_loss, _, _, _, _, _ = losses_builder.build(losses_proto)
  37. self.assertTrue(isinstance(localization_loss,
  38. losses.WeightedL2LocalizationLoss))
  39. def test_build_weighted_smooth_l1_localization_loss_default_delta(self):
  40. losses_text_proto = """
  41. localization_loss {
  42. weighted_smooth_l1 {
  43. }
  44. }
  45. classification_loss {
  46. weighted_softmax {
  47. }
  48. }
  49. """
  50. losses_proto = losses_pb2.Loss()
  51. text_format.Merge(losses_text_proto, losses_proto)
  52. _, localization_loss, _, _, _, _, _ = losses_builder.build(losses_proto)
  53. self.assertTrue(isinstance(localization_loss,
  54. losses.WeightedSmoothL1LocalizationLoss))
  55. self.assertAlmostEqual(localization_loss._delta, 1.0)
  56. def test_build_weighted_smooth_l1_localization_loss_non_default_delta(self):
  57. losses_text_proto = """
  58. localization_loss {
  59. weighted_smooth_l1 {
  60. delta: 0.1
  61. }
  62. }
  63. classification_loss {
  64. weighted_softmax {
  65. }
  66. }
  67. """
  68. losses_proto = losses_pb2.Loss()
  69. text_format.Merge(losses_text_proto, losses_proto)
  70. _, localization_loss, _, _, _, _, _ = losses_builder.build(losses_proto)
  71. self.assertTrue(isinstance(localization_loss,
  72. losses.WeightedSmoothL1LocalizationLoss))
  73. self.assertAlmostEqual(localization_loss._delta, 0.1)
  74. def test_build_weighted_iou_localization_loss(self):
  75. losses_text_proto = """
  76. localization_loss {
  77. weighted_iou {
  78. }
  79. }
  80. classification_loss {
  81. weighted_softmax {
  82. }
  83. }
  84. """
  85. losses_proto = losses_pb2.Loss()
  86. text_format.Merge(losses_text_proto, losses_proto)
  87. _, localization_loss, _, _, _, _, _ = losses_builder.build(losses_proto)
  88. self.assertTrue(isinstance(localization_loss,
  89. losses.WeightedIOULocalizationLoss))
  90. def test_anchorwise_output(self):
  91. losses_text_proto = """
  92. localization_loss {
  93. weighted_smooth_l1 {
  94. }
  95. }
  96. classification_loss {
  97. weighted_softmax {
  98. }
  99. }
  100. """
  101. losses_proto = losses_pb2.Loss()
  102. text_format.Merge(losses_text_proto, losses_proto)
  103. _, localization_loss, _, _, _, _, _ = losses_builder.build(losses_proto)
  104. self.assertTrue(isinstance(localization_loss,
  105. losses.WeightedSmoothL1LocalizationLoss))
  106. predictions = tf.constant([[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]]])
  107. targets = tf.constant([[[0.0, 0.0, 1.0, 1.0], [0.0, 0.0, 1.0, 1.0]]])
  108. weights = tf.constant([[1.0, 1.0]])
  109. loss = localization_loss(predictions, targets, weights=weights)
  110. self.assertEqual(loss.shape, [1, 2])
  111. def test_raise_error_on_empty_localization_config(self):
  112. losses_text_proto = """
  113. classification_loss {
  114. weighted_softmax {
  115. }
  116. }
  117. """
  118. losses_proto = losses_pb2.Loss()
  119. text_format.Merge(losses_text_proto, losses_proto)
  120. with self.assertRaises(ValueError):
  121. losses_builder._build_localization_loss(losses_proto)
  122. class ClassificationLossBuilderTest(tf.test.TestCase):
  123. def test_build_weighted_sigmoid_classification_loss(self):
  124. losses_text_proto = """
  125. classification_loss {
  126. weighted_sigmoid {
  127. }
  128. }
  129. localization_loss {
  130. weighted_l2 {
  131. }
  132. }
  133. """
  134. losses_proto = losses_pb2.Loss()
  135. text_format.Merge(losses_text_proto, losses_proto)
  136. classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto)
  137. self.assertTrue(isinstance(classification_loss,
  138. losses.WeightedSigmoidClassificationLoss))
  139. def test_build_weighted_sigmoid_focal_classification_loss(self):
  140. losses_text_proto = """
  141. classification_loss {
  142. weighted_sigmoid_focal {
  143. }
  144. }
  145. localization_loss {
  146. weighted_l2 {
  147. }
  148. }
  149. """
  150. losses_proto = losses_pb2.Loss()
  151. text_format.Merge(losses_text_proto, losses_proto)
  152. classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto)
  153. self.assertTrue(isinstance(classification_loss,
  154. losses.SigmoidFocalClassificationLoss))
  155. self.assertAlmostEqual(classification_loss._alpha, None)
  156. self.assertAlmostEqual(classification_loss._gamma, 2.0)
  157. def test_build_weighted_sigmoid_focal_loss_non_default(self):
  158. losses_text_proto = """
  159. classification_loss {
  160. weighted_sigmoid_focal {
  161. alpha: 0.25
  162. gamma: 3.0
  163. }
  164. }
  165. localization_loss {
  166. weighted_l2 {
  167. }
  168. }
  169. """
  170. losses_proto = losses_pb2.Loss()
  171. text_format.Merge(losses_text_proto, losses_proto)
  172. classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto)
  173. self.assertTrue(isinstance(classification_loss,
  174. losses.SigmoidFocalClassificationLoss))
  175. self.assertAlmostEqual(classification_loss._alpha, 0.25)
  176. self.assertAlmostEqual(classification_loss._gamma, 3.0)
  177. def test_build_weighted_softmax_classification_loss(self):
  178. losses_text_proto = """
  179. classification_loss {
  180. weighted_softmax {
  181. }
  182. }
  183. localization_loss {
  184. weighted_l2 {
  185. }
  186. }
  187. """
  188. losses_proto = losses_pb2.Loss()
  189. text_format.Merge(losses_text_proto, losses_proto)
  190. classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto)
  191. self.assertTrue(isinstance(classification_loss,
  192. losses.WeightedSoftmaxClassificationLoss))
  193. def test_build_weighted_logits_softmax_classification_loss(self):
  194. losses_text_proto = """
  195. classification_loss {
  196. weighted_logits_softmax {
  197. }
  198. }
  199. localization_loss {
  200. weighted_l2 {
  201. }
  202. }
  203. """
  204. losses_proto = losses_pb2.Loss()
  205. text_format.Merge(losses_text_proto, losses_proto)
  206. classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto)
  207. self.assertTrue(
  208. isinstance(classification_loss,
  209. losses.WeightedSoftmaxClassificationAgainstLogitsLoss))
  210. def test_build_weighted_softmax_classification_loss_with_logit_scale(self):
  211. losses_text_proto = """
  212. classification_loss {
  213. weighted_softmax {
  214. logit_scale: 2.0
  215. }
  216. }
  217. localization_loss {
  218. weighted_l2 {
  219. }
  220. }
  221. """
  222. losses_proto = losses_pb2.Loss()
  223. text_format.Merge(losses_text_proto, losses_proto)
  224. classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto)
  225. self.assertTrue(isinstance(classification_loss,
  226. losses.WeightedSoftmaxClassificationLoss))
  227. def test_build_bootstrapped_sigmoid_classification_loss(self):
  228. losses_text_proto = """
  229. classification_loss {
  230. bootstrapped_sigmoid {
  231. alpha: 0.5
  232. }
  233. }
  234. localization_loss {
  235. weighted_l2 {
  236. }
  237. }
  238. """
  239. losses_proto = losses_pb2.Loss()
  240. text_format.Merge(losses_text_proto, losses_proto)
  241. classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto)
  242. self.assertTrue(isinstance(classification_loss,
  243. losses.BootstrappedSigmoidClassificationLoss))
  244. def test_anchorwise_output(self):
  245. losses_text_proto = """
  246. classification_loss {
  247. weighted_sigmoid {
  248. anchorwise_output: true
  249. }
  250. }
  251. localization_loss {
  252. weighted_l2 {
  253. }
  254. }
  255. """
  256. losses_proto = losses_pb2.Loss()
  257. text_format.Merge(losses_text_proto, losses_proto)
  258. classification_loss, _, _, _, _, _, _ = losses_builder.build(losses_proto)
  259. self.assertTrue(isinstance(classification_loss,
  260. losses.WeightedSigmoidClassificationLoss))
  261. predictions = tf.constant([[[0.0, 1.0, 0.0], [0.0, 0.5, 0.5]]])
  262. targets = tf.constant([[[0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]])
  263. weights = tf.constant([[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]]])
  264. loss = classification_loss(predictions, targets, weights=weights)
  265. self.assertEqual(loss.shape, [1, 2, 3])
  266. def test_raise_error_on_empty_config(self):
  267. losses_text_proto = """
  268. localization_loss {
  269. weighted_l2 {
  270. }
  271. }
  272. """
  273. losses_proto = losses_pb2.Loss()
  274. text_format.Merge(losses_text_proto, losses_proto)
  275. with self.assertRaises(ValueError):
  276. losses_builder.build(losses_proto)
  277. class HardExampleMinerBuilderTest(tf.test.TestCase):
  278. def test_do_not_build_hard_example_miner_by_default(self):
  279. losses_text_proto = """
  280. localization_loss {
  281. weighted_l2 {
  282. }
  283. }
  284. classification_loss {
  285. weighted_softmax {
  286. }
  287. }
  288. """
  289. losses_proto = losses_pb2.Loss()
  290. text_format.Merge(losses_text_proto, losses_proto)
  291. _, _, _, _, hard_example_miner, _, _ = losses_builder.build(losses_proto)
  292. self.assertEqual(hard_example_miner, None)
  293. def test_build_hard_example_miner_for_classification_loss(self):
  294. losses_text_proto = """
  295. localization_loss {
  296. weighted_l2 {
  297. }
  298. }
  299. classification_loss {
  300. weighted_softmax {
  301. }
  302. }
  303. hard_example_miner {
  304. loss_type: CLASSIFICATION
  305. }
  306. """
  307. losses_proto = losses_pb2.Loss()
  308. text_format.Merge(losses_text_proto, losses_proto)
  309. _, _, _, _, hard_example_miner, _, _ = losses_builder.build(losses_proto)
  310. self.assertTrue(isinstance(hard_example_miner, losses.HardExampleMiner))
  311. self.assertEqual(hard_example_miner._loss_type, 'cls')
  312. def test_build_hard_example_miner_for_localization_loss(self):
  313. losses_text_proto = """
  314. localization_loss {
  315. weighted_l2 {
  316. }
  317. }
  318. classification_loss {
  319. weighted_softmax {
  320. }
  321. }
  322. hard_example_miner {
  323. loss_type: LOCALIZATION
  324. }
  325. """
  326. losses_proto = losses_pb2.Loss()
  327. text_format.Merge(losses_text_proto, losses_proto)
  328. _, _, _, _, hard_example_miner, _, _ = losses_builder.build(losses_proto)
  329. self.assertTrue(isinstance(hard_example_miner, losses.HardExampleMiner))
  330. self.assertEqual(hard_example_miner._loss_type, 'loc')
  331. def test_build_hard_example_miner_with_non_default_values(self):
  332. losses_text_proto = """
  333. localization_loss {
  334. weighted_l2 {
  335. }
  336. }
  337. classification_loss {
  338. weighted_softmax {
  339. }
  340. }
  341. hard_example_miner {
  342. num_hard_examples: 32
  343. iou_threshold: 0.5
  344. loss_type: LOCALIZATION
  345. max_negatives_per_positive: 10
  346. min_negatives_per_image: 3
  347. }
  348. """
  349. losses_proto = losses_pb2.Loss()
  350. text_format.Merge(losses_text_proto, losses_proto)
  351. _, _, _, _, hard_example_miner, _, _ = losses_builder.build(losses_proto)
  352. self.assertTrue(isinstance(hard_example_miner, losses.HardExampleMiner))
  353. self.assertEqual(hard_example_miner._num_hard_examples, 32)
  354. self.assertAlmostEqual(hard_example_miner._iou_threshold, 0.5)
  355. self.assertEqual(hard_example_miner._max_negatives_per_positive, 10)
  356. self.assertEqual(hard_example_miner._min_negatives_per_image, 3)
  357. class LossBuilderTest(tf.test.TestCase):
  358. def test_build_all_loss_parameters(self):
  359. losses_text_proto = """
  360. localization_loss {
  361. weighted_l2 {
  362. }
  363. }
  364. classification_loss {
  365. weighted_softmax {
  366. }
  367. }
  368. hard_example_miner {
  369. }
  370. classification_weight: 0.8
  371. localization_weight: 0.2
  372. """
  373. losses_proto = losses_pb2.Loss()
  374. text_format.Merge(losses_text_proto, losses_proto)
  375. (classification_loss, localization_loss, classification_weight,
  376. localization_weight, hard_example_miner, _,
  377. _) = losses_builder.build(losses_proto)
  378. self.assertTrue(isinstance(hard_example_miner, losses.HardExampleMiner))
  379. self.assertTrue(isinstance(classification_loss,
  380. losses.WeightedSoftmaxClassificationLoss))
  381. self.assertTrue(isinstance(localization_loss,
  382. losses.WeightedL2LocalizationLoss))
  383. self.assertAlmostEqual(classification_weight, 0.8)
  384. self.assertAlmostEqual(localization_weight, 0.2)
  385. def test_build_expected_sampling(self):
  386. losses_text_proto = """
  387. localization_loss {
  388. weighted_l2 {
  389. }
  390. }
  391. classification_loss {
  392. weighted_softmax {
  393. }
  394. }
  395. hard_example_miner {
  396. }
  397. classification_weight: 0.8
  398. localization_weight: 0.2
  399. """
  400. losses_proto = losses_pb2.Loss()
  401. text_format.Merge(losses_text_proto, losses_proto)
  402. (classification_loss, localization_loss, classification_weight,
  403. localization_weight, hard_example_miner, _,
  404. _) = losses_builder.build(losses_proto)
  405. self.assertTrue(isinstance(hard_example_miner, losses.HardExampleMiner))
  406. self.assertTrue(
  407. isinstance(classification_loss,
  408. losses.WeightedSoftmaxClassificationLoss))
  409. self.assertTrue(
  410. isinstance(localization_loss, losses.WeightedL2LocalizationLoss))
  411. self.assertAlmostEqual(classification_weight, 0.8)
  412. self.assertAlmostEqual(localization_weight, 0.2)
  413. def test_build_reweighting_unmatched_anchors(self):
  414. losses_text_proto = """
  415. localization_loss {
  416. weighted_l2 {
  417. }
  418. }
  419. classification_loss {
  420. weighted_softmax {
  421. }
  422. }
  423. hard_example_miner {
  424. }
  425. classification_weight: 0.8
  426. localization_weight: 0.2
  427. """
  428. losses_proto = losses_pb2.Loss()
  429. text_format.Merge(losses_text_proto, losses_proto)
  430. (classification_loss, localization_loss, classification_weight,
  431. localization_weight, hard_example_miner, _,
  432. _) = losses_builder.build(losses_proto)
  433. self.assertTrue(isinstance(hard_example_miner, losses.HardExampleMiner))
  434. self.assertTrue(
  435. isinstance(classification_loss,
  436. losses.WeightedSoftmaxClassificationLoss))
  437. self.assertTrue(
  438. isinstance(localization_loss, losses.WeightedL2LocalizationLoss))
  439. self.assertAlmostEqual(classification_weight, 0.8)
  440. self.assertAlmostEqual(localization_weight, 0.2)
  441. def test_raise_error_when_both_focal_loss_and_hard_example_miner(self):
  442. losses_text_proto = """
  443. localization_loss {
  444. weighted_l2 {
  445. }
  446. }
  447. classification_loss {
  448. weighted_sigmoid_focal {
  449. }
  450. }
  451. hard_example_miner {
  452. }
  453. classification_weight: 0.8
  454. localization_weight: 0.2
  455. """
  456. losses_proto = losses_pb2.Loss()
  457. text_format.Merge(losses_text_proto, losses_proto)
  458. with self.assertRaises(ValueError):
  459. losses_builder.build(losses_proto)
  460. class FasterRcnnClassificationLossBuilderTest(tf.test.TestCase):
  461. def test_build_sigmoid_loss(self):
  462. losses_text_proto = """
  463. weighted_sigmoid {
  464. }
  465. """
  466. losses_proto = losses_pb2.ClassificationLoss()
  467. text_format.Merge(losses_text_proto, losses_proto)
  468. classification_loss = losses_builder.build_faster_rcnn_classification_loss(
  469. losses_proto)
  470. self.assertTrue(isinstance(classification_loss,
  471. losses.WeightedSigmoidClassificationLoss))
  472. def test_build_softmax_loss(self):
  473. losses_text_proto = """
  474. weighted_softmax {
  475. }
  476. """
  477. losses_proto = losses_pb2.ClassificationLoss()
  478. text_format.Merge(losses_text_proto, losses_proto)
  479. classification_loss = losses_builder.build_faster_rcnn_classification_loss(
  480. losses_proto)
  481. self.assertTrue(isinstance(classification_loss,
  482. losses.WeightedSoftmaxClassificationLoss))
  483. def test_build_logits_softmax_loss(self):
  484. losses_text_proto = """
  485. weighted_logits_softmax {
  486. }
  487. """
  488. losses_proto = losses_pb2.ClassificationLoss()
  489. text_format.Merge(losses_text_proto, losses_proto)
  490. classification_loss = losses_builder.build_faster_rcnn_classification_loss(
  491. losses_proto)
  492. self.assertTrue(
  493. isinstance(classification_loss,
  494. losses.WeightedSoftmaxClassificationAgainstLogitsLoss))
  495. def test_build_sigmoid_focal_loss(self):
  496. losses_text_proto = """
  497. weighted_sigmoid_focal {
  498. }
  499. """
  500. losses_proto = losses_pb2.ClassificationLoss()
  501. text_format.Merge(losses_text_proto, losses_proto)
  502. classification_loss = losses_builder.build_faster_rcnn_classification_loss(
  503. losses_proto)
  504. self.assertTrue(
  505. isinstance(classification_loss,
  506. losses.SigmoidFocalClassificationLoss))
  507. def test_build_softmax_loss_by_default(self):
  508. losses_text_proto = """
  509. """
  510. losses_proto = losses_pb2.ClassificationLoss()
  511. text_format.Merge(losses_text_proto, losses_proto)
  512. classification_loss = losses_builder.build_faster_rcnn_classification_loss(
  513. losses_proto)
  514. self.assertTrue(isinstance(classification_loss,
  515. losses.WeightedSoftmaxClassificationLoss))
  516. if __name__ == '__main__':
  517. tf.test.main()