|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
|
|
"""Base head class.
|
|
|
|
All the different kinds of prediction heads in different models will inherit
|
|
from this class. What is in common between all head classes is that they have a
|
|
`predict` function that receives `features` as its first argument.
|
|
|
|
How to add a new prediction head to an existing meta architecture?
|
|
For example, how can we add a `3d shape` prediction head to Mask RCNN?
|
|
|
|
We have to take the following steps to add a new prediction head to an
|
|
existing meta arch:
|
|
(a) Add a class for predicting the head. This class should inherit from the
|
|
`Head` class below and have a `predict` function that receives the features
|
|
and predicts the output. The output is always a tf.float32 tensor.
|
|
(b) Add the head to the meta architecture. For example in case of Mask RCNN,
|
|
go to box_predictor_builder and put in the logic for adding the new head to the
|
|
Mask RCNN box predictor.
|
|
(c) Add the logic for computing the loss for the new head.
|
|
(d) Add the necessary metrics for the new head.
|
|
(e) (optional) Add visualization for the new head.
|
|
"""
|
|
from abc import abstractmethod
|
|
|
|
import tensorflow as tf
|
|
|
|
|
|
class Head(object):
|
|
"""Mask RCNN head base class."""
|
|
|
|
def __init__(self):
|
|
"""Constructor."""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def predict(self, features, num_predictions_per_location):
|
|
"""Returns the head's predictions.
|
|
|
|
Args:
|
|
features: A float tensor of features.
|
|
num_predictions_per_location: Int containing number of predictions per
|
|
location.
|
|
|
|
Returns:
|
|
A tf.float32 tensor.
|
|
"""
|
|
pass
|
|
|
|
|
|
class KerasHead(tf.keras.Model):
|
|
"""Keras head base class."""
|
|
|
|
def call(self, features):
|
|
"""The Keras model call will delegate to the `_predict` method."""
|
|
return self._predict(features)
|
|
|
|
@abstractmethod
|
|
def _predict(self, features):
|
|
"""Returns the head's predictions.
|
|
|
|
Args:
|
|
features: A float tensor of features.
|
|
|
|
Returns:
|
|
A tf.float32 tensor.
|
|
"""
|
|
pass
|