You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

142 lines
5.3 KiB

  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. r"""Evaluation executable for detection models.
  16. This executable is used to evaluate DetectionModels. There are two ways of
  17. configuring the eval job.
  18. 1) A single pipeline_pb2.TrainEvalPipelineConfig file maybe specified instead.
  19. In this mode, the --eval_training_data flag may be given to force the pipeline
  20. to evaluate on training data instead.
  21. Example usage:
  22. ./eval \
  23. --logtostderr \
  24. --checkpoint_dir=path/to/checkpoint_dir \
  25. --eval_dir=path/to/eval_dir \
  26. --pipeline_config_path=pipeline_config.pbtxt
  27. 2) Three configuration files may be provided: a model_pb2.DetectionModel
  28. configuration file to define what type of DetectionModel is being evaluated, an
  29. input_reader_pb2.InputReader file to specify what data the model is evaluating
  30. and an eval_pb2.EvalConfig file to configure evaluation parameters.
  31. Example usage:
  32. ./eval \
  33. --logtostderr \
  34. --checkpoint_dir=path/to/checkpoint_dir \
  35. --eval_dir=path/to/eval_dir \
  36. --eval_config_path=eval_config.pbtxt \
  37. --model_config_path=model_config.pbtxt \
  38. --input_config_path=eval_input_config.pbtxt
  39. """
  40. import functools
  41. import os
  42. import tensorflow as tf
  43. from object_detection.builders import dataset_builder
  44. from object_detection.builders import graph_rewriter_builder
  45. from object_detection.builders import model_builder
  46. from object_detection.legacy import evaluator
  47. from object_detection.utils import config_util
  48. from object_detection.utils import label_map_util
  49. tf.logging.set_verbosity(tf.logging.INFO)
  50. flags = tf.app.flags
  51. flags.DEFINE_boolean('eval_training_data', False,
  52. 'If training data should be evaluated for this job.')
  53. flags.DEFINE_string(
  54. 'checkpoint_dir', '',
  55. 'Directory containing checkpoints to evaluate, typically '
  56. 'set to `train_dir` used in the training job.')
  57. flags.DEFINE_string('eval_dir', '', 'Directory to write eval summaries to.')
  58. flags.DEFINE_string(
  59. 'pipeline_config_path', '',
  60. 'Path to a pipeline_pb2.TrainEvalPipelineConfig config '
  61. 'file. If provided, other configs are ignored')
  62. flags.DEFINE_string('eval_config_path', '',
  63. 'Path to an eval_pb2.EvalConfig config file.')
  64. flags.DEFINE_string('input_config_path', '',
  65. 'Path to an input_reader_pb2.InputReader config file.')
  66. flags.DEFINE_string('model_config_path', '',
  67. 'Path to a model_pb2.DetectionModel config file.')
  68. flags.DEFINE_boolean(
  69. 'run_once', False, 'Option to only run a single pass of '
  70. 'evaluation. Overrides the `max_evals` parameter in the '
  71. 'provided config.')
  72. FLAGS = flags.FLAGS
  73. @tf.contrib.framework.deprecated(None, 'Use object_detection/model_main.py.')
  74. def main(unused_argv):
  75. assert FLAGS.checkpoint_dir, '`checkpoint_dir` is missing.'
  76. assert FLAGS.eval_dir, '`eval_dir` is missing.'
  77. tf.gfile.MakeDirs(FLAGS.eval_dir)
  78. if FLAGS.pipeline_config_path:
  79. configs = config_util.get_configs_from_pipeline_file(
  80. FLAGS.pipeline_config_path)
  81. tf.gfile.Copy(
  82. FLAGS.pipeline_config_path,
  83. os.path.join(FLAGS.eval_dir, 'pipeline.config'),
  84. overwrite=True)
  85. else:
  86. configs = config_util.get_configs_from_multiple_files(
  87. model_config_path=FLAGS.model_config_path,
  88. eval_config_path=FLAGS.eval_config_path,
  89. eval_input_config_path=FLAGS.input_config_path)
  90. for name, config in [('model.config', FLAGS.model_config_path),
  91. ('eval.config', FLAGS.eval_config_path),
  92. ('input.config', FLAGS.input_config_path)]:
  93. tf.gfile.Copy(config, os.path.join(FLAGS.eval_dir, name), overwrite=True)
  94. model_config = configs['model']
  95. eval_config = configs['eval_config']
  96. input_config = configs['eval_input_config']
  97. if FLAGS.eval_training_data:
  98. input_config = configs['train_input_config']
  99. model_fn = functools.partial(
  100. model_builder.build, model_config=model_config, is_training=False)
  101. def get_next(config):
  102. return dataset_builder.make_initializable_iterator(
  103. dataset_builder.build(config)).get_next()
  104. create_input_dict_fn = functools.partial(get_next, input_config)
  105. categories = label_map_util.create_categories_from_labelmap(
  106. input_config.label_map_path)
  107. if FLAGS.run_once:
  108. eval_config.max_evals = 1
  109. graph_rewriter_fn = None
  110. if 'graph_rewriter_config' in configs:
  111. graph_rewriter_fn = graph_rewriter_builder.build(
  112. configs['graph_rewriter_config'], is_training=False)
  113. evaluator.evaluate(
  114. create_input_dict_fn,
  115. model_fn,
  116. eval_config,
  117. categories,
  118. FLAGS.checkpoint_dir,
  119. FLAGS.eval_dir,
  120. graph_hook_fn=graph_rewriter_fn)
  121. if __name__ == '__main__':
  122. tf.app.run()