|
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
r"""Exports an SSD detection model to use with tf-lite.
|
|
|
|
Outputs file:
|
|
* A tflite compatible frozen graph - $output_directory/tflite_graph.pb
|
|
|
|
The exported graph has the following input and output nodes.
|
|
|
|
Inputs:
|
|
'normalized_input_image_tensor': a float32 tensor of shape
|
|
[1, height, width, 3] containing the normalized input image. Note that the
|
|
height and width must be compatible with the height and width configured in
|
|
the fixed_shape_image resizer options in the pipeline config proto.
|
|
|
|
In floating point Mobilenet model, 'normalized_image_tensor' has values
|
|
between [-1,1). This typically means mapping each pixel (linearly)
|
|
to a value between [-1, 1]. Input image
|
|
values between 0 and 255 are scaled by (1/128.0) and then a value of
|
|
-1 is added to them to ensure the range is [-1,1).
|
|
In quantized Mobilenet model, 'normalized_image_tensor' has values between [0,
|
|
255].
|
|
In general, see the `preprocess` function defined in the feature extractor class
|
|
in the object_detection/models directory.
|
|
|
|
Outputs:
|
|
If add_postprocessing_op is true: frozen graph adds a
|
|
TFLite_Detection_PostProcess custom op node has four outputs:
|
|
detection_boxes: a float32 tensor of shape [1, num_boxes, 4] with box
|
|
locations
|
|
detection_classes: a float32 tensor of shape [1, num_boxes]
|
|
with class indices
|
|
detection_scores: a float32 tensor of shape [1, num_boxes]
|
|
with class scores
|
|
num_boxes: a float32 tensor of size 1 containing the number of detected boxes
|
|
else:
|
|
the graph has two outputs:
|
|
'raw_outputs/box_encodings': a float32 tensor of shape [1, num_anchors, 4]
|
|
containing the encoded box predictions.
|
|
'raw_outputs/class_predictions': a float32 tensor of shape
|
|
[1, num_anchors, num_classes] containing the class scores for each anchor
|
|
after applying score conversion.
|
|
|
|
Example Usage:
|
|
--------------
|
|
python object_detection/export_tflite_ssd_graph \
|
|
--pipeline_config_path path/to/ssd_mobilenet.config \
|
|
--trained_checkpoint_prefix path/to/model.ckpt \
|
|
--output_directory path/to/exported_model_directory
|
|
|
|
The expected output would be in the directory
|
|
path/to/exported_model_directory (which is created if it does not exist)
|
|
with contents:
|
|
- tflite_graph.pbtxt
|
|
- tflite_graph.pb
|
|
Config overrides (see the `config_override` flag) are text protobufs
|
|
(also of type pipeline_pb2.TrainEvalPipelineConfig) which are used to override
|
|
certain fields in the provided pipeline_config_path. These are useful for
|
|
making small changes to the inference graph that differ from the training or
|
|
eval config.
|
|
|
|
Example Usage (in which we change the NMS iou_threshold to be 0.5 and
|
|
NMS score_threshold to be 0.0):
|
|
python object_detection/export_tflite_ssd_graph \
|
|
--pipeline_config_path path/to/ssd_mobilenet.config \
|
|
--trained_checkpoint_prefix path/to/model.ckpt \
|
|
--output_directory path/to/exported_model_directory
|
|
--config_override " \
|
|
model{ \
|
|
ssd{ \
|
|
post_processing { \
|
|
batch_non_max_suppression { \
|
|
score_threshold: 0.0 \
|
|
iou_threshold: 0.5 \
|
|
} \
|
|
} \
|
|
} \
|
|
} \
|
|
"
|
|
"""
|
|
|
|
import tensorflow as tf
|
|
from google.protobuf import text_format
|
|
from object_detection import export_tflite_ssd_graph_lib
|
|
from object_detection.protos import pipeline_pb2
|
|
|
|
flags = tf.app.flags
|
|
flags.DEFINE_string('output_directory', None, 'Path to write outputs.')
|
|
flags.DEFINE_string(
|
|
'pipeline_config_path', None,
|
|
'Path to a pipeline_pb2.TrainEvalPipelineConfig config '
|
|
'file.')
|
|
flags.DEFINE_string('trained_checkpoint_prefix', None, 'Checkpoint prefix.')
|
|
flags.DEFINE_integer('max_detections', 10,
|
|
'Maximum number of detections (boxes) to show.')
|
|
flags.DEFINE_integer('max_classes_per_detection', 1,
|
|
'Number of classes to display per detection box.')
|
|
flags.DEFINE_integer(
|
|
'detections_per_class', 100,
|
|
'Number of anchors used per class in Regular Non-Max-Suppression.')
|
|
flags.DEFINE_bool('add_postprocessing_op', True,
|
|
'Add TFLite custom op for postprocessing to the graph.')
|
|
flags.DEFINE_bool(
|
|
'use_regular_nms', False,
|
|
'Flag to set postprocessing op to use Regular NMS instead of Fast NMS.')
|
|
flags.DEFINE_string(
|
|
'config_override', '', 'pipeline_pb2.TrainEvalPipelineConfig '
|
|
'text proto to override pipeline_config_path.')
|
|
|
|
FLAGS = flags.FLAGS
|
|
|
|
|
|
def main(argv):
|
|
del argv # Unused.
|
|
flags.mark_flag_as_required('output_directory')
|
|
flags.mark_flag_as_required('pipeline_config_path')
|
|
flags.mark_flag_as_required('trained_checkpoint_prefix')
|
|
|
|
pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
|
|
|
with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f:
|
|
text_format.Merge(f.read(), pipeline_config)
|
|
text_format.Merge(FLAGS.config_override, pipeline_config)
|
|
export_tflite_ssd_graph_lib.export_tflite_graph(
|
|
pipeline_config, FLAGS.trained_checkpoint_prefix, FLAGS.output_directory,
|
|
FLAGS.add_postprocessing_op, FLAGS.max_detections,
|
|
FLAGS.max_classes_per_detection, FLAGS.use_regular_nms)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
tf.app.run(main)
|