|
|
- # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
-
- r"""Training executable for detection models.
-
- This executable is used to train DetectionModels. There are two ways of
- configuring the training job:
-
- 1) A single pipeline_pb2.TrainEvalPipelineConfig configuration file
- can be specified by --pipeline_config_path.
-
- Example usage:
- ./train \
- --logtostderr \
- --train_dir=path/to/train_dir \
- --pipeline_config_path=pipeline_config.pbtxt
-
- 2) Three configuration files can be provided: a model_pb2.DetectionModel
- configuration file to define what type of DetectionModel is being trained, an
- input_reader_pb2.InputReader file to specify what training data will be used and
- a train_pb2.TrainConfig file to configure training parameters.
-
- Example usage:
- ./train \
- --logtostderr \
- --train_dir=path/to/train_dir \
- --model_config_path=model_config.pbtxt \
- --train_config_path=train_config.pbtxt \
- --input_config_path=train_input_config.pbtxt
- """
-
- import functools
- import json
- import os
- import tensorflow as tf
-
- from object_detection.builders import dataset_builder
- from object_detection.builders import graph_rewriter_builder
- from object_detection.builders import model_builder
- from object_detection.legacy import trainer
- from object_detection.utils import config_util
-
- tf.logging.set_verbosity(tf.logging.INFO)
-
- flags = tf.app.flags
- flags.DEFINE_string('master', '', 'Name of the TensorFlow master to use.')
- flags.DEFINE_integer('task', 0, 'task id')
- flags.DEFINE_integer('num_clones', 1, 'Number of clones to deploy per worker.')
- flags.DEFINE_boolean('clone_on_cpu', False,
- 'Force clones to be deployed on CPU. Note that even if '
- 'set to False (allowing ops to run on gpu), some ops may '
- 'still be run on the CPU if they have no GPU kernel.')
- flags.DEFINE_integer('worker_replicas', 1, 'Number of worker+trainer '
- 'replicas.')
- flags.DEFINE_integer('ps_tasks', 0,
- 'Number of parameter server tasks. If None, does not use '
- 'a parameter server.')
- flags.DEFINE_string('train_dir', '',
- 'Directory to save the checkpoints and training summaries.')
-
- flags.DEFINE_string('pipeline_config_path', '',
- 'Path to a pipeline_pb2.TrainEvalPipelineConfig config '
- 'file. If provided, other configs are ignored')
-
- flags.DEFINE_string('train_config_path', '',
- 'Path to a train_pb2.TrainConfig config file.')
- flags.DEFINE_string('input_config_path', '',
- 'Path to an input_reader_pb2.InputReader config file.')
- flags.DEFINE_string('model_config_path', '',
- 'Path to a model_pb2.DetectionModel config file.')
-
- FLAGS = flags.FLAGS
-
-
- @tf.contrib.framework.deprecated(None, 'Use object_detection/model_main.py.')
- def main(_):
- assert FLAGS.train_dir, '`train_dir` is missing.'
- if FLAGS.task == 0: tf.gfile.MakeDirs(FLAGS.train_dir)
- if FLAGS.pipeline_config_path:
- configs = config_util.get_configs_from_pipeline_file(
- FLAGS.pipeline_config_path)
- if FLAGS.task == 0:
- tf.gfile.Copy(FLAGS.pipeline_config_path,
- os.path.join(FLAGS.train_dir, 'pipeline.config'),
- overwrite=True)
- else:
- configs = config_util.get_configs_from_multiple_files(
- model_config_path=FLAGS.model_config_path,
- train_config_path=FLAGS.train_config_path,
- train_input_config_path=FLAGS.input_config_path)
- if FLAGS.task == 0:
- for name, config in [('model.config', FLAGS.model_config_path),
- ('train.config', FLAGS.train_config_path),
- ('input.config', FLAGS.input_config_path)]:
- tf.gfile.Copy(config, os.path.join(FLAGS.train_dir, name),
- overwrite=True)
-
- model_config = configs['model']
- train_config = configs['train_config']
- input_config = configs['train_input_config']
-
- model_fn = functools.partial(
- model_builder.build,
- model_config=model_config,
- is_training=True)
-
- def get_next(config):
- return dataset_builder.make_initializable_iterator(
- dataset_builder.build(config)).get_next()
-
- create_input_dict_fn = functools.partial(get_next, input_config)
-
- env = json.loads(os.environ.get('TF_CONFIG', '{}'))
- cluster_data = env.get('cluster', None)
- cluster = tf.train.ClusterSpec(cluster_data) if cluster_data else None
- task_data = env.get('task', None) or {'type': 'master', 'index': 0}
- task_info = type('TaskSpec', (object,), task_data)
-
- # Parameters for a single worker.
- ps_tasks = 0
- worker_replicas = 1
- worker_job_name = 'lonely_worker'
- task = 0
- is_chief = True
- master = ''
-
- if cluster_data and 'worker' in cluster_data:
- # Number of total worker replicas include "worker"s and the "master".
- worker_replicas = len(cluster_data['worker']) + 1
- if cluster_data and 'ps' in cluster_data:
- ps_tasks = len(cluster_data['ps'])
-
- if worker_replicas > 1 and ps_tasks < 1:
- raise ValueError('At least 1 ps task is needed for distributed training.')
-
- if worker_replicas >= 1 and ps_tasks > 0:
- # Set up distributed training.
- server = tf.train.Server(tf.train.ClusterSpec(cluster), protocol='grpc',
- job_name=task_info.type,
- task_index=task_info.index)
- if task_info.type == 'ps':
- server.join()
- return
-
- worker_job_name = '%s/task:%d' % (task_info.type, task_info.index)
- task = task_info.index
- is_chief = (task_info.type == 'master')
- master = server.target
-
- graph_rewriter_fn = None
- if 'graph_rewriter_config' in configs:
- graph_rewriter_fn = graph_rewriter_builder.build(
- configs['graph_rewriter_config'], is_training=True)
-
- trainer.train(
- create_input_dict_fn,
- model_fn,
- train_config,
- master,
- task,
- FLAGS.num_clones,
- worker_replicas,
- FLAGS.clone_on_cpu,
- ps_tasks,
- worker_job_name,
- is_chief,
- FLAGS.train_dir,
- graph_hook_fn=graph_rewriter_fn)
-
-
- if __name__ == '__main__':
- tf.app.run()
|