You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

122 lines
5.3 KiB

  1. syntax = "proto2";
  2. package object_detection.protos;
  3. import "object_detection/protos/optimizer.proto";
  4. import "object_detection/protos/preprocessor.proto";
  5. // Message for configuring DetectionModel training jobs (train.py).
  6. // Next id: 28
  7. message TrainConfig {
  8. // Effective batch size to use for training.
  9. // For TPU (or sync SGD jobs), the batch size per core (or GPU) is going to be
  10. // `batch_size` / number of cores (or `batch_size` / number of GPUs).
  11. optional uint32 batch_size = 1 [default=32];
  12. // Data augmentation options.
  13. repeated PreprocessingStep data_augmentation_options = 2;
  14. // Whether to synchronize replicas during training.
  15. optional bool sync_replicas = 3 [default=false];
  16. // How frequently to keep checkpoints.
  17. optional float keep_checkpoint_every_n_hours = 4 [default=10000.0];
  18. // Optimizer used to train the DetectionModel.
  19. optional Optimizer optimizer = 5;
  20. // If greater than 0, clips gradients by this value.
  21. optional float gradient_clipping_by_norm = 6 [default=0.0];
  22. // Checkpoint to restore variables from. Typically used to load feature
  23. // extractor variables trained outside of object detection.
  24. optional string fine_tune_checkpoint = 7 [default=""];
  25. // Type of checkpoint to restore variables from, e.g. 'classification' or
  26. // 'detection'. Provides extensibility to from_detection_checkpoint.
  27. // Typically used to load feature extractor variables from trained models.
  28. optional string fine_tune_checkpoint_type = 22 [default=""];
  29. // [Deprecated]: use fine_tune_checkpoint_type instead.
  30. // Specifies if the finetune checkpoint is from an object detection model.
  31. // If from an object detection model, the model being trained should have
  32. // the same parameters with the exception of the num_classes parameter.
  33. // If false, it assumes the checkpoint was a object classification model.
  34. optional bool from_detection_checkpoint = 8 [default=false, deprecated=true];
  35. // Whether to load all checkpoint vars that match model variable names and
  36. // sizes. This option is only available if `from_detection_checkpoint` is
  37. // True.
  38. optional bool load_all_detection_checkpoint_vars = 19 [default = false];
  39. // Number of steps to train the DetectionModel for. If 0, will train the model
  40. // indefinitely.
  41. optional uint32 num_steps = 9 [default=0];
  42. // Number of training steps between replica startup.
  43. // This flag must be set to 0 if sync_replicas is set to true.
  44. optional float startup_delay_steps = 10 [default=15];
  45. // If greater than 0, multiplies the gradient of bias variables by this
  46. // amount.
  47. optional float bias_grad_multiplier = 11 [default=0];
  48. // Variables that should be updated during training. Note that variables which
  49. // also match the patterns in freeze_variables will be excluded.
  50. repeated string update_trainable_variables = 25;
  51. // Variables that should not be updated during training. If
  52. // update_trainable_variables is not empty, only eliminates the included
  53. // variables according to freeze_variables patterns.
  54. repeated string freeze_variables = 12;
  55. // Number of replicas to aggregate before making parameter updates.
  56. optional int32 replicas_to_aggregate = 13 [default=1];
  57. // Maximum number of elements to store within a queue.
  58. optional int32 batch_queue_capacity = 14 [default=150, deprecated=true];
  59. // Number of threads to use for batching.
  60. optional int32 num_batch_queue_threads = 15 [default=8, deprecated=true];
  61. // Maximum capacity of the queue used to prefetch assembled batches.
  62. optional int32 prefetch_queue_capacity = 16 [default=5, deprecated=true];
  63. // If true, boxes with the same coordinates will be merged together.
  64. // This is useful when each box can have multiple labels.
  65. // Note that only Sigmoid classification losses should be used.
  66. optional bool merge_multiple_label_boxes = 17 [default=false];
  67. // If true, will use multiclass scores from object annotations as ground
  68. // truth. Currently only compatible with annotated image inputs.
  69. optional bool use_multiclass_scores = 24 [default = false];
  70. // Whether to add regularization loss to `total_loss`. This is true by
  71. // default and adds all regularization losses defined in the model to
  72. // `total_loss`.
  73. // Setting this option to false is very useful while debugging the model and
  74. // losses.
  75. optional bool add_regularization_loss = 18 [default=true];
  76. // Maximum number of boxes used during training.
  77. // Set this to at least the maximum amount of boxes in the input data.
  78. // Otherwise, it may cause "Data loss: Attempted to pad to a smaller size
  79. // than the input element" errors.
  80. optional int32 max_number_of_boxes = 20 [default=100, deprecated=true];
  81. // Whether to remove padding along `num_boxes` dimension of the groundtruth
  82. // tensors.
  83. optional bool unpad_groundtruth_tensors = 21 [default=true];
  84. // Whether to retain original images (i.e. not pre-processed) in the tensor
  85. // dictionary, so that they can be displayed in Tensorboard. Note that this
  86. // will lead to a larger memory footprint.
  87. optional bool retain_original_images = 23 [default=false];
  88. // Whether to use bfloat16 for training. This is currently only supported for
  89. // TPUs.
  90. optional bool use_bfloat16 = 26 [default=false];
  91. // Whether to summarize gradients.
  92. optional bool summarize_gradients = 27 [default=false];
  93. }