You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

156 lines
6.7 KiB

  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. r"""Tool to export an object detection model for inference.
  16. Prepares an object detection tensorflow graph for inference using model
  17. configuration and a trained checkpoint. Outputs inference
  18. graph, associated checkpoint files, a frozen inference graph and a
  19. SavedModel (https://tensorflow.github.io/serving/serving_basic.html).
  20. The inference graph contains one of three input nodes depending on the user
  21. specified option.
  22. * `image_tensor`: Accepts a uint8 4-D tensor of shape [None, None, None, 3]
  23. * `encoded_image_string_tensor`: Accepts a 1-D string tensor of shape [None]
  24. containing encoded PNG or JPEG images. Image resolutions are expected to be
  25. the same if more than 1 image is provided.
  26. * `tf_example`: Accepts a 1-D string tensor of shape [None] containing
  27. serialized TFExample protos. Image resolutions are expected to be the same
  28. if more than 1 image is provided.
  29. and the following output nodes returned by the model.postprocess(..):
  30. * `num_detections`: Outputs float32 tensors of the form [batch]
  31. that specifies the number of valid boxes per image in the batch.
  32. * `detection_boxes`: Outputs float32 tensors of the form
  33. [batch, num_boxes, 4] containing detected boxes.
  34. * `detection_scores`: Outputs float32 tensors of the form
  35. [batch, num_boxes] containing class scores for the detections.
  36. * `detection_classes`: Outputs float32 tensors of the form
  37. [batch, num_boxes] containing classes for the detections.
  38. * `raw_detection_boxes`: Outputs float32 tensors of the form
  39. [batch, raw_num_boxes, 4] containing detection boxes without
  40. post-processing.
  41. * `raw_detection_scores`: Outputs float32 tensors of the form
  42. [batch, raw_num_boxes, num_classes_with_background] containing class score
  43. logits for raw detection boxes.
  44. * `detection_masks`: Outputs float32 tensors of the form
  45. [batch, num_boxes, mask_height, mask_width] containing predicted instance
  46. masks for each box if its present in the dictionary of postprocessed
  47. tensors returned by the model.
  48. Notes:
  49. * This tool uses `use_moving_averages` from eval_config to decide which
  50. weights to freeze.
  51. Example Usage:
  52. --------------
  53. python export_inference_graph \
  54. --input_type image_tensor \
  55. --pipeline_config_path path/to/ssd_inception_v2.config \
  56. --trained_checkpoint_prefix path/to/model.ckpt \
  57. --output_directory path/to/exported_model_directory
  58. The expected output would be in the directory
  59. path/to/exported_model_directory (which is created if it does not exist)
  60. with contents:
  61. - inference_graph.pbtxt
  62. - model.ckpt.data-00000-of-00001
  63. - model.ckpt.info
  64. - model.ckpt.meta
  65. - frozen_inference_graph.pb
  66. + saved_model (a directory)
  67. Config overrides (see the `config_override` flag) are text protobufs
  68. (also of type pipeline_pb2.TrainEvalPipelineConfig) which are used to override
  69. certain fields in the provided pipeline_config_path. These are useful for
  70. making small changes to the inference graph that differ from the training or
  71. eval config.
  72. Example Usage (in which we change the second stage post-processing score
  73. threshold to be 0.5):
  74. python export_inference_graph \
  75. --input_type image_tensor \
  76. --pipeline_config_path path/to/ssd_inception_v2.config \
  77. --trained_checkpoint_prefix path/to/model.ckpt \
  78. --output_directory path/to/exported_model_directory \
  79. --config_override " \
  80. model{ \
  81. faster_rcnn { \
  82. second_stage_post_processing { \
  83. batch_non_max_suppression { \
  84. score_threshold: 0.5 \
  85. } \
  86. } \
  87. } \
  88. }"
  89. """
  90. import tensorflow as tf
  91. from google.protobuf import text_format
  92. from object_detection import exporter
  93. from object_detection.protos import pipeline_pb2
  94. slim = tf.contrib.slim
  95. flags = tf.app.flags
  96. flags.DEFINE_string('input_type', 'image_tensor', 'Type of input node. Can be '
  97. 'one of [`image_tensor`, `encoded_image_string_tensor`, '
  98. '`tf_example`]')
  99. flags.DEFINE_string('input_shape', None,
  100. 'If input_type is `image_tensor`, this can explicitly set '
  101. 'the shape of this input tensor to a fixed size. The '
  102. 'dimensions are to be provided as a comma-separated list '
  103. 'of integers. A value of -1 can be used for unknown '
  104. 'dimensions. If not specified, for an `image_tensor, the '
  105. 'default shape will be partially specified as '
  106. '`[None, None, None, 3]`.')
  107. flags.DEFINE_string('pipeline_config_path', None,
  108. 'Path to a pipeline_pb2.TrainEvalPipelineConfig config '
  109. 'file.')
  110. flags.DEFINE_string('trained_checkpoint_prefix', None,
  111. 'Path to trained checkpoint, typically of the form '
  112. 'path/to/model.ckpt')
  113. flags.DEFINE_string('output_directory', None, 'Path to write outputs.')
  114. flags.DEFINE_string('config_override', '',
  115. 'pipeline_pb2.TrainEvalPipelineConfig '
  116. 'text proto to override pipeline_config_path.')
  117. flags.DEFINE_boolean('write_inference_graph', False,
  118. 'If true, writes inference graph to disk.')
  119. tf.app.flags.mark_flag_as_required('pipeline_config_path')
  120. tf.app.flags.mark_flag_as_required('trained_checkpoint_prefix')
  121. tf.app.flags.mark_flag_as_required('output_directory')
  122. FLAGS = flags.FLAGS
  123. def main(_):
  124. pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
  125. with tf.gfile.GFile(FLAGS.pipeline_config_path, 'r') as f:
  126. text_format.Merge(f.read(), pipeline_config)
  127. text_format.Merge(FLAGS.config_override, pipeline_config)
  128. if FLAGS.input_shape:
  129. input_shape = [
  130. int(dim) if dim != '-1' else None
  131. for dim in FLAGS.input_shape.split(',')
  132. ]
  133. else:
  134. input_shape = None
  135. exporter.export_inference_graph(
  136. FLAGS.input_type, pipeline_config, FLAGS.trained_checkpoint_prefix,
  137. FLAGS.output_directory, input_shape=input_shape,
  138. write_inference_graph=FLAGS.write_inference_graph)
  139. if __name__ == '__main__':
  140. tf.app.run()