You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

869 lines
36 KiB

  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. r"""Constructs model, inputs, and training environment."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import copy
  20. import functools
  21. import os
  22. import tensorflow as tf
  23. from tensorflow.python.util import function_utils
  24. from object_detection import eval_util
  25. from object_detection import exporter as exporter_lib
  26. from object_detection import inputs
  27. from object_detection.builders import graph_rewriter_builder
  28. from object_detection.builders import model_builder
  29. from object_detection.builders import optimizer_builder
  30. from object_detection.core import standard_fields as fields
  31. from object_detection.utils import config_util
  32. from object_detection.utils import label_map_util
  33. from object_detection.utils import ops
  34. from object_detection.utils import shape_utils
  35. from object_detection.utils import variables_helper
  36. from object_detection.utils import visualization_utils as vis_utils
  37. # A map of names to methods that help build the model.
  38. MODEL_BUILD_UTIL_MAP = {
  39. 'get_configs_from_pipeline_file':
  40. config_util.get_configs_from_pipeline_file,
  41. 'create_pipeline_proto_from_configs':
  42. config_util.create_pipeline_proto_from_configs,
  43. 'merge_external_params_with_configs':
  44. config_util.merge_external_params_with_configs,
  45. 'create_train_input_fn':
  46. inputs.create_train_input_fn,
  47. 'create_eval_input_fn':
  48. inputs.create_eval_input_fn,
  49. 'create_predict_input_fn':
  50. inputs.create_predict_input_fn,
  51. 'detection_model_fn_base': model_builder.build,
  52. }
  53. def _prepare_groundtruth_for_eval(detection_model, class_agnostic,
  54. max_number_of_boxes):
  55. """Extracts groundtruth data from detection_model and prepares it for eval.
  56. Args:
  57. detection_model: A `DetectionModel` object.
  58. class_agnostic: Whether the detections are class_agnostic.
  59. max_number_of_boxes: Max number of groundtruth boxes.
  60. Returns:
  61. A tuple of:
  62. groundtruth: Dictionary with the following fields:
  63. 'groundtruth_boxes': [batch_size, num_boxes, 4] float32 tensor of boxes,
  64. in normalized coordinates.
  65. 'groundtruth_classes': [batch_size, num_boxes] int64 tensor of 1-indexed
  66. classes.
  67. 'groundtruth_masks': 4D float32 tensor of instance masks (if provided in
  68. groundtruth)
  69. 'groundtruth_is_crowd': [batch_size, num_boxes] bool tensor indicating
  70. is_crowd annotations (if provided in groundtruth).
  71. 'num_groundtruth_boxes': [batch_size] tensor containing the maximum number
  72. of groundtruth boxes per image..
  73. class_agnostic: Boolean indicating whether detections are class agnostic.
  74. """
  75. input_data_fields = fields.InputDataFields()
  76. groundtruth_boxes = tf.stack(
  77. detection_model.groundtruth_lists(fields.BoxListFields.boxes))
  78. groundtruth_boxes_shape = tf.shape(groundtruth_boxes)
  79. # For class-agnostic models, groundtruth one-hot encodings collapse to all
  80. # ones.
  81. if class_agnostic:
  82. groundtruth_classes_one_hot = tf.ones(
  83. [groundtruth_boxes_shape[0], groundtruth_boxes_shape[1], 1])
  84. else:
  85. groundtruth_classes_one_hot = tf.stack(
  86. detection_model.groundtruth_lists(fields.BoxListFields.classes))
  87. label_id_offset = 1 # Applying label id offset (b/63711816)
  88. groundtruth_classes = (
  89. tf.argmax(groundtruth_classes_one_hot, axis=2) + label_id_offset)
  90. groundtruth = {
  91. input_data_fields.groundtruth_boxes: groundtruth_boxes,
  92. input_data_fields.groundtruth_classes: groundtruth_classes
  93. }
  94. if detection_model.groundtruth_has_field(fields.BoxListFields.masks):
  95. groundtruth[input_data_fields.groundtruth_instance_masks] = tf.stack(
  96. detection_model.groundtruth_lists(fields.BoxListFields.masks))
  97. if detection_model.groundtruth_has_field(fields.BoxListFields.is_crowd):
  98. groundtruth[input_data_fields.groundtruth_is_crowd] = tf.stack(
  99. detection_model.groundtruth_lists(fields.BoxListFields.is_crowd))
  100. groundtruth[input_data_fields.num_groundtruth_boxes] = (
  101. tf.tile([max_number_of_boxes], multiples=[groundtruth_boxes_shape[0]]))
  102. return groundtruth
  103. def unstack_batch(tensor_dict, unpad_groundtruth_tensors=True):
  104. """Unstacks all tensors in `tensor_dict` along 0th dimension.
  105. Unstacks tensor from the tensor dict along 0th dimension and returns a
  106. tensor_dict containing values that are lists of unstacked, unpadded tensors.
  107. Tensors in the `tensor_dict` are expected to be of one of the three shapes:
  108. 1. [batch_size]
  109. 2. [batch_size, height, width, channels]
  110. 3. [batch_size, num_boxes, d1, d2, ... dn]
  111. When unpad_groundtruth_tensors is set to true, unstacked tensors of form 3
  112. above are sliced along the `num_boxes` dimension using the value in tensor
  113. field.InputDataFields.num_groundtruth_boxes.
  114. Note that this function has a static list of input data fields and has to be
  115. kept in sync with the InputDataFields defined in core/standard_fields.py
  116. Args:
  117. tensor_dict: A dictionary of batched groundtruth tensors.
  118. unpad_groundtruth_tensors: Whether to remove padding along `num_boxes`
  119. dimension of the groundtruth tensors.
  120. Returns:
  121. A dictionary where the keys are from fields.InputDataFields and values are
  122. a list of unstacked (optionally unpadded) tensors.
  123. Raises:
  124. ValueError: If unpad_tensors is True and `tensor_dict` does not contain
  125. `num_groundtruth_boxes` tensor.
  126. """
  127. unbatched_tensor_dict = {
  128. key: tf.unstack(tensor) for key, tensor in tensor_dict.items()
  129. }
  130. if unpad_groundtruth_tensors:
  131. if (fields.InputDataFields.num_groundtruth_boxes not in
  132. unbatched_tensor_dict):
  133. raise ValueError('`num_groundtruth_boxes` not found in tensor_dict. '
  134. 'Keys available: {}'.format(
  135. unbatched_tensor_dict.keys()))
  136. unbatched_unpadded_tensor_dict = {}
  137. unpad_keys = set([
  138. # List of input data fields that are padded along the num_boxes
  139. # dimension. This list has to be kept in sync with InputDataFields in
  140. # standard_fields.py.
  141. fields.InputDataFields.groundtruth_instance_masks,
  142. fields.InputDataFields.groundtruth_classes,
  143. fields.InputDataFields.groundtruth_boxes,
  144. fields.InputDataFields.groundtruth_keypoints,
  145. fields.InputDataFields.groundtruth_group_of,
  146. fields.InputDataFields.groundtruth_difficult,
  147. fields.InputDataFields.groundtruth_is_crowd,
  148. fields.InputDataFields.groundtruth_area,
  149. fields.InputDataFields.groundtruth_weights
  150. ]).intersection(set(unbatched_tensor_dict.keys()))
  151. for key in unpad_keys:
  152. unpadded_tensor_list = []
  153. for num_gt, padded_tensor in zip(
  154. unbatched_tensor_dict[fields.InputDataFields.num_groundtruth_boxes],
  155. unbatched_tensor_dict[key]):
  156. tensor_shape = shape_utils.combined_static_and_dynamic_shape(
  157. padded_tensor)
  158. slice_begin = tf.zeros([len(tensor_shape)], dtype=tf.int32)
  159. slice_size = tf.stack(
  160. [num_gt] + [-1 if dim is None else dim for dim in tensor_shape[1:]])
  161. unpadded_tensor = tf.slice(padded_tensor, slice_begin, slice_size)
  162. unpadded_tensor_list.append(unpadded_tensor)
  163. unbatched_unpadded_tensor_dict[key] = unpadded_tensor_list
  164. unbatched_tensor_dict.update(unbatched_unpadded_tensor_dict)
  165. return unbatched_tensor_dict
  166. def _provide_groundtruth(model, labels):
  167. """Provides the labels to a model as groundtruth.
  168. This helper function extracts the corresponding boxes, classes,
  169. keypoints, weights, masks, etc. from the labels, and provides it
  170. as groundtruth to the models.
  171. Args:
  172. model: The detection model to provide groundtruth to.
  173. labels: The labels for the training or evaluation inputs.
  174. """
  175. gt_boxes_list = labels[fields.InputDataFields.groundtruth_boxes]
  176. gt_classes_list = labels[fields.InputDataFields.groundtruth_classes]
  177. gt_masks_list = None
  178. if fields.InputDataFields.groundtruth_instance_masks in labels:
  179. gt_masks_list = labels[
  180. fields.InputDataFields.groundtruth_instance_masks]
  181. gt_keypoints_list = None
  182. if fields.InputDataFields.groundtruth_keypoints in labels:
  183. gt_keypoints_list = labels[fields.InputDataFields.groundtruth_keypoints]
  184. gt_weights_list = None
  185. if fields.InputDataFields.groundtruth_weights in labels:
  186. gt_weights_list = labels[fields.InputDataFields.groundtruth_weights]
  187. gt_confidences_list = None
  188. if fields.InputDataFields.groundtruth_confidences in labels:
  189. gt_confidences_list = labels[
  190. fields.InputDataFields.groundtruth_confidences]
  191. gt_is_crowd_list = None
  192. if fields.InputDataFields.groundtruth_is_crowd in labels:
  193. gt_is_crowd_list = labels[fields.InputDataFields.groundtruth_is_crowd]
  194. model.provide_groundtruth(
  195. groundtruth_boxes_list=gt_boxes_list,
  196. groundtruth_classes_list=gt_classes_list,
  197. groundtruth_confidences_list=gt_confidences_list,
  198. groundtruth_masks_list=gt_masks_list,
  199. groundtruth_keypoints_list=gt_keypoints_list,
  200. groundtruth_weights_list=gt_weights_list,
  201. groundtruth_is_crowd_list=gt_is_crowd_list)
  202. def create_model_fn(detection_model_fn, configs, hparams, use_tpu=False,
  203. postprocess_on_cpu=False):
  204. """Creates a model function for `Estimator`.
  205. Args:
  206. detection_model_fn: Function that returns a `DetectionModel` instance.
  207. configs: Dictionary of pipeline config objects.
  208. hparams: `HParams` object.
  209. use_tpu: Boolean indicating whether model should be constructed for
  210. use on TPU.
  211. postprocess_on_cpu: When use_tpu and postprocess_on_cpu is true, postprocess
  212. is scheduled on the host cpu.
  213. Returns:
  214. `model_fn` for `Estimator`.
  215. """
  216. train_config = configs['train_config']
  217. eval_input_config = configs['eval_input_config']
  218. eval_config = configs['eval_config']
  219. def model_fn(features, labels, mode, params=None):
  220. """Constructs the object detection model.
  221. Args:
  222. features: Dictionary of feature tensors, returned from `input_fn`.
  223. labels: Dictionary of groundtruth tensors if mode is TRAIN or EVAL,
  224. otherwise None.
  225. mode: Mode key from tf.estimator.ModeKeys.
  226. params: Parameter dictionary passed from the estimator.
  227. Returns:
  228. An `EstimatorSpec` that encapsulates the model and its serving
  229. configurations.
  230. """
  231. params = params or {}
  232. total_loss, train_op, detections, export_outputs = None, None, None, None
  233. is_training = mode == tf.estimator.ModeKeys.TRAIN
  234. # Make sure to set the Keras learning phase. True during training,
  235. # False for inference.
  236. tf.keras.backend.set_learning_phase(is_training)
  237. detection_model = detection_model_fn(
  238. is_training=is_training, add_summaries=(not use_tpu))
  239. scaffold_fn = None
  240. if mode == tf.estimator.ModeKeys.TRAIN:
  241. labels = unstack_batch(
  242. labels,
  243. unpad_groundtruth_tensors=train_config.unpad_groundtruth_tensors)
  244. elif mode == tf.estimator.ModeKeys.EVAL:
  245. # For evaling on train data, it is necessary to check whether groundtruth
  246. # must be unpadded.
  247. boxes_shape = (
  248. labels[fields.InputDataFields.groundtruth_boxes].get_shape()
  249. .as_list())
  250. unpad_groundtruth_tensors = boxes_shape[1] is not None and not use_tpu
  251. labels = unstack_batch(
  252. labels, unpad_groundtruth_tensors=unpad_groundtruth_tensors)
  253. if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
  254. _provide_groundtruth(detection_model, labels)
  255. preprocessed_images = features[fields.InputDataFields.image]
  256. if use_tpu and train_config.use_bfloat16:
  257. with tf.contrib.tpu.bfloat16_scope():
  258. prediction_dict = detection_model.predict(
  259. preprocessed_images,
  260. features[fields.InputDataFields.true_image_shape])
  261. prediction_dict = ops.bfloat16_to_float32_nested(prediction_dict)
  262. else:
  263. prediction_dict = detection_model.predict(
  264. preprocessed_images,
  265. features[fields.InputDataFields.true_image_shape])
  266. def postprocess_wrapper(args):
  267. return detection_model.postprocess(args[0], args[1])
  268. if mode in (tf.estimator.ModeKeys.EVAL, tf.estimator.ModeKeys.PREDICT):
  269. if use_tpu and postprocess_on_cpu:
  270. detections = tf.contrib.tpu.outside_compilation(
  271. postprocess_wrapper,
  272. (prediction_dict,
  273. features[fields.InputDataFields.true_image_shape]))
  274. else:
  275. detections = postprocess_wrapper((
  276. prediction_dict,
  277. features[fields.InputDataFields.true_image_shape]))
  278. if mode == tf.estimator.ModeKeys.TRAIN:
  279. if train_config.fine_tune_checkpoint and hparams.load_pretrained:
  280. if not train_config.fine_tune_checkpoint_type:
  281. # train_config.from_detection_checkpoint field is deprecated. For
  282. # backward compatibility, set train_config.fine_tune_checkpoint_type
  283. # based on train_config.from_detection_checkpoint.
  284. if train_config.from_detection_checkpoint:
  285. train_config.fine_tune_checkpoint_type = 'detection'
  286. else:
  287. train_config.fine_tune_checkpoint_type = 'classification'
  288. asg_map = detection_model.restore_map(
  289. fine_tune_checkpoint_type=train_config.fine_tune_checkpoint_type,
  290. load_all_detection_checkpoint_vars=(
  291. train_config.load_all_detection_checkpoint_vars))
  292. available_var_map = (
  293. variables_helper.get_variables_available_in_checkpoint(
  294. asg_map,
  295. train_config.fine_tune_checkpoint,
  296. include_global_step=False))
  297. if use_tpu:
  298. def tpu_scaffold():
  299. tf.train.init_from_checkpoint(train_config.fine_tune_checkpoint,
  300. available_var_map)
  301. return tf.train.Scaffold()
  302. scaffold_fn = tpu_scaffold
  303. else:
  304. tf.train.init_from_checkpoint(train_config.fine_tune_checkpoint,
  305. available_var_map)
  306. if mode in (tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL):
  307. losses_dict = detection_model.loss(
  308. prediction_dict, features[fields.InputDataFields.true_image_shape])
  309. losses = [loss_tensor for loss_tensor in losses_dict.values()]
  310. if train_config.add_regularization_loss:
  311. regularization_losses = detection_model.regularization_losses()
  312. if use_tpu and train_config.use_bfloat16:
  313. regularization_losses = ops.bfloat16_to_float32_nested(
  314. regularization_losses)
  315. if regularization_losses:
  316. regularization_loss = tf.add_n(
  317. regularization_losses, name='regularization_loss')
  318. losses.append(regularization_loss)
  319. losses_dict['Loss/regularization_loss'] = regularization_loss
  320. total_loss = tf.add_n(losses, name='total_loss')
  321. losses_dict['Loss/total_loss'] = total_loss
  322. if 'graph_rewriter_config' in configs:
  323. graph_rewriter_fn = graph_rewriter_builder.build(
  324. configs['graph_rewriter_config'], is_training=is_training)
  325. graph_rewriter_fn()
  326. # TODO(rathodv): Stop creating optimizer summary vars in EVAL mode once we
  327. # can write learning rate summaries on TPU without host calls.
  328. global_step = tf.train.get_or_create_global_step()
  329. training_optimizer, optimizer_summary_vars = optimizer_builder.build(
  330. train_config.optimizer)
  331. if mode == tf.estimator.ModeKeys.TRAIN:
  332. if use_tpu:
  333. training_optimizer = tf.contrib.tpu.CrossShardOptimizer(
  334. training_optimizer)
  335. # Optionally freeze some layers by setting their gradients to be zero.
  336. trainable_variables = None
  337. include_variables = (
  338. train_config.update_trainable_variables
  339. if train_config.update_trainable_variables else None)
  340. exclude_variables = (
  341. train_config.freeze_variables
  342. if train_config.freeze_variables else None)
  343. trainable_variables = tf.contrib.framework.filter_variables(
  344. tf.trainable_variables(),
  345. include_patterns=include_variables,
  346. exclude_patterns=exclude_variables)
  347. clip_gradients_value = None
  348. if train_config.gradient_clipping_by_norm > 0:
  349. clip_gradients_value = train_config.gradient_clipping_by_norm
  350. if not use_tpu:
  351. for var in optimizer_summary_vars:
  352. tf.summary.scalar(var.op.name, var)
  353. summaries = [] if use_tpu else None
  354. if train_config.summarize_gradients:
  355. summaries = ['gradients', 'gradient_norm', 'global_gradient_norm']
  356. train_op = tf.contrib.layers.optimize_loss(
  357. loss=total_loss,
  358. global_step=global_step,
  359. learning_rate=None,
  360. clip_gradients=clip_gradients_value,
  361. optimizer=training_optimizer,
  362. update_ops=detection_model.updates(),
  363. variables=trainable_variables,
  364. summaries=summaries,
  365. name='') # Preventing scope prefix on all variables.
  366. if mode == tf.estimator.ModeKeys.PREDICT:
  367. exported_output = exporter_lib.add_output_tensor_nodes(detections)
  368. export_outputs = {
  369. tf.saved_model.signature_constants.PREDICT_METHOD_NAME:
  370. tf.estimator.export.PredictOutput(exported_output)
  371. }
  372. eval_metric_ops = None
  373. scaffold = None
  374. if mode == tf.estimator.ModeKeys.EVAL:
  375. class_agnostic = (
  376. fields.DetectionResultFields.detection_classes not in detections)
  377. groundtruth = _prepare_groundtruth_for_eval(
  378. detection_model, class_agnostic,
  379. eval_input_config.max_number_of_boxes)
  380. use_original_images = fields.InputDataFields.original_image in features
  381. if use_original_images:
  382. eval_images = features[fields.InputDataFields.original_image]
  383. true_image_shapes = tf.slice(
  384. features[fields.InputDataFields.true_image_shape], [0, 0], [-1, 3])
  385. original_image_spatial_shapes = features[fields.InputDataFields
  386. .original_image_spatial_shape]
  387. else:
  388. eval_images = features[fields.InputDataFields.image]
  389. true_image_shapes = None
  390. original_image_spatial_shapes = None
  391. eval_dict = eval_util.result_dict_for_batched_example(
  392. eval_images,
  393. features[inputs.HASH_KEY],
  394. detections,
  395. groundtruth,
  396. class_agnostic=class_agnostic,
  397. scale_to_absolute=True,
  398. original_image_spatial_shapes=original_image_spatial_shapes,
  399. true_image_shapes=true_image_shapes)
  400. if class_agnostic:
  401. category_index = label_map_util.create_class_agnostic_category_index()
  402. else:
  403. category_index = label_map_util.create_category_index_from_labelmap(
  404. eval_input_config.label_map_path)
  405. vis_metric_ops = None
  406. if not use_tpu and use_original_images:
  407. eval_metric_op_vis = vis_utils.VisualizeSingleFrameDetections(
  408. category_index,
  409. max_examples_to_draw=eval_config.num_visualizations,
  410. max_boxes_to_draw=eval_config.max_num_boxes_to_visualize,
  411. min_score_thresh=eval_config.min_score_threshold,
  412. use_normalized_coordinates=False)
  413. vis_metric_ops = eval_metric_op_vis.get_estimator_eval_metric_ops(
  414. eval_dict)
  415. # Eval metrics on a single example.
  416. eval_metric_ops = eval_util.get_eval_metric_ops_for_evaluators(
  417. eval_config, list(category_index.values()), eval_dict)
  418. for loss_key, loss_tensor in iter(losses_dict.items()):
  419. eval_metric_ops[loss_key] = tf.metrics.mean(loss_tensor)
  420. for var in optimizer_summary_vars:
  421. eval_metric_ops[var.op.name] = (var, tf.no_op())
  422. if vis_metric_ops is not None:
  423. eval_metric_ops.update(vis_metric_ops)
  424. eval_metric_ops = {str(k): v for k, v in eval_metric_ops.items()}
  425. if eval_config.use_moving_averages:
  426. variable_averages = tf.train.ExponentialMovingAverage(0.0)
  427. variables_to_restore = variable_averages.variables_to_restore()
  428. keep_checkpoint_every_n_hours = (
  429. train_config.keep_checkpoint_every_n_hours)
  430. saver = tf.train.Saver(
  431. variables_to_restore,
  432. keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)
  433. scaffold = tf.train.Scaffold(saver=saver)
  434. # EVAL executes on CPU, so use regular non-TPU EstimatorSpec.
  435. if use_tpu and mode != tf.estimator.ModeKeys.EVAL:
  436. return tf.contrib.tpu.TPUEstimatorSpec(
  437. mode=mode,
  438. scaffold_fn=scaffold_fn,
  439. predictions=detections,
  440. loss=total_loss,
  441. train_op=train_op,
  442. eval_metrics=eval_metric_ops,
  443. export_outputs=export_outputs)
  444. else:
  445. if scaffold is None:
  446. keep_checkpoint_every_n_hours = (
  447. train_config.keep_checkpoint_every_n_hours)
  448. saver = tf.train.Saver(
  449. sharded=True,
  450. keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours,
  451. save_relative_paths=True)
  452. tf.add_to_collection(tf.GraphKeys.SAVERS, saver)
  453. scaffold = tf.train.Scaffold(saver=saver)
  454. return tf.estimator.EstimatorSpec(
  455. mode=mode,
  456. predictions=detections,
  457. loss=total_loss,
  458. train_op=train_op,
  459. eval_metric_ops=eval_metric_ops,
  460. export_outputs=export_outputs,
  461. scaffold=scaffold)
  462. return model_fn
  463. def create_estimator_and_inputs(run_config,
  464. hparams,
  465. pipeline_config_path,
  466. config_override=None,
  467. train_steps=None,
  468. sample_1_of_n_eval_examples=1,
  469. sample_1_of_n_eval_on_train_examples=1,
  470. model_fn_creator=create_model_fn,
  471. use_tpu_estimator=False,
  472. use_tpu=False,
  473. num_shards=1,
  474. params=None,
  475. override_eval_num_epochs=True,
  476. save_final_config=False,
  477. postprocess_on_cpu=False,
  478. export_to_tpu=None,
  479. **kwargs):
  480. """Creates `Estimator`, input functions, and steps.
  481. Args:
  482. run_config: A `RunConfig`.
  483. hparams: A `HParams`.
  484. pipeline_config_path: A path to a pipeline config file.
  485. config_override: A pipeline_pb2.TrainEvalPipelineConfig text proto to
  486. override the config from `pipeline_config_path`.
  487. train_steps: Number of training steps. If None, the number of training steps
  488. is set from the `TrainConfig` proto.
  489. sample_1_of_n_eval_examples: Integer representing how often an eval example
  490. should be sampled. If 1, will sample all examples.
  491. sample_1_of_n_eval_on_train_examples: Similar to
  492. `sample_1_of_n_eval_examples`, except controls the sampling of training
  493. data for evaluation.
  494. model_fn_creator: A function that creates a `model_fn` for `Estimator`.
  495. Follows the signature:
  496. * Args:
  497. * `detection_model_fn`: Function that returns `DetectionModel` instance.
  498. * `configs`: Dictionary of pipeline config objects.
  499. * `hparams`: `HParams` object.
  500. * Returns:
  501. `model_fn` for `Estimator`.
  502. use_tpu_estimator: Whether a `TPUEstimator` should be returned. If False,
  503. an `Estimator` will be returned.
  504. use_tpu: Boolean, whether training and evaluation should run on TPU. Only
  505. used if `use_tpu_estimator` is True.
  506. num_shards: Number of shards (TPU cores). Only used if `use_tpu_estimator`
  507. is True.
  508. params: Parameter dictionary passed from the estimator. Only used if
  509. `use_tpu_estimator` is True.
  510. override_eval_num_epochs: Whether to overwrite the number of epochs to 1 for
  511. eval_input.
  512. save_final_config: Whether to save final config (obtained after applying
  513. overrides) to `estimator.model_dir`.
  514. postprocess_on_cpu: When use_tpu and postprocess_on_cpu are true,
  515. postprocess is scheduled on the host cpu.
  516. export_to_tpu: When use_tpu and export_to_tpu are true,
  517. `export_savedmodel()` exports a metagraph for serving on TPU besides the
  518. one on CPU.
  519. **kwargs: Additional keyword arguments for configuration override.
  520. Returns:
  521. A dictionary with the following fields:
  522. 'estimator': An `Estimator` or `TPUEstimator`.
  523. 'train_input_fn': A training input function.
  524. 'eval_input_fns': A list of all evaluation input functions.
  525. 'eval_input_names': A list of names for each evaluation input.
  526. 'eval_on_train_input_fn': An evaluation-on-train input function.
  527. 'predict_input_fn': A prediction input function.
  528. 'train_steps': Number of training steps. Either directly from input or from
  529. configuration.
  530. """
  531. get_configs_from_pipeline_file = MODEL_BUILD_UTIL_MAP[
  532. 'get_configs_from_pipeline_file']
  533. merge_external_params_with_configs = MODEL_BUILD_UTIL_MAP[
  534. 'merge_external_params_with_configs']
  535. create_pipeline_proto_from_configs = MODEL_BUILD_UTIL_MAP[
  536. 'create_pipeline_proto_from_configs']
  537. create_train_input_fn = MODEL_BUILD_UTIL_MAP['create_train_input_fn']
  538. create_eval_input_fn = MODEL_BUILD_UTIL_MAP['create_eval_input_fn']
  539. create_predict_input_fn = MODEL_BUILD_UTIL_MAP['create_predict_input_fn']
  540. detection_model_fn_base = MODEL_BUILD_UTIL_MAP['detection_model_fn_base']
  541. configs = get_configs_from_pipeline_file(
  542. pipeline_config_path, config_override=config_override)
  543. kwargs.update({
  544. 'train_steps': train_steps,
  545. 'sample_1_of_n_eval_examples': sample_1_of_n_eval_examples,
  546. 'use_bfloat16': configs['train_config'].use_bfloat16 and use_tpu
  547. })
  548. if override_eval_num_epochs:
  549. kwargs.update({'eval_num_epochs': 1})
  550. tf.logging.warning(
  551. 'Forced number of epochs for all eval validations to be 1.')
  552. configs = merge_external_params_with_configs(
  553. configs, hparams, kwargs_dict=kwargs)
  554. model_config = configs['model']
  555. train_config = configs['train_config']
  556. train_input_config = configs['train_input_config']
  557. eval_config = configs['eval_config']
  558. eval_input_configs = configs['eval_input_configs']
  559. eval_on_train_input_config = copy.deepcopy(train_input_config)
  560. eval_on_train_input_config.sample_1_of_n_examples = (
  561. sample_1_of_n_eval_on_train_examples)
  562. if override_eval_num_epochs and eval_on_train_input_config.num_epochs != 1:
  563. tf.logging.warning('Expected number of evaluation epochs is 1, but '
  564. 'instead encountered `eval_on_train_input_config'
  565. '.num_epochs` = '
  566. '{}. Overwriting `num_epochs` to 1.'.format(
  567. eval_on_train_input_config.num_epochs))
  568. eval_on_train_input_config.num_epochs = 1
  569. # update train_steps from config but only when non-zero value is provided
  570. if train_steps is None and train_config.num_steps != 0:
  571. train_steps = train_config.num_steps
  572. detection_model_fn = functools.partial(
  573. detection_model_fn_base, model_config=model_config)
  574. # Create the input functions for TRAIN/EVAL/PREDICT.
  575. train_input_fn = create_train_input_fn(
  576. train_config=train_config,
  577. train_input_config=train_input_config,
  578. model_config=model_config)
  579. eval_input_fns = [
  580. create_eval_input_fn(
  581. eval_config=eval_config,
  582. eval_input_config=eval_input_config,
  583. model_config=model_config) for eval_input_config in eval_input_configs
  584. ]
  585. eval_input_names = [
  586. eval_input_config.name for eval_input_config in eval_input_configs
  587. ]
  588. eval_on_train_input_fn = create_eval_input_fn(
  589. eval_config=eval_config,
  590. eval_input_config=eval_on_train_input_config,
  591. model_config=model_config)
  592. predict_input_fn = create_predict_input_fn(
  593. model_config=model_config, predict_input_config=eval_input_configs[0])
  594. # Read export_to_tpu from hparams if not passed.
  595. if export_to_tpu is None:
  596. export_to_tpu = hparams.get('export_to_tpu', False)
  597. tf.logging.info('create_estimator_and_inputs: use_tpu %s, export_to_tpu %s',
  598. use_tpu, export_to_tpu)
  599. model_fn = model_fn_creator(detection_model_fn, configs, hparams, use_tpu,
  600. postprocess_on_cpu)
  601. if use_tpu_estimator:
  602. # Multicore inference disabled due to b/129367127
  603. tpu_estimator_args = function_utils.fn_args(tf.contrib.tpu.TPUEstimator)
  604. kwargs = {}
  605. if 'experimental_export_device_assignment' in tpu_estimator_args:
  606. kwargs['experimental_export_device_assignment'] = True
  607. estimator = tf.contrib.tpu.TPUEstimator(
  608. model_fn=model_fn,
  609. train_batch_size=train_config.batch_size,
  610. # For each core, only batch size 1 is supported for eval.
  611. eval_batch_size=num_shards * 1 if use_tpu else 1,
  612. use_tpu=use_tpu,
  613. config=run_config,
  614. export_to_tpu=export_to_tpu,
  615. eval_on_tpu=False, # Eval runs on CPU, so disable eval on TPU
  616. params=params if params else {},
  617. **kwargs)
  618. else:
  619. estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config)
  620. # Write the as-run pipeline config to disk.
  621. if run_config.is_chief and save_final_config:
  622. pipeline_config_final = create_pipeline_proto_from_configs(configs)
  623. config_util.save_pipeline_config(pipeline_config_final, estimator.model_dir)
  624. return dict(
  625. estimator=estimator,
  626. train_input_fn=train_input_fn,
  627. eval_input_fns=eval_input_fns,
  628. eval_input_names=eval_input_names,
  629. eval_on_train_input_fn=eval_on_train_input_fn,
  630. predict_input_fn=predict_input_fn,
  631. train_steps=train_steps)
  632. def create_train_and_eval_specs(train_input_fn,
  633. eval_input_fns,
  634. eval_on_train_input_fn,
  635. predict_input_fn,
  636. train_steps,
  637. eval_on_train_data=False,
  638. final_exporter_name='Servo',
  639. eval_spec_names=None):
  640. """Creates a `TrainSpec` and `EvalSpec`s.
  641. Args:
  642. train_input_fn: Function that produces features and labels on train data.
  643. eval_input_fns: A list of functions that produce features and labels on eval
  644. data.
  645. eval_on_train_input_fn: Function that produces features and labels for
  646. evaluation on train data.
  647. predict_input_fn: Function that produces features for inference.
  648. train_steps: Number of training steps.
  649. eval_on_train_data: Whether to evaluate model on training data. Default is
  650. False.
  651. final_exporter_name: String name given to `FinalExporter`.
  652. eval_spec_names: A list of string names for each `EvalSpec`.
  653. Returns:
  654. Tuple of `TrainSpec` and list of `EvalSpecs`. If `eval_on_train_data` is
  655. True, the last `EvalSpec` in the list will correspond to training data. The
  656. rest EvalSpecs in the list are evaluation datas.
  657. """
  658. train_spec = tf.estimator.TrainSpec(
  659. input_fn=train_input_fn, max_steps=train_steps)
  660. if eval_spec_names is None:
  661. eval_spec_names = [str(i) for i in range(len(eval_input_fns))]
  662. eval_specs = []
  663. for index, (eval_spec_name, eval_input_fn) in enumerate(
  664. zip(eval_spec_names, eval_input_fns)):
  665. # Uses final_exporter_name as exporter_name for the first eval spec for
  666. # backward compatibility.
  667. if index == 0:
  668. exporter_name = final_exporter_name
  669. else:
  670. exporter_name = '{}_{}'.format(final_exporter_name, eval_spec_name)
  671. exporter = tf.estimator.FinalExporter(
  672. name=exporter_name, serving_input_receiver_fn=predict_input_fn)
  673. eval_specs.append(
  674. tf.estimator.EvalSpec(
  675. name=eval_spec_name,
  676. input_fn=eval_input_fn,
  677. steps=None,
  678. exporters=exporter))
  679. if eval_on_train_data:
  680. eval_specs.append(
  681. tf.estimator.EvalSpec(
  682. name='eval_on_train', input_fn=eval_on_train_input_fn, steps=None))
  683. return train_spec, eval_specs
  684. def continuous_eval(estimator, model_dir, input_fn, train_steps, name):
  685. """Perform continuous evaluation on checkpoints written to a model directory.
  686. Args:
  687. estimator: Estimator object to use for evaluation.
  688. model_dir: Model directory to read checkpoints for continuous evaluation.
  689. input_fn: Input function to use for evaluation.
  690. train_steps: Number of training steps. This is used to infer the last
  691. checkpoint and stop evaluation loop.
  692. name: Namescope for eval summary.
  693. """
  694. def terminate_eval():
  695. tf.logging.info('Terminating eval after 180 seconds of no checkpoints')
  696. return True
  697. for ckpt in tf.contrib.training.checkpoints_iterator(
  698. model_dir, min_interval_secs=180, timeout=None,
  699. timeout_fn=terminate_eval):
  700. tf.logging.info('Starting Evaluation.')
  701. try:
  702. eval_results = estimator.evaluate(
  703. input_fn=input_fn, steps=None, checkpoint_path=ckpt, name=name)
  704. tf.logging.info('Eval results: %s' % eval_results)
  705. # Terminate eval job when final checkpoint is reached
  706. current_step = int(os.path.basename(ckpt).split('-')[1])
  707. if current_step >= train_steps:
  708. tf.logging.info(
  709. 'Evaluation finished after training step %d' % current_step)
  710. break
  711. except tf.errors.NotFoundError:
  712. tf.logging.info(
  713. 'Checkpoint %s no longer exists, skipping checkpoint' % ckpt)
  714. def populate_experiment(run_config,
  715. hparams,
  716. pipeline_config_path,
  717. train_steps=None,
  718. eval_steps=None,
  719. model_fn_creator=create_model_fn,
  720. **kwargs):
  721. """Populates an `Experiment` object.
  722. EXPERIMENT CLASS IS DEPRECATED. Please switch to
  723. tf.estimator.train_and_evaluate. As an example, see model_main.py.
  724. Args:
  725. run_config: A `RunConfig`.
  726. hparams: A `HParams`.
  727. pipeline_config_path: A path to a pipeline config file.
  728. train_steps: Number of training steps. If None, the number of training steps
  729. is set from the `TrainConfig` proto.
  730. eval_steps: Number of evaluation steps per evaluation cycle. If None, the
  731. number of evaluation steps is set from the `EvalConfig` proto.
  732. model_fn_creator: A function that creates a `model_fn` for `Estimator`.
  733. Follows the signature:
  734. * Args:
  735. * `detection_model_fn`: Function that returns `DetectionModel` instance.
  736. * `configs`: Dictionary of pipeline config objects.
  737. * `hparams`: `HParams` object.
  738. * Returns:
  739. `model_fn` for `Estimator`.
  740. **kwargs: Additional keyword arguments for configuration override.
  741. Returns:
  742. An `Experiment` that defines all aspects of training, evaluation, and
  743. export.
  744. """
  745. tf.logging.warning('Experiment is being deprecated. Please use '
  746. 'tf.estimator.train_and_evaluate(). See model_main.py for '
  747. 'an example.')
  748. train_and_eval_dict = create_estimator_and_inputs(
  749. run_config,
  750. hparams,
  751. pipeline_config_path,
  752. train_steps=train_steps,
  753. eval_steps=eval_steps,
  754. model_fn_creator=model_fn_creator,
  755. save_final_config=True,
  756. **kwargs)
  757. estimator = train_and_eval_dict['estimator']
  758. train_input_fn = train_and_eval_dict['train_input_fn']
  759. eval_input_fns = train_and_eval_dict['eval_input_fns']
  760. predict_input_fn = train_and_eval_dict['predict_input_fn']
  761. train_steps = train_and_eval_dict['train_steps']
  762. export_strategies = [
  763. tf.contrib.learn.utils.saved_model_export_utils.make_export_strategy(
  764. serving_input_fn=predict_input_fn)
  765. ]
  766. return tf.contrib.learn.Experiment(
  767. estimator=estimator,
  768. train_input_fn=train_input_fn,
  769. eval_input_fn=eval_input_fns[0],
  770. train_steps=train_steps,
  771. eval_steps=None,
  772. export_strategies=export_strategies,
  773. eval_delay_secs=120,
  774. )