In this section, we discuss some of the abstractions that we use for defining detection models. If you would like to define a new model architecture for detection and use it in the Tensorflow Detection API, then this section should also serve as a high level guide to the files that you will need to edit to get your new model working.
object_detection/core/model.py
)In order to be trained, evaluated, and exported for serving using our
provided binaries, all models under the Tensorflow Object Detection API must
implement the DetectionModel
interface (see the full definition in object_detection/core/model.py
). In particular,
each of these models are responsible for implementing 5 functions:
preprocess
: Run any preprocessing (e.g., scaling/shifting/reshaping) of
input values that is necessary prior to running the detector on an input
image.predict
: Produce “raw” prediction tensors that can be passed to loss or
postprocess functions.postprocess
: Convert predicted output tensors to final detections.loss
: Compute scalar loss tensors with respect to provided groundtruth.restore
: Load a checkpoint into the Tensorflow graph.Given a DetectionModel
at training time, we pass each image batch through
the following sequence of functions to compute a loss which can be optimized via
SGD:
inputs (images tensor) -> preprocess -> predict -> loss -> outputs (loss tensor)
And at eval time, we pass each image batch through the following sequence of functions to produce a set of detections:
inputs (images tensor) -> preprocess -> predict -> postprocess ->
outputs (boxes tensor, scores tensor, classes tensor, num_detections tensor)
Some conventions to be aware of:
DetectionModel
s should make no assumptions about the input size or aspect
ratio --- they are responsible for doing any resize/reshaping necessary
(see docstring for the preprocess
function).[0, num_classes)
.
Any mapping of these integers to semantic labels is to be handled outside
of this class. We never explicitly emit a “background class” --- thus 0 is
the first non-background class and any logic of predicting and removing
implicit background classes must be handled internally by the implementation.[y_min, x_min, y_max, x_max]
format and normalized relative to the
image window.In most cases, you probably will not implement a DetectionModel
from scratch
--- instead you might create a new feature extractor to be used by one of the
SSD or Faster R-CNN meta-architectures. (We think of meta-architectures as
classes that define entire families of models using the DetectionModel
abstraction).
Note: For the following discussion to make sense, we recommend first becoming familiar with the Faster R-CNN paper.
Let’s now imagine that you have invented a brand new network architecture (say, “InceptionV100”) for classification and want to see how InceptionV100 would behave as a feature extractor for detection (say, with Faster R-CNN). A similar procedure would hold for SSD models, but we’ll discuss Faster R-CNN.
To use InceptionV100, we will have to define a new
FasterRCNNFeatureExtractor
and pass it to our FasterRCNNMetaArch
constructor as input. See
object_detection/meta_architectures/faster_rcnn_meta_arch.py
for definitions
of FasterRCNNFeatureExtractor
and FasterRCNNMetaArch
, respectively.
A FasterRCNNFeatureExtractor
must define a few
functions:
preprocess
: Run any preprocessing of input values that is necessary prior
to running the detector on an input image._extract_proposal_features
: Extract first stage Region Proposal Network
(RPN) features._extract_box_classifier_features
: Extract second stage Box Classifier
features.restore_from_classification_checkpoint_fn
: Load a checkpoint into the
Tensorflow graph.See the object_detection/models/faster_rcnn_resnet_v1_feature_extractor.py
definition as one example. Some remarks:
_extract_proposal_features
function and the final block is separately
defined in the _extract_box_classifier_features function
. In general,
some experimentation may be required to decide on an optimal layer at
which to “cut” your feature extractor into these two pieces for Faster R-CNN.Assuming that your new feature extractor does not require nonstandard
configuration, you will want to ideally be able to simply change the
“feature_extractor.type” fields in your configuration protos to point to a
new feature extractor. In order for our API to know how to understand this
new type though, you will first have to register your new feature
extractor with the model builder (object_detection/builders/model_builder.py
),
whose job is to create models from config protos..
Registration is simple --- just add a pointer to the new Feature Extractor
class that you have defined in one of the SSD or Faster R-CNN Feature
Extractor Class maps at the top of the
object_detection/builders/model_builder.py
file.
We recommend adding a test in object_detection/builders/model_builder_test.py
to make sure that parsing your proto will work as expected.
After registration you are ready to go with your model! Some final tips: