You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

143 lines
5.7 KiB

6 years ago
  1. # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. r"""Exports an SSD detection model to use with tf-lite.
  16. Outputs file:
  17. * A tflite compatible frozen graph - $output_directory/tflite_graph.pb
  18. The exported graph has the following input and output nodes.
  19. Inputs:
  20. 'normalized_input_image_tensor': a float32 tensor of shape
  21. [1, height, width, 3] containing the normalized input image. Note that the
  22. height and width must be compatible with the height and width configured in
  23. the fixed_shape_image resizer options in the pipeline config proto.
  24. In floating point Mobilenet model, 'normalized_image_tensor' has values
  25. between [-1,1). This typically means mapping each pixel (linearly)
  26. to a value between [-1, 1]. Input image
  27. values between 0 and 255 are scaled by (1/128.0) and then a value of
  28. -1 is added to them to ensure the range is [-1,1).
  29. In quantized Mobilenet model, 'normalized_image_tensor' has values between [0,
  30. 255].
  31. In general, see the `preprocess` function defined in the feature extractor class
  32. in the object_detection/models directory.
  33. Outputs:
  34. If add_postprocessing_op is true: frozen graph adds a
  35. TFLite_Detection_PostProcess custom op node has four outputs:
  36. detection_boxes: a float32 tensor of shape [1, num_boxes, 4] with box
  37. locations
  38. detection_classes: a float32 tensor of shape [1, num_boxes]
  39. with class indices
  40. detection_scores: a float32 tensor of shape [1, num_boxes]
  41. with class scores
  42. num_boxes: a float32 tensor of size 1 containing the number of detected boxes
  43. else:
  44. the graph has two outputs:
  45. 'raw_outputs/box_encodings': a float32 tensor of shape [1, num_anchors, 4]
  46. containing the encoded box predictions.
  47. 'raw_outputs/class_predictions': a float32 tensor of shape
  48. [1, num_anchors, num_classes] containing the class scores for each anchor
  49. after applying score conversion.
  50. Example Usage:
  51. --------------
  52. python object_detection/export_tflite_ssd_graph \
  53. --pipeline_config_path path/to/ssd_mobilenet.config \
  54. --trained_checkpoint_prefix path/to/model.ckpt \
  55. --output_directory path/to/exported_model_directory
  56. The expected output would be in the directory
  57. path/to/exported_model_directory (which is created if it does not exist)
  58. with contents:
  59. - tflite_graph.pbtxt
  60. - tflite_graph.pb
  61. Config overrides (see the `config_override` flag) are text protobufs
  62. (also of type pipeline_pb2.TrainEvalPipelineConfig) which are used to override
  63. certain fields in the provided pipeline_config_path. These are useful for
  64. making small changes to the inference graph that differ from the training or
  65. eval config.
  66. Example Usage (in which we change the NMS iou_threshold to be 0.5 and
  67. NMS score_threshold to be 0.0):
  68. python object_detection/export_tflite_ssd_graph \
  69. --pipeline_config_path path/to/ssd_mobilenet.config \
  70. --trained_checkpoint_prefix path/to/model.ckpt \
  71. --output_directory path/to/exported_model_directory
  72. --config_override " \
  73. model{ \
  74. ssd{ \
  75. post_processing { \
  76. batch_non_max_suppression { \
  77. score_threshold: 0.0 \
  78. iou_threshold: 0.5 \
  79. } \
  80. } \
  81. } \
  82. } \
  83. "
  84. """
  85. import tensorflow as tf
  86. from google.protobuf import text_format
  87. from object_detection import export_tflite_ssd_graph_lib
  88. from object_detection.protos import pipeline_pb2
  89. flags = tf.app.flags
  90. flags.DEFINE_string('output_directory', None, 'Path to write outputs.')
  91. flags.DEFINE_string(
  92. 'pipeline_config_path', None,
  93. 'Path to a pipeline_pb2.TrainEvalPipelineConfig config '
  94. 'file.')
  95. flags.DEFINE_string('trained_checkpoint_prefix', None, 'Checkpoint prefix.')
  96. flags.DEFINE_integer('max_detections', 10,
  97. 'Maximum number of detections (boxes) to show.')
  98. flags.DEFINE_integer('max_classes_per_detection', 1,
  99. 'Number of classes to display per detection box.')
  100. flags.DEFINE_integer(
  101. 'detections_per_class', 100,
  102. 'Number of anchors used per class in Regular Non-Max-Suppression.')
  103. flags.DEFINE_bool('add_postprocessing_op', True,
  104. 'Add TFLite custom op for postprocessing to the graph.')
  105. flags.DEFINE_bool(
  106. 'use_regular_nms', False,
  107. 'Flag to set postprocessing op to use Regular NMS instead of Fast NMS.')
  108. flags.DEFINE_string(
  109. 'config_override', '', 'pipeline_pb2.TrainEvalPipelineConfig '
  110. 'text proto to override pipeline_config_path.')
  111. FLAGS = flags.FLAGS
  112. def main(argv):
  113. del argv # Unused.
  114. flags.mark_flag_as_required('output_directory')
  115. flags.mark_flag_as_required('pipeline_config_path')
  116. flags.mark_flag_as_required('trained_checkpoint_prefix')
  117. pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  118. with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f:
  119. text_format.Merge(f.read(), pipeline_config)
  120. text_format.Merge(FLAGS.config_override, pipeline_config)
  121. export_tflite_ssd_graph_lib.export_tflite_graph(
  122. pipeline_config, FLAGS.trained_checkpoint_prefix, FLAGS.output_directory,
  123. FLAGS.add_postprocessing_op, FLAGS.max_detections,
  124. FLAGS.max_classes_per_detection, FLAGS.use_regular_nms)
  125. if __name__ == '__main__':
  126. tf.app.run(main)