You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

109 lines
4.4 KiB

  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Binary to run train and evaluation on object detection model."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. from absl import flags
  20. import tensorflow as tf
  21. from object_detection import model_hparams
  22. from object_detection import model_lib
  23. flags.DEFINE_string(
  24. 'model_dir', None, 'Path to output model directory '
  25. 'where event and checkpoint files will be written.')
  26. flags.DEFINE_string('pipeline_config_path', None, 'Path to pipeline config '
  27. 'file.')
  28. flags.DEFINE_integer('num_train_steps', None, 'Number of train steps.')
  29. flags.DEFINE_boolean('eval_training_data', False,
  30. 'If training data should be evaluated for this job. Note '
  31. 'that one call only use this in eval-only mode, and '
  32. '`checkpoint_dir` must be supplied.')
  33. flags.DEFINE_integer('sample_1_of_n_eval_examples', 1, 'Will sample one of '
  34. 'every n eval input examples, where n is provided.')
  35. flags.DEFINE_integer('sample_1_of_n_eval_on_train_examples', 5, 'Will sample '
  36. 'one of every n train input examples for evaluation, '
  37. 'where n is provided. This is only used if '
  38. '`eval_training_data` is True.')
  39. flags.DEFINE_string(
  40. 'hparams_overrides', None, 'Hyperparameter overrides, '
  41. 'represented as a string containing comma-separated '
  42. 'hparam_name=value pairs.')
  43. flags.DEFINE_string(
  44. 'checkpoint_dir', None, 'Path to directory holding a checkpoint. If '
  45. '`checkpoint_dir` is provided, this binary operates in eval-only mode, '
  46. 'writing resulting metrics to `model_dir`.')
  47. flags.DEFINE_boolean(
  48. 'run_once', False, 'If running in eval-only mode, whether to run just '
  49. 'one round of eval vs running continuously (default).'
  50. )
  51. FLAGS = flags.FLAGS
  52. def main(unused_argv):
  53. flags.mark_flag_as_required('model_dir')
  54. flags.mark_flag_as_required('pipeline_config_path')
  55. config = tf.estimator.RunConfig(model_dir=FLAGS.model_dir)
  56. train_and_eval_dict = model_lib.create_estimator_and_inputs(
  57. run_config=config,
  58. hparams=model_hparams.create_hparams(FLAGS.hparams_overrides),
  59. pipeline_config_path=FLAGS.pipeline_config_path,
  60. train_steps=FLAGS.num_train_steps,
  61. sample_1_of_n_eval_examples=FLAGS.sample_1_of_n_eval_examples,
  62. sample_1_of_n_eval_on_train_examples=(
  63. FLAGS.sample_1_of_n_eval_on_train_examples))
  64. estimator = train_and_eval_dict['estimator']
  65. train_input_fn = train_and_eval_dict['train_input_fn']
  66. eval_input_fns = train_and_eval_dict['eval_input_fns']
  67. eval_on_train_input_fn = train_and_eval_dict['eval_on_train_input_fn']
  68. predict_input_fn = train_and_eval_dict['predict_input_fn']
  69. train_steps = train_and_eval_dict['train_steps']
  70. if FLAGS.checkpoint_dir:
  71. if FLAGS.eval_training_data:
  72. name = 'training_data'
  73. input_fn = eval_on_train_input_fn
  74. else:
  75. name = 'validation_data'
  76. # The first eval input will be evaluated.
  77. input_fn = eval_input_fns[0]
  78. if FLAGS.run_once:
  79. estimator.evaluate(input_fn,
  80. steps=None,
  81. checkpoint_path=tf.train.latest_checkpoint(
  82. FLAGS.checkpoint_dir))
  83. else:
  84. model_lib.continuous_eval(estimator, FLAGS.checkpoint_dir, input_fn,
  85. train_steps, name)
  86. else:
  87. train_spec, eval_specs = model_lib.create_train_and_eval_specs(
  88. train_input_fn,
  89. eval_input_fns,
  90. eval_on_train_input_fn,
  91. predict_input_fn,
  92. train_steps,
  93. eval_on_train_data=False)
  94. # Currently only a single Eval Spec is allowed.
  95. tf.estimator.train_and_evaluate(estimator, train_spec, eval_specs[0])
  96. if __name__ == '__main__':
  97. tf.app.run()