You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

142 lines
5.5 KiB

6 years ago
  1. # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. r"""Creates and runs `Estimator` for object detection model on TPUs.
  16. This uses the TPUEstimator API to define and run a model in TRAIN/EVAL modes.
  17. """
  18. # pylint: enable=line-too-long
  19. from __future__ import absolute_import
  20. from __future__ import division
  21. from __future__ import print_function
  22. from absl import flags
  23. import tensorflow as tf
  24. from object_detection import model_hparams
  25. from object_detection import model_lib
  26. tf.flags.DEFINE_bool('use_tpu', True, 'Use TPUs rather than plain CPUs')
  27. # Cloud TPU Cluster Resolvers
  28. flags.DEFINE_string(
  29. 'gcp_project',
  30. default=None,
  31. help='Project name for the Cloud TPU-enabled project. If not specified, we '
  32. 'will attempt to automatically detect the GCE project from metadata.')
  33. flags.DEFINE_string(
  34. 'tpu_zone',
  35. default=None,
  36. help='GCE zone where the Cloud TPU is located in. If not specified, we '
  37. 'will attempt to automatically detect the GCE project from metadata.')
  38. flags.DEFINE_string(
  39. 'tpu_name',
  40. default=None,
  41. help='Name of the Cloud TPU for Cluster Resolvers.')
  42. flags.DEFINE_integer('num_shards', 8, 'Number of shards (TPU cores).')
  43. flags.DEFINE_integer('iterations_per_loop', 100,
  44. 'Number of iterations per TPU training loop.')
  45. # For mode=train_and_eval, evaluation occurs after training is finished.
  46. # Note: independently of steps_per_checkpoint, estimator will save the most
  47. # recent checkpoint every 10 minutes by default for train_and_eval
  48. flags.DEFINE_string('mode', 'train',
  49. 'Mode to run: train, eval')
  50. flags.DEFINE_integer('train_batch_size', None, 'Batch size for training. If '
  51. 'this is not provided, batch size is read from training '
  52. 'config.')
  53. flags.DEFINE_string(
  54. 'hparams_overrides', None, 'Comma-separated list of '
  55. 'hyperparameters to override defaults.')
  56. flags.DEFINE_integer('num_train_steps', None, 'Number of train steps.')
  57. flags.DEFINE_boolean('eval_training_data', False,
  58. 'If training data should be evaluated for this job.')
  59. flags.DEFINE_integer('sample_1_of_n_eval_examples', 1, 'Will sample one of '
  60. 'every n eval input examples, where n is provided.')
  61. flags.DEFINE_integer('sample_1_of_n_eval_on_train_examples', 5, 'Will sample '
  62. 'one of every n train input examples for evaluation, '
  63. 'where n is provided. This is only used if '
  64. '`eval_training_data` is True.')
  65. flags.DEFINE_string(
  66. 'model_dir', None, 'Path to output model directory '
  67. 'where event and checkpoint files will be written.')
  68. flags.DEFINE_string('pipeline_config_path', None, 'Path to pipeline config '
  69. 'file.')
  70. FLAGS = tf.flags.FLAGS
  71. def main(unused_argv):
  72. flags.mark_flag_as_required('model_dir')
  73. flags.mark_flag_as_required('pipeline_config_path')
  74. tpu_cluster_resolver = (
  75. tf.contrib.cluster_resolver.TPUClusterResolver(
  76. tpu=[FLAGS.tpu_name],
  77. zone=FLAGS.tpu_zone,
  78. project=FLAGS.gcp_project))
  79. tpu_grpc_url = tpu_cluster_resolver.get_master()
  80. config = tf.contrib.tpu.RunConfig(
  81. master=tpu_grpc_url,
  82. evaluation_master=tpu_grpc_url,
  83. model_dir=FLAGS.model_dir,
  84. tpu_config=tf.contrib.tpu.TPUConfig(
  85. iterations_per_loop=FLAGS.iterations_per_loop,
  86. num_shards=FLAGS.num_shards))
  87. kwargs = {}
  88. if FLAGS.train_batch_size:
  89. kwargs['batch_size'] = FLAGS.train_batch_size
  90. train_and_eval_dict = model_lib.create_estimator_and_inputs(
  91. run_config=config,
  92. hparams=model_hparams.create_hparams(FLAGS.hparams_overrides),
  93. pipeline_config_path=FLAGS.pipeline_config_path,
  94. train_steps=FLAGS.num_train_steps,
  95. sample_1_of_n_eval_examples=FLAGS.sample_1_of_n_eval_examples,
  96. sample_1_of_n_eval_on_train_examples=(
  97. FLAGS.sample_1_of_n_eval_on_train_examples),
  98. use_tpu_estimator=True,
  99. use_tpu=FLAGS.use_tpu,
  100. num_shards=FLAGS.num_shards,
  101. save_final_config=FLAGS.mode == 'train',
  102. **kwargs)
  103. estimator = train_and_eval_dict['estimator']
  104. train_input_fn = train_and_eval_dict['train_input_fn']
  105. eval_input_fns = train_and_eval_dict['eval_input_fns']
  106. eval_on_train_input_fn = train_and_eval_dict['eval_on_train_input_fn']
  107. train_steps = train_and_eval_dict['train_steps']
  108. if FLAGS.mode == 'train':
  109. estimator.train(input_fn=train_input_fn, max_steps=train_steps)
  110. # Continuously evaluating.
  111. if FLAGS.mode == 'eval':
  112. if FLAGS.eval_training_data:
  113. name = 'training_data'
  114. input_fn = eval_on_train_input_fn
  115. else:
  116. name = 'validation_data'
  117. # Currently only a single eval input is allowed.
  118. input_fn = eval_input_fns[0]
  119. model_lib.continuous_eval(estimator, FLAGS.model_dir, input_fn, train_steps,
  120. name)
  121. if __name__ == '__main__':
  122. tf.app.run()