You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

137 lines
6.7 KiB

6 years ago
  1. # So you want to create a new model!
  2. In this section, we discuss some of the abstractions that we use
  3. for defining detection models. If you would like to define a new model
  4. architecture for detection and use it in the Tensorflow Detection API,
  5. then this section should also serve as a high level guide to the files that you
  6. will need to edit to get your new model working.
  7. ## DetectionModels (`object_detection/core/model.py`)
  8. In order to be trained, evaluated, and exported for serving using our
  9. provided binaries, all models under the Tensorflow Object Detection API must
  10. implement the `DetectionModel` interface (see the full definition in `object_detection/core/model.py`). In particular,
  11. each of these models are responsible for implementing 5 functions:
  12. * `preprocess`: Run any preprocessing (e.g., scaling/shifting/reshaping) of
  13. input values that is necessary prior to running the detector on an input
  14. image.
  15. * `predict`: Produce “raw” prediction tensors that can be passed to loss or
  16. postprocess functions.
  17. * `postprocess`: Convert predicted output tensors to final detections.
  18. * `loss`: Compute scalar loss tensors with respect to provided groundtruth.
  19. * `restore`: Load a checkpoint into the Tensorflow graph.
  20. Given a `DetectionModel` at training time, we pass each image batch through
  21. the following sequence of functions to compute a loss which can be optimized via
  22. SGD:
  23. ```
  24. inputs (images tensor) -> preprocess -> predict -> loss -> outputs (loss tensor)
  25. ```
  26. And at eval time, we pass each image batch through the following sequence of
  27. functions to produce a set of detections:
  28. ```
  29. inputs (images tensor) -> preprocess -> predict -> postprocess ->
  30. outputs (boxes tensor, scores tensor, classes tensor, num_detections tensor)
  31. ```
  32. Some conventions to be aware of:
  33. * `DetectionModel`s should make no assumptions about the input size or aspect
  34. ratio --- they are responsible for doing any resize/reshaping necessary
  35. (see docstring for the `preprocess` function).
  36. * Output classes are always integers in the range `[0, num_classes)`.
  37. Any mapping of these integers to semantic labels is to be handled outside
  38. of this class. We never explicitly emit a “background class” --- thus 0 is
  39. the first non-background class and any logic of predicting and removing
  40. implicit background classes must be handled internally by the implementation.
  41. * Detected boxes are to be interpreted as being in
  42. `[y_min, x_min, y_max, x_max]` format and normalized relative to the
  43. image window.
  44. * We do not specifically assume any kind of probabilistic interpretation of the
  45. scores --- the only important thing is their relative ordering. Thus
  46. implementations of the postprocess function are free to output logits,
  47. probabilities, calibrated probabilities, or anything else.
  48. ## Defining a new Faster R-CNN or SSD Feature Extractor
  49. In most cases, you probably will not implement a `DetectionModel` from scratch
  50. --- instead you might create a new feature extractor to be used by one of the
  51. SSD or Faster R-CNN meta-architectures. (We think of meta-architectures as
  52. classes that define entire families of models using the `DetectionModel`
  53. abstraction).
  54. Note: For the following discussion to make sense, we recommend first becoming
  55. familiar with the [Faster R-CNN](https://arxiv.org/abs/1506.01497) paper.
  56. Let’s now imagine that you have invented a brand new network architecture
  57. (say, “InceptionV100”) for classification and want to see how InceptionV100
  58. would behave as a feature extractor for detection (say, with Faster R-CNN).
  59. A similar procedure would hold for SSD models, but we’ll discuss Faster R-CNN.
  60. To use InceptionV100, we will have to define a new
  61. `FasterRCNNFeatureExtractor` and pass it to our `FasterRCNNMetaArch`
  62. constructor as input. See
  63. `object_detection/meta_architectures/faster_rcnn_meta_arch.py` for definitions
  64. of `FasterRCNNFeatureExtractor` and `FasterRCNNMetaArch`, respectively.
  65. A `FasterRCNNFeatureExtractor` must define a few
  66. functions:
  67. * `preprocess`: Run any preprocessing of input values that is necessary prior
  68. to running the detector on an input image.
  69. * `_extract_proposal_features`: Extract first stage Region Proposal Network
  70. (RPN) features.
  71. * `_extract_box_classifier_features`: Extract second stage Box Classifier
  72. features.
  73. * `restore_from_classification_checkpoint_fn`: Load a checkpoint into the
  74. Tensorflow graph.
  75. See the `object_detection/models/faster_rcnn_resnet_v1_feature_extractor.py`
  76. definition as one example. Some remarks:
  77. * We typically initialize the weights of this feature extractor
  78. using those from the
  79. [Slim Resnet-101 classification checkpoint](https://github.com/tensorflow/models/tree/master/research/slim#pre-trained-models),
  80. and we know
  81. that images were preprocessed when training this checkpoint
  82. by subtracting a channel mean from each input
  83. image. Thus, we implement the preprocess function to replicate the same
  84. channel mean subtraction behavior.
  85. * The “full” resnet classification network defined in slim is cut into two
  86. parts --- all but the last “resnet block” is put into the
  87. `_extract_proposal_features` function and the final block is separately
  88. defined in the `_extract_box_classifier_features function`. In general,
  89. some experimentation may be required to decide on an optimal layer at
  90. which to “cut” your feature extractor into these two pieces for Faster R-CNN.
  91. ## Register your model for configuration
  92. Assuming that your new feature extractor does not require nonstandard
  93. configuration, you will want to ideally be able to simply change the
  94. “feature_extractor.type” fields in your configuration protos to point to a
  95. new feature extractor. In order for our API to know how to understand this
  96. new type though, you will first have to register your new feature
  97. extractor with the model builder (`object_detection/builders/model_builder.py`),
  98. whose job is to create models from config protos..
  99. Registration is simple --- just add a pointer to the new Feature Extractor
  100. class that you have defined in one of the SSD or Faster R-CNN Feature
  101. Extractor Class maps at the top of the
  102. `object_detection/builders/model_builder.py` file.
  103. We recommend adding a test in `object_detection/builders/model_builder_test.py`
  104. to make sure that parsing your proto will work as expected.
  105. ## Taking your new model for a spin
  106. After registration you are ready to go with your model! Some final tips:
  107. * To save time debugging, try running your configuration file locally first
  108. (both training and evaluation).
  109. * Do a sweep of learning rates to figure out which learning rate is best
  110. for your model.
  111. * A small but often important detail: you may find it necessary to disable
  112. batchnorm training (that is, load the batch norm parameters from the
  113. classification checkpoint, but do not update them during gradient descent).