You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

81 lines
2.6 KiB

  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Base head class.
  16. All the different kinds of prediction heads in different models will inherit
  17. from this class. What is in common between all head classes is that they have a
  18. `predict` function that receives `features` as its first argument.
  19. How to add a new prediction head to an existing meta architecture?
  20. For example, how can we add a `3d shape` prediction head to Mask RCNN?
  21. We have to take the following steps to add a new prediction head to an
  22. existing meta arch:
  23. (a) Add a class for predicting the head. This class should inherit from the
  24. `Head` class below and have a `predict` function that receives the features
  25. and predicts the output. The output is always a tf.float32 tensor.
  26. (b) Add the head to the meta architecture. For example in case of Mask RCNN,
  27. go to box_predictor_builder and put in the logic for adding the new head to the
  28. Mask RCNN box predictor.
  29. (c) Add the logic for computing the loss for the new head.
  30. (d) Add the necessary metrics for the new head.
  31. (e) (optional) Add visualization for the new head.
  32. """
  33. from abc import abstractmethod
  34. import tensorflow as tf
  35. class Head(object):
  36. """Mask RCNN head base class."""
  37. def __init__(self):
  38. """Constructor."""
  39. pass
  40. @abstractmethod
  41. def predict(self, features, num_predictions_per_location):
  42. """Returns the head's predictions.
  43. Args:
  44. features: A float tensor of features.
  45. num_predictions_per_location: Int containing number of predictions per
  46. location.
  47. Returns:
  48. A tf.float32 tensor.
  49. """
  50. pass
  51. class KerasHead(tf.keras.Model):
  52. """Keras head base class."""
  53. def call(self, features):
  54. """The Keras model call will delegate to the `_predict` method."""
  55. return self._predict(features)
  56. @abstractmethod
  57. def _predict(self, features):
  58. """Returns the head's predictions.
  59. Args:
  60. features: A float tensor of features.
  61. Returns:
  62. A tf.float32 tensor.
  63. """
  64. pass