You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

846 lines
35 KiB

6 years ago
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. r"""Constructs model, inputs, and training environment."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import copy
  20. import functools
  21. import os
  22. import tensorflow as tf
  23. from object_detection import eval_util
  24. from object_detection import exporter as exporter_lib
  25. from object_detection import inputs
  26. from object_detection.builders import graph_rewriter_builder
  27. from object_detection.builders import model_builder
  28. from object_detection.builders import optimizer_builder
  29. from object_detection.core import standard_fields as fields
  30. from object_detection.utils import config_util
  31. from object_detection.utils import label_map_util
  32. from object_detection.utils import shape_utils
  33. from object_detection.utils import variables_helper
  34. from object_detection.utils import visualization_utils as vis_utils
  35. # A map of names to methods that help build the model.
  36. MODEL_BUILD_UTIL_MAP = {
  37. 'get_configs_from_pipeline_file':
  38. config_util.get_configs_from_pipeline_file,
  39. 'create_pipeline_proto_from_configs':
  40. config_util.create_pipeline_proto_from_configs,
  41. 'merge_external_params_with_configs':
  42. config_util.merge_external_params_with_configs,
  43. 'create_train_input_fn':
  44. inputs.create_train_input_fn,
  45. 'create_eval_input_fn':
  46. inputs.create_eval_input_fn,
  47. 'create_predict_input_fn':
  48. inputs.create_predict_input_fn,
  49. 'detection_model_fn_base': model_builder.build,
  50. }
  51. def _prepare_groundtruth_for_eval(detection_model, class_agnostic,
  52. max_number_of_boxes):
  53. """Extracts groundtruth data from detection_model and prepares it for eval.
  54. Args:
  55. detection_model: A `DetectionModel` object.
  56. class_agnostic: Whether the detections are class_agnostic.
  57. max_number_of_boxes: Max number of groundtruth boxes.
  58. Returns:
  59. A tuple of:
  60. groundtruth: Dictionary with the following fields:
  61. 'groundtruth_boxes': [batch_size, num_boxes, 4] float32 tensor of boxes,
  62. in normalized coordinates.
  63. 'groundtruth_classes': [batch_size, num_boxes] int64 tensor of 1-indexed
  64. classes.
  65. 'groundtruth_masks': 4D float32 tensor of instance masks (if provided in
  66. groundtruth)
  67. 'groundtruth_is_crowd': [batch_size, num_boxes] bool tensor indicating
  68. is_crowd annotations (if provided in groundtruth).
  69. 'num_groundtruth_boxes': [batch_size] tensor containing the maximum number
  70. of groundtruth boxes per image..
  71. class_agnostic: Boolean indicating whether detections are class agnostic.
  72. """
  73. input_data_fields = fields.InputDataFields()
  74. groundtruth_boxes = tf.stack(
  75. detection_model.groundtruth_lists(fields.BoxListFields.boxes))
  76. groundtruth_boxes_shape = tf.shape(groundtruth_boxes)
  77. # For class-agnostic models, groundtruth one-hot encodings collapse to all
  78. # ones.
  79. if class_agnostic:
  80. groundtruth_classes_one_hot = tf.ones(
  81. [groundtruth_boxes_shape[0], groundtruth_boxes_shape[1], 1])
  82. else:
  83. groundtruth_classes_one_hot = tf.stack(
  84. detection_model.groundtruth_lists(fields.BoxListFields.classes))
  85. label_id_offset = 1 # Applying label id offset (b/63711816)
  86. groundtruth_classes = (
  87. tf.argmax(groundtruth_classes_one_hot, axis=2) + label_id_offset)
  88. groundtruth = {
  89. input_data_fields.groundtruth_boxes: groundtruth_boxes,
  90. input_data_fields.groundtruth_classes: groundtruth_classes
  91. }
  92. if detection_model.groundtruth_has_field(fields.BoxListFields.masks):
  93. groundtruth[input_data_fields.groundtruth_instance_masks] = tf.stack(
  94. detection_model.groundtruth_lists(fields.BoxListFields.masks))
  95. if detection_model.groundtruth_has_field(fields.BoxListFields.is_crowd):
  96. groundtruth[input_data_fields.groundtruth_is_crowd] = tf.stack(
  97. detection_model.groundtruth_lists(fields.BoxListFields.is_crowd))
  98. groundtruth[input_data_fields.num_groundtruth_boxes] = (
  99. tf.tile([max_number_of_boxes], multiples=[groundtruth_boxes_shape[0]]))
  100. return groundtruth
  101. def unstack_batch(tensor_dict, unpad_groundtruth_tensors=True):
  102. """Unstacks all tensors in `tensor_dict` along 0th dimension.
  103. Unstacks tensor from the tensor dict along 0th dimension and returns a
  104. tensor_dict containing values that are lists of unstacked, unpadded tensors.
  105. Tensors in the `tensor_dict` are expected to be of one of the three shapes:
  106. 1. [batch_size]
  107. 2. [batch_size, height, width, channels]
  108. 3. [batch_size, num_boxes, d1, d2, ... dn]
  109. When unpad_groundtruth_tensors is set to true, unstacked tensors of form 3
  110. above are sliced along the `num_boxes` dimension using the value in tensor
  111. field.InputDataFields.num_groundtruth_boxes.
  112. Note that this function has a static list of input data fields and has to be
  113. kept in sync with the InputDataFields defined in core/standard_fields.py
  114. Args:
  115. tensor_dict: A dictionary of batched groundtruth tensors.
  116. unpad_groundtruth_tensors: Whether to remove padding along `num_boxes`
  117. dimension of the groundtruth tensors.
  118. Returns:
  119. A dictionary where the keys are from fields.InputDataFields and values are
  120. a list of unstacked (optionally unpadded) tensors.
  121. Raises:
  122. ValueError: If unpad_tensors is True and `tensor_dict` does not contain
  123. `num_groundtruth_boxes` tensor.
  124. """
  125. unbatched_tensor_dict = {
  126. key: tf.unstack(tensor) for key, tensor in tensor_dict.items()
  127. }
  128. if unpad_groundtruth_tensors:
  129. if (fields.InputDataFields.num_groundtruth_boxes not in
  130. unbatched_tensor_dict):
  131. raise ValueError('`num_groundtruth_boxes` not found in tensor_dict. '
  132. 'Keys available: {}'.format(
  133. unbatched_tensor_dict.keys()))
  134. unbatched_unpadded_tensor_dict = {}
  135. unpad_keys = set([
  136. # List of input data fields that are padded along the num_boxes
  137. # dimension. This list has to be kept in sync with InputDataFields in
  138. # standard_fields.py.
  139. fields.InputDataFields.groundtruth_instance_masks,
  140. fields.InputDataFields.groundtruth_classes,
  141. fields.InputDataFields.groundtruth_boxes,
  142. fields.InputDataFields.groundtruth_keypoints,
  143. fields.InputDataFields.groundtruth_group_of,
  144. fields.InputDataFields.groundtruth_difficult,
  145. fields.InputDataFields.groundtruth_is_crowd,
  146. fields.InputDataFields.groundtruth_area,
  147. fields.InputDataFields.groundtruth_weights
  148. ]).intersection(set(unbatched_tensor_dict.keys()))
  149. for key in unpad_keys:
  150. unpadded_tensor_list = []
  151. for num_gt, padded_tensor in zip(
  152. unbatched_tensor_dict[fields.InputDataFields.num_groundtruth_boxes],
  153. unbatched_tensor_dict[key]):
  154. tensor_shape = shape_utils.combined_static_and_dynamic_shape(
  155. padded_tensor)
  156. slice_begin = tf.zeros([len(tensor_shape)], dtype=tf.int32)
  157. slice_size = tf.stack(
  158. [num_gt] + [-1 if dim is None else dim for dim in tensor_shape[1:]])
  159. unpadded_tensor = tf.slice(padded_tensor, slice_begin, slice_size)
  160. unpadded_tensor_list.append(unpadded_tensor)
  161. unbatched_unpadded_tensor_dict[key] = unpadded_tensor_list
  162. unbatched_tensor_dict.update(unbatched_unpadded_tensor_dict)
  163. return unbatched_tensor_dict
  164. def create_model_fn(detection_model_fn, configs, hparams, use_tpu=False,
  165. postprocess_on_cpu=False):
  166. """Creates a model function for `Estimator`.
  167. Args:
  168. detection_model_fn: Function that returns a `DetectionModel` instance.
  169. configs: Dictionary of pipeline config objects.
  170. hparams: `HParams` object.
  171. use_tpu: Boolean indicating whether model should be constructed for
  172. use on TPU.
  173. postprocess_on_cpu: When use_tpu and postprocess_on_cpu is true, postprocess
  174. is scheduled on the host cpu.
  175. Returns:
  176. `model_fn` for `Estimator`.
  177. """
  178. train_config = configs['train_config']
  179. eval_input_config = configs['eval_input_config']
  180. eval_config = configs['eval_config']
  181. def model_fn(features, labels, mode, params=None):
  182. """Constructs the object detection model.
  183. Args:
  184. features: Dictionary of feature tensors, returned from `input_fn`.
  185. labels: Dictionary of groundtruth tensors if mode is TRAIN or EVAL,
  186. otherwise None.
  187. mode: Mode key from tf.estimator.ModeKeys.
  188. params: Parameter dictionary passed from the estimator.
  189. Returns:
  190. An `EstimatorSpec` that encapsulates the model and its serving
  191. configurations.
  192. """
  193. params = params or {}
  194. total_loss, train_op, detections, export_outputs = None, None, None, None
  195. is_training = mode == tf.estimator.ModeKeys.TRAIN
  196. # Make sure to set the Keras learning phase. True during training,
  197. # False for inference.
  198. tf.keras.backend.set_learning_phase(is_training)
  199. detection_model = detection_model_fn(
  200. is_training=is_training, add_summaries=(not use_tpu))
  201. scaffold_fn = None
  202. if mode == tf.estimator.ModeKeys.TRAIN:
  203. labels = unstack_batch(
  204. labels,
  205. unpad_groundtruth_tensors=train_config.unpad_groundtruth_tensors)
  206. elif mode == tf.estimator.ModeKeys.EVAL:
  207. # For evaling on train data, it is necessary to check whether groundtruth
  208. # must be unpadded.
  209. boxes_shape = (
  210. labels[fields.InputDataFields.groundtruth_boxes].get_shape()
  211. .as_list())
  212. unpad_groundtruth_tensors = boxes_shape[1] is not None and not use_tpu
  213. labels = unstack_batch(
  214. labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors)
  215. if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
  216. gt_boxes_list = labels[fields.InputDataFields.groundtruth_boxes]
  217. gt_classes_list = labels[fields.InputDataFields.groundtruth_classes]
  218. gt_masks_list = None
  219. if fields.InputDataFields.groundtruth_instance_masks in labels:
  220. gt_masks_list = labels[
  221. fields.InputDataFields.groundtruth_instance_masks]
  222. gt_keypoints_list = None
  223. if fields.InputDataFields.groundtruth_keypoints in labels:
  224. gt_keypoints_list = labels[fields.InputDataFields.groundtruth_keypoints]
  225. gt_weights_list = None
  226. if fields.InputDataFields.groundtruth_weights in labels:
  227. gt_weights_list = labels[fields.InputDataFields.groundtruth_weights]
  228. gt_confidences_list = None
  229. if fields.InputDataFields.groundtruth_confidences in labels:
  230. gt_confidences_list = labels[
  231. fields.InputDataFields.groundtruth_confidences]
  232. gt_is_crowd_list = None
  233. if fields.InputDataFields.groundtruth_is_crowd in labels:
  234. gt_is_crowd_list = labels[fields.InputDataFields.groundtruth_is_crowd]
  235. detection_model.provide_groundtruth(
  236. groundtruth_boxes_list=gt_boxes_list,
  237. groundtruth_classes_list=gt_classes_list,
  238. groundtruth_confidences_list=gt_confidences_list,
  239. groundtruth_masks_list=gt_masks_list,
  240. groundtruth_keypoints_list=gt_keypoints_list,
  241. groundtruth_weights_list=gt_weights_list,
  242. groundtruth_is_crowd_list=gt_is_crowd_list)
  243. preprocessed_images = features[fields.InputDataFields.image]
  244. if use_tpu and train_config.use_bfloat16:
  245. with tf.contrib.tpu.bfloat16_scope():
  246. prediction_dict = detection_model.predict(
  247. preprocessed_images,
  248. features[fields.InputDataFields.true_image_shape])
  249. for k, v in prediction_dict.items():
  250. if v.dtype == tf.bfloat16:
  251. prediction_dict[k] = tf.cast(v, tf.float32)
  252. else:
  253. prediction_dict = detection_model.predict(
  254. preprocessed_images,
  255. features[fields.InputDataFields.true_image_shape])
  256. def postprocess_wrapper(args):
  257. return detection_model.postprocess(args[0], args[1])
  258. if mode in (tf.estimator.ModeKeys.EVAL, tf.estimator.ModeKeys.PREDICT):
  259. if use_tpu and postprocess_on_cpu:
  260. detections = tf.contrib.tpu.outside_compilation(
  261. postprocess_wrapper,
  262. (prediction_dict,
  263. features[fields.InputDataFields.true_image_shape]))
  264. else:
  265. detections = postprocess_wrapper((
  266. prediction_dict,
  267. features[fields.InputDataFields.true_image_shape]))
  268. if mode == tf.estimator.ModeKeys.TRAIN:
  269. if train_config.fine_tune_checkpoint and hparams.load_pretrained:
  270. if not train_config.fine_tune_checkpoint_type:
  271. # train_config.from_detection_checkpoint field is deprecated. For
  272. # backward compatibility, set train_config.fine_tune_checkpoint_type
  273. # based on train_config.from_detection_checkpoint.
  274. if train_config.from_detection_checkpoint:
  275. train_config.fine_tune_checkpoint_type = 'detection'
  276. else:
  277. train_config.fine_tune_checkpoint_type = 'classification'
  278. asg_map = detection_model.restore_map(
  279. fine_tune_checkpoint_type=train_config.fine_tune_checkpoint_type,
  280. load_all_detection_checkpoint_vars=(
  281. train_config.load_all_detection_checkpoint_vars))
  282. available_var_map = (
  283. variables_helper.get_variables_available_in_checkpoint(
  284. asg_map,
  285. train_config.fine_tune_checkpoint,
  286. include_global_step=False))
  287. if use_tpu:
  288. def tpu_scaffold():
  289. tf.train.init_from_checkpoint(train_config.fine_tune_checkpoint,
  290. available_var_map)
  291. return tf.train.Scaffold()
  292. scaffold_fn = tpu_scaffold
  293. else:
  294. tf.train.init_from_checkpoint(train_config.fine_tune_checkpoint,
  295. available_var_map)
  296. if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
  297. losses_dict = detection_model.loss(
  298. prediction_dict, features[fields.InputDataFields.true_image_shape])
  299. losses = [loss_tensor for loss_tensor in losses_dict.values()]
  300. if train_config.add_regularization_loss:
  301. regularization_losses = detection_model.regularization_losses()
  302. if regularization_losses:
  303. regularization_loss = tf.add_n(
  304. regularization_losses, name='regularization_loss')
  305. losses.append(regularization_loss)
  306. losses_dict['Loss/regularization_loss'] = regularization_loss
  307. total_loss = tf.add_n(losses, name='total_loss')
  308. losses_dict['Loss/total_loss'] = total_loss
  309. if 'graph_rewriter_config' in configs:
  310. graph_rewriter_fn = graph_rewriter_builder.build(
  311. configs['graph_rewriter_config'], is_training=is_training)
  312. graph_rewriter_fn()
  313. # TODO(rathodv): Stop creating optimizer summary vars in EVAL mode once we
  314. # can write learning rate summaries on TPU without host calls.
  315. global_step = tf.train.get_or_create_global_step()
  316. training_optimizer, optimizer_summary_vars = optimizer_builder.build(
  317. train_config.optimizer)
  318. if mode == tf.estimator.ModeKeys.TRAIN:
  319. if use_tpu:
  320. training_optimizer = tf.contrib.tpu.CrossShardOptimizer(
  321. training_optimizer)
  322. # Optionally freeze some layers by setting their gradients to be zero.
  323. trainable_variables = None
  324. include_variables = (
  325. train_config.update_trainable_variables
  326. if train_config.update_trainable_variables else None)
  327. exclude_variables = (
  328. train_config.freeze_variables
  329. if train_config.freeze_variables else None)
  330. trainable_variables = tf.contrib.framework.filter_variables(
  331. tf.trainable_variables(),
  332. include_patterns=include_variables,
  333. exclude_patterns=exclude_variables)
  334. clip_gradients_value = None
  335. if train_config.gradient_clipping_by_norm > 0:
  336. clip_gradients_value = train_config.gradient_clipping_by_norm
  337. if not use_tpu:
  338. for var in optimizer_summary_vars:
  339. tf.summary.scalar(var.op.name, var)
  340. summaries = [] if use_tpu else None
  341. if train_config.summarize_gradients:
  342. summaries = ['gradients', 'gradient_norm', 'global_gradient_norm']
  343. train_op = tf.contrib.layers.optimize_loss(
  344. loss=total_loss,
  345. global_step=global_step,
  346. learning_rate=None,
  347. clip_gradients=clip_gradients_value,
  348. optimizer=training_optimizer,
  349. update_ops=detection_model.updates(),
  350. variables=trainable_variables,
  351. summaries=summaries,
  352. name='') # Preventing scope prefix on all variables.
  353. if mode == tf.estimator.ModeKeys.PREDICT:
  354. exported_output = exporter_lib.add_output_tensor_nodes(detections)
  355. export_outputs = {
  356. tf.saved_model.signature_constants.PREDICT_METHOD_NAME:
  357. tf.estimator.export.PredictOutput(exported_output)
  358. }
  359. eval_metric_ops = None
  360. scaffold = None
  361. if mode == tf.estimator.ModeKeys.EVAL:
  362. class_agnostic = (
  363. fields.DetectionResultFields.detection_classes not in detections)
  364. groundtruth = _prepare_groundtruth_for_eval(
  365. detection_model, class_agnostic,
  366. eval_input_config.max_number_of_boxes)
  367. use_original_images = fields.InputDataFields.original_image in features
  368. if use_original_images:
  369. eval_images = features[fields.InputDataFields.original_image]
  370. true_image_shapes = tf.slice(
  371. features[fields.InputDataFields.true_image_shape], [0, 0], [-1, 3])
  372. original_image_spatial_shapes = features[fields.InputDataFields
  373. .original_image_spatial_shape]
  374. else:
  375. eval_images = features[fields.InputDataFields.image]
  376. true_image_shapes = None
  377. original_image_spatial_shapes = None
  378. eval_dict = eval_util.result_dict_for_batched_example(
  379. eval_images,
  380. features[inputs.HASH_KEY],
  381. detections,
  382. groundtruth,
  383. class_agnostic=class_agnostic,
  384. scale_to_absolute=True,
  385. original_image_spatial_shapes=original_image_spatial_shapes,
  386. true_image_shapes=true_image_shapes)
  387. if class_agnostic:
  388. category_index = label_map_util.create_class_agnostic_category_index()
  389. else:
  390. category_index = label_map_util.create_category_index_from_labelmap(
  391. eval_input_config.label_map_path)
  392. vis_metric_ops = None
  393. if not use_tpu and use_original_images:
  394. eval_metric_op_vis = vis_utils.VisualizeSingleFrameDetections(
  395. category_index,
  396. max_examples_to_draw=eval_config.num_visualizations,
  397. max_boxes_to_draw=eval_config.max_num_boxes_to_visualize,
  398. min_score_thresh=eval_config.min_score_threshold,
  399. use_normalized_coordinates=False)
  400. vis_metric_ops = eval_metric_op_vis.get_estimator_eval_metric_ops(
  401. eval_dict)
  402. # Eval metrics on a single example.
  403. eval_metric_ops = eval_util.get_eval_metric_ops_for_evaluators(
  404. eval_config, list(category_index.values()), eval_dict)
  405. for loss_key, loss_tensor in iter(losses_dict.items()):
  406. eval_metric_ops[loss_key] = tf.metrics.mean(loss_tensor)
  407. for var in optimizer_summary_vars:
  408. eval_metric_ops[var.op.name] = (var, tf.no_op())
  409. if vis_metric_ops is not None:
  410. eval_metric_ops.update(vis_metric_ops)
  411. eval_metric_ops = {str(k): v for k, v in eval_metric_ops.items()}
  412. if eval_config.use_moving_averages:
  413. variable_averages = tf.train.ExponentialMovingAverage(0.0)
  414. variables_to_restore = variable_averages.variables_to_restore()
  415. keep_checkpoint_every_n_hours = (
  416. train_config.keep_checkpoint_every_n_hours)
  417. saver = tf.train.Saver(
  418. variables_to_restore,
  419. keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)
  420. scaffold = tf.train.Scaffold(saver=saver)
  421. # EVAL executes on CPU, so use regular non-TPU EstimatorSpec.
  422. if use_tpu and mode != tf.estimator.ModeKeys.EVAL:
  423. return tf.contrib.tpu.TPUEstimatorSpec(
  424. mode=mode,
  425. scaffold_fn=scaffold_fn,
  426. predictions=detections,
  427. loss=total_loss,
  428. train_op=train_op,
  429. eval_metrics=eval_metric_ops,
  430. export_outputs=export_outputs)
  431. else:
  432. if scaffold is None:
  433. keep_checkpoint_every_n_hours = (
  434. train_config.keep_checkpoint_every_n_hours)
  435. saver = tf.train.Saver(
  436. sharded=True,
  437. keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
  438. save_relative_paths=True)
  439. tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
  440. scaffold = tf.train.Scaffold(saver=saver)
  441. return tf.estimator.EstimatorSpec(
  442. mode=mode,
  443. predictions=detections,
  444. loss=total_loss,
  445. train_op=train_op,
  446. eval_metric_ops=eval_metric_ops,
  447. export_outputs=export_outputs,
  448. scaffold=scaffold)
  449. return model_fn
  450. def create_estimator_and_inputs(run_config,
  451. hparams,
  452. pipeline_config_path,
  453. config_override=None,
  454. train_steps=None,
  455. sample_1_of_n_eval_examples=1,
  456. sample_1_of_n_eval_on_train_examples=1,
  457. model_fn_creator=create_model_fn,
  458. use_tpu_estimator=False,
  459. use_tpu=False,
  460. num_shards=1,
  461. params=None,
  462. override_eval_num_epochs=True,
  463. save_final_config=False,
  464. postprocess_on_cpu=False,
  465. export_to_tpu=None,
  466. **kwargs):
  467. """Creates `Estimator`, input functions, and steps.
  468. Args:
  469. run_config: A `RunConfig`.
  470. hparams: A `HParams`.
  471. pipeline_config_path: A path to a pipeline config file.
  472. config_override: A pipeline_pb2.TrainEvalPipelineConfig text proto to
  473. override the config from `pipeline_config_path`.
  474. train_steps: Number of training steps. If None, the number of training steps
  475. is set from the `TrainConfig` proto.
  476. sample_1_of_n_eval_examples: Integer representing how often an eval example
  477. should be sampled. If 1, will sample all examples.
  478. sample_1_of_n_eval_on_train_examples: Similar to
  479. `sample_1_of_n_eval_examples`, except controls the sampling of training
  480. data for evaluation.
  481. model_fn_creator: A function that creates a `model_fn` for `Estimator`.
  482. Follows the signature:
  483. * Args:
  484. * `detection_model_fn`: Function that returns `DetectionModel` instance.
  485. * `configs`: Dictionary of pipeline config objects.
  486. * `hparams`: `HParams` object.
  487. * Returns:
  488. `model_fn` for `Estimator`.
  489. use_tpu_estimator: Whether a `TPUEstimator` should be returned. If False,
  490. an `Estimator` will be returned.
  491. use_tpu: Boolean, whether training and evaluation should run on TPU. Only
  492. used if `use_tpu_estimator` is True.
  493. num_shards: Number of shards (TPU cores). Only used if `use_tpu_estimator`
  494. is True.
  495. params: Parameter dictionary passed from the estimator. Only used if
  496. `use_tpu_estimator` is True.
  497. override_eval_num_epochs: Whether to overwrite the number of epochs to 1 for
  498. eval_input.
  499. save_final_config: Whether to save final config (obtained after applying
  500. overrides) to `estimator.model_dir`.
  501. postprocess_on_cpu: When use_tpu and postprocess_on_cpu are true,
  502. postprocess is scheduled on the host cpu.
  503. export_to_tpu: When use_tpu and export_to_tpu are true,
  504. `export_savedmodel()` exports a metagraph for serving on TPU besides the
  505. one on CPU.
  506. **kwargs: Additional keyword arguments for configuration override.
  507. Returns:
  508. A dictionary with the following fields:
  509. 'estimator': An `Estimator` or `TPUEstimator`.
  510. 'train_input_fn': A training input function.
  511. 'eval_input_fns': A list of all evaluation input functions.
  512. 'eval_input_names': A list of names for each evaluation input.
  513. 'eval_on_train_input_fn': An evaluation-on-train input function.
  514. 'predict_input_fn': A prediction input function.
  515. 'train_steps': Number of training steps. Either directly from input or from
  516. configuration.
  517. """
  518. get_configs_from_pipeline_file = MODEL_BUILD_UTIL_MAP[
  519. 'get_configs_from_pipeline_file']
  520. merge_external_params_with_configs = MODEL_BUILD_UTIL_MAP[
  521. 'merge_external_params_with_configs']
  522. create_pipeline_proto_from_configs = MODEL_BUILD_UTIL_MAP[
  523. 'create_pipeline_proto_from_configs']
  524. create_train_input_fn = MODEL_BUILD_UTIL_MAP['create_train_input_fn']
  525. create_eval_input_fn = MODEL_BUILD_UTIL_MAP['create_eval_input_fn']
  526. create_predict_input_fn = MODEL_BUILD_UTIL_MAP['create_predict_input_fn']
  527. detection_model_fn_base = MODEL_BUILD_UTIL_MAP['detection_model_fn_base']
  528. configs = get_configs_from_pipeline_file(
  529. pipeline_config_path, config_override=config_override)
  530. kwargs.update({
  531. 'train_steps': train_steps,
  532. 'sample_1_of_n_eval_examples': sample_1_of_n_eval_examples,
  533. 'use_bfloat16': configs['train_config'].use_bfloat16 and use_tpu
  534. })
  535. if override_eval_num_epochs:
  536. kwargs.update({'eval_num_epochs': 1})
  537. tf.logging.warning(
  538. 'Forced number of epochs for all eval validations to be 1.')
  539. configs = merge_external_params_with_configs(
  540. configs, hparams, kwargs_dict=kwargs)
  541. model_config = configs['model']
  542. train_config = configs['train_config']
  543. train_input_config = configs['train_input_config']
  544. eval_config = configs['eval_config']
  545. eval_input_configs = configs['eval_input_configs']
  546. eval_on_train_input_config = copy.deepcopy(train_input_config)
  547. eval_on_train_input_config.sample_1_of_n_examples = (
  548. sample_1_of_n_eval_on_train_examples)
  549. if override_eval_num_epochs and eval_on_train_input_config.num_epochs != 1:
  550. tf.logging.warning('Expected number of evaluation epochs is 1, but '
  551. 'instead encountered `eval_on_train_input_config'
  552. '.num_epochs` = '
  553. '{}. Overwriting `num_epochs` to 1.'.format(
  554. eval_on_train_input_config.num_epochs))
  555. eval_on_train_input_config.num_epochs = 1
  556. # update train_steps from config but only when non-zero value is provided
  557. if train_steps is None and train_config.num_steps != 0:
  558. train_steps = train_config.num_steps
  559. detection_model_fn = functools.partial(
  560. detection_model_fn_base, model_config=model_config)
  561. # Create the input functions for TRAIN/EVAL/PREDICT.
  562. train_input_fn = create_train_input_fn(
  563. train_config=train_config,
  564. train_input_config=train_input_config,
  565. model_config=model_config)
  566. eval_input_fns = [
  567. create_eval_input_fn(
  568. eval_config=eval_config,
  569. eval_input_config=eval_input_config,
  570. model_config=model_config) for eval_input_config in eval_input_configs
  571. ]
  572. eval_input_names = [
  573. eval_input_config.name for eval_input_config in eval_input_configs
  574. ]
  575. eval_on_train_input_fn = create_eval_input_fn(
  576. eval_config=eval_config,
  577. eval_input_config=eval_on_train_input_config,
  578. model_config=model_config)
  579. predict_input_fn = create_predict_input_fn(
  580. model_config=model_config, predict_input_config=eval_input_configs[0])
  581. # Read export_to_tpu from hparams if not passed.
  582. if export_to_tpu is None:
  583. export_to_tpu = hparams.get('export_to_tpu', False)
  584. tf.logging.info('create_estimator_and_inputs: use_tpu %s, export_to_tpu %s',
  585. use_tpu, export_to_tpu)
  586. model_fn = model_fn_creator(detection_model_fn, configs, hparams, use_tpu,
  587. postprocess_on_cpu)
  588. if use_tpu_estimator:
  589. estimator = tf.contrib.tpu.TPUEstimator(
  590. model_fn=model_fn,
  591. train_batch_size=train_config.batch_size,
  592. # For each core, only batch size 1 is supported for eval.
  593. eval_batch_size=num_shards * 1 if use_tpu else 1,
  594. use_tpu=use_tpu,
  595. config=run_config,
  596. export_to_tpu=export_to_tpu,
  597. eval_on_tpu=False, # Eval runs on CPU, so disable eval on TPU
  598. params=params if params else {})
  599. else:
  600. estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config)
  601. # Write the as-run pipeline config to disk.
  602. if run_config.is_chief and save_final_config:
  603. pipeline_config_final = create_pipeline_proto_from_configs(configs)
  604. config_util.save_pipeline_config(pipeline_config_final, estimator.model_dir)
  605. return dict(
  606. estimator=estimator,
  607. train_input_fn=train_input_fn,
  608. eval_input_fns=eval_input_fns,
  609. eval_input_names=eval_input_names,
  610. eval_on_train_input_fn=eval_on_train_input_fn,
  611. predict_input_fn=predict_input_fn,
  612. train_steps=train_steps)
  613. def create_train_and_eval_specs(train_input_fn,
  614. eval_input_fns,
  615. eval_on_train_input_fn,
  616. predict_input_fn,
  617. train_steps,
  618. eval_on_train_data=False,
  619. final_exporter_name='Servo',
  620. eval_spec_names=None):
  621. """Creates a `TrainSpec` and `EvalSpec`s.
  622. Args:
  623. train_input_fn: Function that produces features and labels on train data.
  624. eval_input_fns: A list of functions that produce features and labels on eval
  625. data.
  626. eval_on_train_input_fn: Function that produces features and labels for
  627. evaluation on train data.
  628. predict_input_fn: Function that produces features for inference.
  629. train_steps: Number of training steps.
  630. eval_on_train_data: Whether to evaluate model on training data. Default is
  631. False.
  632. final_exporter_name: String name given to `FinalExporter`.
  633. eval_spec_names: A list of string names for each `EvalSpec`.
  634. Returns:
  635. Tuple of `TrainSpec` and list of `EvalSpecs`. If `eval_on_train_data` is
  636. True, the last `EvalSpec` in the list will correspond to training data. The
  637. rest EvalSpecs in the list are evaluation datas.
  638. """
  639. train_spec = tf.estimator.TrainSpec(
  640. input_fn=train_input_fn, max_steps=train_steps)
  641. if eval_spec_names is None:
  642. eval_spec_names = [str(i) for i in range(len(eval_input_fns))]
  643. eval_specs = []
  644. for index, (eval_spec_name, eval_input_fn) in enumerate(
  645. zip(eval_spec_names, eval_input_fns)):
  646. # Uses final_exporter_name as exporter_name for the first eval spec for
  647. # backward compatibility.
  648. if index == 0:
  649. exporter_name = final_exporter_name
  650. else:
  651. exporter_name = '{}_{}'.format(final_exporter_name, eval_spec_name)
  652. exporter = tf.estimator.FinalExporter(
  653. name=exporter_name, serving_input_receiver_fn=predict_input_fn)
  654. eval_specs.append(
  655. tf.estimator.EvalSpec(
  656. name=eval_spec_name,
  657. input_fn=eval_input_fn,
  658. steps=None,
  659. exporters=exporter))
  660. if eval_on_train_data:
  661. eval_specs.append(
  662. tf.estimator.EvalSpec(
  663. name='eval_on_train', input_fn=eval_on_train_input_fn, steps=None))
  664. return train_spec, eval_specs
  665. def continuous_eval(estimator, model_dir, input_fn, train_steps, name):
  666. """Perform continuous evaluation on checkpoints written to a model directory.
  667. Args:
  668. estimator: Estimator object to use for evaluation.
  669. model_dir: Model directory to read checkpoints for continuous evaluation.
  670. input_fn: Input function to use for evaluation.
  671. train_steps: Number of training steps. This is used to infer the last
  672. checkpoint and stop evaluation loop.
  673. name: Namescope for eval summary.
  674. """
  675. def terminate_eval():
  676. tf.logging.info('Terminating eval after 180 seconds of no checkpoints')
  677. return True
  678. for ckpt in tf.contrib.training.checkpoints_iterator(
  679. model_dir, min_interval_secs=180, timeout=None,
  680. timeout_fn=terminate_eval):
  681. tf.logging.info('Starting Evaluation.')
  682. try:
  683. eval_results = estimator.evaluate(
  684. input_fn=input_fn, steps=None, checkpoint_path=ckpt, name=name)
  685. tf.logging.info('Eval results: %s' % eval_results)
  686. # Terminate eval job when final checkpoint is reached
  687. current_step = int(os.path.basename(ckpt).split('-')[1])
  688. if current_step >= train_steps:
  689. tf.logging.info(
  690. 'Evaluation finished after training step %d' % current_step)
  691. break
  692. except tf.errors.NotFoundError:
  693. tf.logging.info(
  694. 'Checkpoint %s no longer exists, skipping checkpoint' % ckpt)
  695. def populate_experiment(run_config,
  696. hparams,
  697. pipeline_config_path,
  698. train_steps=None,
  699. eval_steps=None,
  700. model_fn_creator=create_model_fn,
  701. **kwargs):
  702. """Populates an `Experiment` object.
  703. EXPERIMENT CLASS IS DEPRECATED. Please switch to
  704. tf.estimator.train_and_evaluate. As an example, see model_main.py.
  705. Args:
  706. run_config: A `RunConfig`.
  707. hparams: A `HParams`.
  708. pipeline_config_path: A path to a pipeline config file.
  709. train_steps: Number of training steps. If None, the number of training steps
  710. is set from the `TrainConfig` proto.
  711. eval_steps: Number of evaluation steps per evaluation cycle. If None, the
  712. number of evaluation steps is set from the `EvalConfig` proto.
  713. model_fn_creator: A function that creates a `model_fn` for `Estimator`.
  714. Follows the signature:
  715. * Args:
  716. * `detection_model_fn`: Function that returns `DetectionModel` instance.
  717. * `configs`: Dictionary of pipeline config objects.
  718. * `hparams`: `HParams` object.
  719. * Returns:
  720. `model_fn` for `Estimator`.
  721. **kwargs: Additional keyword arguments for configuration override.
  722. Returns:
  723. An `Experiment` that defines all aspects of training, evaluation, and
  724. export.
  725. """
  726. tf.logging.warning('Experiment is being deprecated. Please use '
  727. 'tf.estimator.train_and_evaluate(). See model_main.py for '
  728. 'an example.')
  729. train_and_eval_dict = create_estimator_and_inputs(
  730. run_config,
  731. hparams,
  732. pipeline_config_path,
  733. train_steps=train_steps,
  734. eval_steps=eval_steps,
  735. model_fn_creator=model_fn_creator,
  736. save_final_config=True,
  737. **kwargs)
  738. estimator = train_and_eval_dict['estimator']
  739. train_input_fn = train_and_eval_dict['train_input_fn']
  740. eval_input_fns = train_and_eval_dict['eval_input_fns']
  741. predict_input_fn = train_and_eval_dict['predict_input_fn']
  742. train_steps = train_and_eval_dict['train_steps']
  743. export_strategies = [
  744. tf.contrib.learn.utils.saved_model_export_utils.make_export_strategy(
  745. serving_input_fn=predict_input_fn)
  746. ]
  747. return tf.contrib.learn.Experiment(
  748. estimator=estimator,
  749. train_input_fn=train_input_fn,
  750. eval_input_fn=eval_input_fns[0],
  751. train_steps=train_steps,
  752. eval_steps=None,
  753. export_strategies=export_strategies,
  754. eval_delay_secs=120,
  755. )