You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

206 lines
6.5 KiB

  1. # TPU compatible detection pipelines
  2. [TOC]
  3. The Tensorflow Object Detection API supports TPU training for some models. To
  4. make models TPU compatible you need to make a few tweaks to the model config as
  5. mentioned below. We also provide several sample configs that you can use as a
  6. template.
  7. ## TPU compatibility
  8. ### Static shaped tensors
  9. TPU training currently requires all tensors in the Tensorflow Graph to have
  10. static shapes. However, most of the sample configs in Object Detection API have
  11. a few different tensors that are dynamically shaped. Fortunately, we provide
  12. simple alternatives in the model configuration that modifies these tensors to
  13. have static shape:
  14. * **Image tensors with static shape** - This can be achieved either by using a
  15. `fixed_shape_resizer` that resizes images to a fixed spatial shape or by
  16. setting `pad_to_max_dimension: true` in `keep_aspect_ratio_resizer` which
  17. pads the resized images with zeros to the bottom and right. Padded image
  18. tensors are correctly handled internally within the model.
  19. ```
  20. image_resizer {
  21. fixed_shape_resizer {
  22. height: 640
  23. width: 640
  24. }
  25. }
  26. ```
  27. or
  28. ```
  29. image_resizer {
  30. keep_aspect_ratio_resizer {
  31. min_dimension: 640
  32. max_dimension: 640
  33. pad_to_max_dimension: true
  34. }
  35. }
  36. ```
  37. * **Groundtruth tensors with static shape** - Images in a typical detection
  38. dataset have variable number of groundtruth boxes and associated classes.
  39. Setting `max_number_of_boxes` to a large enough number in the
  40. `train_input_reader` and `eval_input_reader` pads the groundtruth tensors
  41. with zeros to a static shape. Padded groundtruth tensors are correctly
  42. handled internally within the model.
  43. ```
  44. train_input_reader: {
  45. tf_record_input_reader {
  46. input_path: "PATH_TO_BE_CONFIGURED/mscoco_train.record-?????-of-00100"
  47. }
  48. label_map_path: "PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt"
  49. max_number_of_boxes: 200
  50. }
  51. eval_input_reader: {
  52. tf_record_input_reader {
  53. input_path: "PATH_TO_BE_CONFIGURED/mscoco_val.record-?????-of-0010"
  54. }
  55. label_map_path: "PATH_TO_BE_CONFIGURED/mscoco_label_map.pbtxt"
  56. max_number_of_boxes: 200
  57. }
  58. ```
  59. ### TPU friendly ops
  60. Although TPU supports a vast number of tensorflow ops, a few used in the
  61. Tensorflow Object Detection API are unsupported. We list such ops below and
  62. recommend compatible substitutes.
  63. * **Anchor sampling** - Typically we use hard example mining in standard SSD
  64. pipeliens to balance positive and negative anchors that contribute to the
  65. loss. Hard Example mining uses non max suppression as a subroutine and since
  66. non max suppression is not currently supported on TPUs we cannot use hard
  67. example mining. Fortunately, we provide an implementation of focal loss that
  68. can be used instead of hard example mining. Remove `hard_example_miner` from
  69. the config and substitute `weighted_sigmoid` classification loss with
  70. `weighted_sigmoid_focal` loss.
  71. ```
  72. loss {
  73. classification_loss {
  74. weighted_sigmoid_focal {
  75. alpha: 0.25
  76. gamma: 2.0
  77. }
  78. }
  79. localization_loss {
  80. weighted_smooth_l1 {
  81. }
  82. }
  83. classification_weight: 1.0
  84. localization_weight: 1.0
  85. }
  86. ```
  87. * **Target Matching** - Object detection API provides two choices for matcher
  88. used in target assignment: `argmax_matcher` and `bipartite_matcher`.
  89. Bipartite matcher is not currently supported on TPU, therefore we must
  90. modify the configs to use `argmax_matcher`. Additionally, set
  91. `use_matmul_gather: true` for efficiency on TPU.
  92. ```
  93. matcher {
  94. argmax_matcher {
  95. matched_threshold: 0.5
  96. unmatched_threshold: 0.5
  97. ignore_thresholds: false
  98. negatives_lower_than_unmatched: true
  99. force_match_for_each_row: true
  100. use_matmul_gather: true
  101. }
  102. }
  103. ```
  104. ### TPU training hyperparameters
  105. Object Detection training on TPU uses synchronous SGD. On a typical cloud TPU
  106. with 8 cores we recommend batch sizes that are 8x large when compared to a GPU
  107. config that uses asynchronous SGD. We also use fewer training steps (~ 1/100 x)
  108. due to the large batch size. This necessitates careful tuning of some other
  109. training parameters as listed below.
  110. * **Batch size** - Use the largest batch size that can fit on cloud TPU.
  111. ```
  112. train_config {
  113. batch_size: 1024
  114. }
  115. ```
  116. * **Training steps** - Typically only 10s of thousands.
  117. ```
  118. train_config {
  119. num_steps: 25000
  120. }
  121. ```
  122. * **Batch norm decay** - Use smaller decay constants (0.97 or 0.997) since we
  123. take fewer training steps.
  124. ```
  125. batch_norm {
  126. scale: true,
  127. decay: 0.97,
  128. epsilon: 0.001,
  129. }
  130. ```
  131. * **Learning rate** - Use large learning rate with warmup. Scale learning rate
  132. linearly with batch size. See `cosine_decay_learning_rate` or
  133. `manual_step_learning_rate` for examples.
  134. ```
  135. learning_rate: {
  136. cosine_decay_learning_rate {
  137. learning_rate_base: .04
  138. total_steps: 25000
  139. warmup_learning_rate: .013333
  140. warmup_steps: 2000
  141. }
  142. }
  143. ```
  144. or
  145. ```
  146. learning_rate: {
  147. manual_step_learning_rate {
  148. warmup: true
  149. initial_learning_rate: .01333
  150. schedule {
  151. step: 2000
  152. learning_rate: 0.04
  153. }
  154. schedule {
  155. step: 15000
  156. learning_rate: 0.004
  157. }
  158. }
  159. }
  160. ```
  161. ## Example TPU compatible configs
  162. We provide example config files that you can use to train your own models on TPU
  163. * <a href='https://github.com/tensorflow/models/blob/master/research/object_detection/samples/configs/ssd_mobilenet_v1_300x300_coco14_sync.config'>ssd_mobilenet_v1_300x300</a> <br>
  164. * <a href='https://github.com/tensorflow/models/blob/master/research/object_detection/samples/configs/ssd_mobilenet_v1_ppn_shared_box_predictor_300x300_coco14_sync.config'>ssd_mobilenet_v1_ppn_300x300</a> <br>
  165. * <a href='https://github.com/tensorflow/models/blob/master/research/object_detection/samples/configs/ssd_mobilenet_v1_fpn_shared_box_predictor_640x640_coco14_sync.config'>ssd_mobilenet_v1_fpn_640x640
  166. (mobilenet based retinanet)</a> <br>
  167. * <a href='https://github.com/tensorflow/models/blob/master/research/object_detection/samples/configs/ssd_resnet50_v1_fpn_shared_box_predictor_640x640_coco14_sync.config'>ssd_resnet50_v1_fpn_640x640
  168. (retinanet)</a> <br>
  169. ## Supported Meta architectures
  170. Currently, `SSDMetaArch` models are supported on TPUs. `FasterRCNNMetaArch` is
  171. going to be supported soon.