You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

184 lines
6.7 KiB

6 years ago
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. r"""Training executable for detection models.
  16. This executable is used to train DetectionModels. There are two ways of
  17. configuring the training job:
  18. 1) A single pipeline_pb2.TrainEvalPipelineConfig configuration file
  19. can be specified by --pipeline_config_path.
  20. Example usage:
  21. ./train \
  22. --logtostderr \
  23. --train_dir=path/to/train_dir \
  24. --pipeline_config_path=pipeline_config.pbtxt
  25. 2) Three configuration files can be provided: a model_pb2.DetectionModel
  26. configuration file to define what type of DetectionModel is being trained, an
  27. input_reader_pb2.InputReader file to specify what training data will be used and
  28. a train_pb2.TrainConfig file to configure training parameters.
  29. Example usage:
  30. ./train \
  31. --logtostderr \
  32. --train_dir=path/to/train_dir \
  33. --model_config_path=model_config.pbtxt \
  34. --train_config_path=train_config.pbtxt \
  35. --input_config_path=train_input_config.pbtxt
  36. """
  37. import functools
  38. import json
  39. import os
  40. import tensorflow as tf
  41. from object_detection.builders import dataset_builder
  42. from object_detection.builders import graph_rewriter_builder
  43. from object_detection.builders import model_builder
  44. from object_detection.legacy import trainer
  45. from object_detection.utils import config_util
  46. tf.logging.set_verbosity(tf.logging.INFO)
  47. flags = tf.app.flags
  48. flags.DEFINE_string('master', '', 'Name of the TensorFlow master to use.')
  49. flags.DEFINE_integer('task', 0, 'task id')
  50. flags.DEFINE_integer('num_clones', 1, 'Number of clones to deploy per worker.')
  51. flags.DEFINE_boolean('clone_on_cpu', False,
  52. 'Force clones to be deployed on CPU. Note that even if '
  53. 'set to False (allowing ops to run on gpu), some ops may '
  54. 'still be run on the CPU if they have no GPU kernel.')
  55. flags.DEFINE_integer('worker_replicas', 1, 'Number of worker+trainer '
  56. 'replicas.')
  57. flags.DEFINE_integer('ps_tasks', 0,
  58. 'Number of parameter server tasks. If None, does not use '
  59. 'a parameter server.')
  60. flags.DEFINE_string('train_dir', '',
  61. 'Directory to save the checkpoints and training summaries.')
  62. flags.DEFINE_string('pipeline_config_path', '',
  63. 'Path to a pipeline_pb2.TrainEvalPipelineConfig config '
  64. 'file. If provided, other configs are ignored')
  65. flags.DEFINE_string('train_config_path', '',
  66. 'Path to a train_pb2.TrainConfig config file.')
  67. flags.DEFINE_string('input_config_path', '',
  68. 'Path to an input_reader_pb2.InputReader config file.')
  69. flags.DEFINE_string('model_config_path', '',
  70. 'Path to a model_pb2.DetectionModel config file.')
  71. FLAGS = flags.FLAGS
  72. @tf.contrib.framework.deprecated(None, 'Use object_detection/model_main.py.')
  73. def main(_):
  74. assert FLAGS.train_dir, '`train_dir` is missing.'
  75. if FLAGS.task == 0: tf.gfile.MakeDirs(FLAGS.train_dir)
  76. if FLAGS.pipeline_config_path:
  77. configs = config_util.get_configs_from_pipeline_file(
  78. FLAGS.pipeline_config_path)
  79. if FLAGS.task == 0:
  80. tf.gfile.Copy(FLAGS.pipeline_config_path,
  81. os.path.join(FLAGS.train_dir, 'pipeline.config'),
  82. overwrite=True)
  83. else:
  84. configs = config_util.get_configs_from_multiple_files(
  85. model_config_path=FLAGS.model_config_path,
  86. train_config_path=FLAGS.train_config_path,
  87. train_input_config_path=FLAGS.input_config_path)
  88. if FLAGS.task == 0:
  89. for name, config in [('model.config', FLAGS.model_config_path),
  90. ('train.config', FLAGS.train_config_path),
  91. ('input.config', FLAGS.input_config_path)]:
  92. tf.gfile.Copy(config, os.path.join(FLAGS.train_dir, name),
  93. overwrite=True)
  94. model_config = configs['model']
  95. train_config = configs['train_config']
  96. input_config = configs['train_input_config']
  97. model_fn = functools.partial(
  98. model_builder.build,
  99. model_config=model_config,
  100. is_training=True)
  101. def get_next(config):
  102. return dataset_builder.make_initializable_iterator(
  103. dataset_builder.build(config)).get_next()
  104. create_input_dict_fn = functools.partial(get_next, input_config)
  105. env = json.loads(os.environ.get('TF_CONFIG', '{}'))
  106. cluster_data = env.get('cluster', None)
  107. cluster = tf.train.ClusterSpec(cluster_data) if cluster_data else None
  108. task_data = env.get('task', None) or {'type': 'master', 'index': 0}
  109. task_info = type('TaskSpec', (object,), task_data)
  110. # Parameters for a single worker.
  111. ps_tasks = 0
  112. worker_replicas = 1
  113. worker_job_name = 'lonely_worker'
  114. task = 0
  115. is_chief = True
  116. master = ''
  117. if cluster_data and 'worker' in cluster_data:
  118. # Number of total worker replicas include "worker"s and the "master".
  119. worker_replicas = len(cluster_data['worker']) + 1
  120. if cluster_data and 'ps' in cluster_data:
  121. ps_tasks = len(cluster_data['ps'])
  122. if worker_replicas > 1 and ps_tasks < 1:
  123. raise ValueError('At least 1 ps task is needed for distributed training.')
  124. if worker_replicas >= 1 and ps_tasks > 0:
  125. # Set up distributed training.
  126. server = tf.train.Server(tf.train.ClusterSpec(cluster), protocol='grpc',
  127. job_name=task_info.type,
  128. task_index=task_info.index)
  129. if task_info.type == 'ps':
  130. server.join()
  131. return
  132. worker_job_name = '%s/task:%d' % (task_info.type, task_info.index)
  133. task = task_info.index
  134. is_chief = (task_info.type == 'master')
  135. master = server.target
  136. graph_rewriter_fn = None
  137. if 'graph_rewriter_config' in configs:
  138. graph_rewriter_fn = graph_rewriter_builder.build(
  139. configs['graph_rewriter_config'], is_training=True)
  140. trainer.train(
  141. create_input_dict_fn,
  142. model_fn,
  143. train_config,
  144. master,
  145. task,
  146. FLAGS.num_clones,
  147. worker_replicas,
  148. FLAGS.clone_on_cpu,
  149. ps_tasks,
  150. worker_job_name,
  151. is_chief,
  152. FLAGS.train_dir,
  153. graph_hook_fn=graph_rewriter_fn)
  154. if __name__ == '__main__':
  155. tf.app.run()