You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

81 lines
2.6 KiB

# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base head class.
All the different kinds of prediction heads in different models will inherit
from this class. What is in common between all head classes is that they have a
`predict` function that receives `features` as its first argument.
How to add a new prediction head to an existing meta architecture?
For example, how can we add a `3d shape` prediction head to Mask RCNN?
We have to take the following steps to add a new prediction head to an
existing meta arch:
(a) Add a class for predicting the head. This class should inherit from the
`Head` class below and have a `predict` function that receives the features
and predicts the output. The output is always a tf.float32 tensor.
(b) Add the head to the meta architecture. For example in case of Mask RCNN,
go to box_predictor_builder and put in the logic for adding the new head to the
Mask RCNN box predictor.
(c) Add the logic for computing the loss for the new head.
(d) Add the necessary metrics for the new head.
(e) (optional) Add visualization for the new head.
"""
from abc import abstractmethod
import tensorflow as tf
class Head(object):
"""Mask RCNN head base class."""
def __init__(self):
"""Constructor."""
pass
@abstractmethod
def predict(self, features, num_predictions_per_location):
"""Returns the head's predictions.
Args:
features: A float tensor of features.
num_predictions_per_location: Int containing number of predictions per
location.
Returns:
A tf.float32 tensor.
"""
pass
class KerasHead(tf.keras.Model):
"""Keras head base class."""
def call(self, features):
"""The Keras model call will delegate to the `_predict` method."""
return self._predict(features)
@abstractmethod
def _predict(self, features):
"""Returns the head's predictions.
Args:
features: A float tensor of features.
Returns:
A tf.float32 tensor.
"""
pass