|
|
- # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """Functions for reading and updating configuration files."""
-
- import os
- import tensorflow as tf
-
- from google.protobuf import text_format
-
- from tensorflow.python.lib.io import file_io
-
- from object_detection.protos import eval_pb2
- from object_detection.protos import graph_rewriter_pb2
- from object_detection.protos import input_reader_pb2
- from object_detection.protos import model_pb2
- from object_detection.protos import pipeline_pb2
- from object_detection.protos import train_pb2
-
-
- def get_image_resizer_config(model_config):
- """Returns the image resizer config from a model config.
-
- Args:
- model_config: A model_pb2.DetectionModel.
-
- Returns:
- An image_resizer_pb2.ImageResizer.
-
- Raises:
- ValueError: If the model type is not recognized.
- """
- meta_architecture = model_config.WhichOneof("model")
- if meta_architecture == "faster_rcnn":
- return model_config.faster_rcnn.image_resizer
- if meta_architecture == "ssd":
- return model_config.ssd.image_resizer
-
- raise ValueError("Unknown model type: {}".format(meta_architecture))
-
-
- def get_spatial_image_size(image_resizer_config):
- """Returns expected spatial size of the output image from a given config.
-
- Args:
- image_resizer_config: An image_resizer_pb2.ImageResizer.
-
- Returns:
- A list of two integers of the form [height, width]. `height` and `width` are
- set -1 if they cannot be determined during graph construction.
-
- Raises:
- ValueError: If the model type is not recognized.
- """
- if image_resizer_config.HasField("fixed_shape_resizer"):
- return [
- image_resizer_config.fixed_shape_resizer.height,
- image_resizer_config.fixed_shape_resizer.width
- ]
- if image_resizer_config.HasField("keep_aspect_ratio_resizer"):
- if image_resizer_config.keep_aspect_ratio_resizer.pad_to_max_dimension:
- return [image_resizer_config.keep_aspect_ratio_resizer.max_dimension] * 2
- else:
- return [-1, -1]
- if image_resizer_config.HasField(
- "identity_resizer") or image_resizer_config.HasField(
- "conditional_shape_resizer"):
- return [-1, -1]
- raise ValueError("Unknown image resizer type.")
-
-
- def get_configs_from_pipeline_file(pipeline_config_path, config_override=None):
- """Reads config from a file containing pipeline_pb2.TrainEvalPipelineConfig.
-
- Args:
- pipeline_config_path: Path to pipeline_pb2.TrainEvalPipelineConfig text
- proto.
- config_override: A pipeline_pb2.TrainEvalPipelineConfig text proto to
- override pipeline_config_path.
-
- Returns:
- Dictionary of configuration objects. Keys are `model`, `train_config`,
- `train_input_config`, `eval_config`, `eval_input_config`. Value are the
- corresponding config objects.
- """
- pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
- with tf.gfile.GFile(pipeline_config_path, "r") as f:
- proto_str = f.read()
- text_format.Merge(proto_str, pipeline_config)
- if config_override:
- text_format.Merge(config_override, pipeline_config)
- return create_configs_from_pipeline_proto(pipeline_config)
-
-
- def create_configs_from_pipeline_proto(pipeline_config):
- """Creates a configs dictionary from pipeline_pb2.TrainEvalPipelineConfig.
-
- Args:
- pipeline_config: pipeline_pb2.TrainEvalPipelineConfig proto object.
-
- Returns:
- Dictionary of configuration objects. Keys are `model`, `train_config`,
- `train_input_config`, `eval_config`, `eval_input_configs`. Value are
- the corresponding config objects or list of config objects (only for
- eval_input_configs).
- """
- configs = {}
- configs["model"] = pipeline_config.model
- configs["train_config"] = pipeline_config.train_config
- configs["train_input_config"] = pipeline_config.train_input_reader
- configs["eval_config"] = pipeline_config.eval_config
- configs["eval_input_configs"] = pipeline_config.eval_input_reader
- # Keeps eval_input_config only for backwards compatibility. All clients should
- # read eval_input_configs instead.
- if configs["eval_input_configs"]:
- configs["eval_input_config"] = configs["eval_input_configs"][0]
- if pipeline_config.HasField("graph_rewriter"):
- configs["graph_rewriter_config"] = pipeline_config.graph_rewriter
-
- return configs
-
-
- def get_graph_rewriter_config_from_file(graph_rewriter_config_file):
- """Parses config for graph rewriter.
-
- Args:
- graph_rewriter_config_file: file path to the graph rewriter config.
-
- Returns:
- graph_rewriter_pb2.GraphRewriter proto
- """
- graph_rewriter_config = graph_rewriter_pb2.GraphRewriter()
- with tf.gfile.GFile(graph_rewriter_config_file, "r") as f:
- text_format.Merge(f.read(), graph_rewriter_config)
- return graph_rewriter_config
-
-
- def create_pipeline_proto_from_configs(configs):
- """Creates a pipeline_pb2.TrainEvalPipelineConfig from configs dictionary.
-
- This function performs the inverse operation of
- create_configs_from_pipeline_proto().
-
- Args:
- configs: Dictionary of configs. See get_configs_from_pipeline_file().
-
- Returns:
- A fully populated pipeline_pb2.TrainEvalPipelineConfig.
- """
- pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
- pipeline_config.model.CopyFrom(configs["model"])
- pipeline_config.train_config.CopyFrom(configs["train_config"])
- pipeline_config.train_input_reader.CopyFrom(configs["train_input_config"])
- pipeline_config.eval_config.CopyFrom(configs["eval_config"])
- pipeline_config.eval_input_reader.extend(configs["eval_input_configs"])
- if "graph_rewriter_config" in configs:
- pipeline_config.graph_rewriter.CopyFrom(configs["graph_rewriter_config"])
- return pipeline_config
-
-
- def save_pipeline_config(pipeline_config, directory):
- """Saves a pipeline config text file to disk.
-
- Args:
- pipeline_config: A pipeline_pb2.TrainEvalPipelineConfig.
- directory: The model directory into which the pipeline config file will be
- saved.
- """
- if not file_io.file_exists(directory):
- file_io.recursive_create_dir(directory)
- pipeline_config_path = os.path.join(directory, "pipeline.config")
- config_text = text_format.MessageToString(pipeline_config)
- with tf.gfile.Open(pipeline_config_path, "wb") as f:
- tf.logging.info("Writing pipeline config file to %s",
- pipeline_config_path)
- f.write(config_text)
-
-
- def get_configs_from_multiple_files(model_config_path="",
- train_config_path="",
- train_input_config_path="",
- eval_config_path="",
- eval_input_config_path="",
- graph_rewriter_config_path=""):
- """Reads training configuration from multiple config files.
-
- Args:
- model_config_path: Path to model_pb2.DetectionModel.
- train_config_path: Path to train_pb2.TrainConfig.
- train_input_config_path: Path to input_reader_pb2.InputReader.
- eval_config_path: Path to eval_pb2.EvalConfig.
- eval_input_config_path: Path to input_reader_pb2.InputReader.
- graph_rewriter_config_path: Path to graph_rewriter_pb2.GraphRewriter.
-
- Returns:
- Dictionary of configuration objects. Keys are `model`, `train_config`,
- `train_input_config`, `eval_config`, `eval_input_config`. Key/Values are
- returned only for valid (non-empty) strings.
- """
- configs = {}
- if model_config_path:
- model_config = model_pb2.DetectionModel()
- with tf.gfile.GFile(model_config_path, "r") as f:
- text_format.Merge(f.read(), model_config)
- configs["model"] = model_config
-
- if train_config_path:
- train_config = train_pb2.TrainConfig()
- with tf.gfile.GFile(train_config_path, "r") as f:
- text_format.Merge(f.read(), train_config)
- configs["train_config"] = train_config
-
- if train_input_config_path:
- train_input_config = input_reader_pb2.InputReader()
- with tf.gfile.GFile(train_input_config_path, "r") as f:
- text_format.Merge(f.read(), train_input_config)
- configs["train_input_config"] = train_input_config
-
- if eval_config_path:
- eval_config = eval_pb2.EvalConfig()
- with tf.gfile.GFile(eval_config_path, "r") as f:
- text_format.Merge(f.read(), eval_config)
- configs["eval_config"] = eval_config
-
- if eval_input_config_path:
- eval_input_config = input_reader_pb2.InputReader()
- with tf.gfile.GFile(eval_input_config_path, "r") as f:
- text_format.Merge(f.read(), eval_input_config)
- configs["eval_input_configs"] = [eval_input_config]
-
- if graph_rewriter_config_path:
- configs["graph_rewriter_config"] = get_graph_rewriter_config_from_file(
- graph_rewriter_config_path)
-
- return configs
-
-
- def get_number_of_classes(model_config):
- """Returns the number of classes for a detection model.
-
- Args:
- model_config: A model_pb2.DetectionModel.
-
- Returns:
- Number of classes.
-
- Raises:
- ValueError: If the model type is not recognized.
- """
- meta_architecture = model_config.WhichOneof("model")
- if meta_architecture == "faster_rcnn":
- return model_config.faster_rcnn.num_classes
- if meta_architecture == "ssd":
- return model_config.ssd.num_classes
-
- raise ValueError("Expected the model to be one of 'faster_rcnn' or 'ssd'.")
-
-
- def get_optimizer_type(train_config):
- """Returns the optimizer type for training.
-
- Args:
- train_config: A train_pb2.TrainConfig.
-
- Returns:
- The type of the optimizer
- """
- return train_config.optimizer.WhichOneof("optimizer")
-
-
- def get_learning_rate_type(optimizer_config):
- """Returns the learning rate type for training.
-
- Args:
- optimizer_config: An optimizer_pb2.Optimizer.
-
- Returns:
- The type of the learning rate.
- """
- return optimizer_config.learning_rate.WhichOneof("learning_rate")
-
-
- def _is_generic_key(key):
- """Determines whether the key starts with a generic config dictionary key."""
- for prefix in [
- "graph_rewriter_config",
- "model",
- "train_input_config",
- "train_config",
- "eval_config"]:
- if key.startswith(prefix + "."):
- return True
- return False
-
-
- def _check_and_convert_legacy_input_config_key(key):
- """Checks key and converts legacy input config update to specific update.
-
- Args:
- key: string indicates the target of update operation.
-
- Returns:
- is_valid_input_config_key: A boolean indicating whether the input key is to
- update input config(s).
- key_name: 'eval_input_configs' or 'train_input_config' string if
- is_valid_input_config_key is true. None if is_valid_input_config_key is
- false.
- input_name: always returns None since legacy input config key never
- specifies the target input config. Keeping this output only to match the
- output form defined for input config update.
- field_name: the field name in input config. `key` itself if
- is_valid_input_config_key is false.
- """
- key_name = None
- input_name = None
- field_name = key
- is_valid_input_config_key = True
- if field_name == "train_shuffle":
- key_name = "train_input_config"
- field_name = "shuffle"
- elif field_name == "eval_shuffle":
- key_name = "eval_input_configs"
- field_name = "shuffle"
- elif field_name == "train_input_path":
- key_name = "train_input_config"
- field_name = "input_path"
- elif field_name == "eval_input_path":
- key_name = "eval_input_configs"
- field_name = "input_path"
- elif field_name == "append_train_input_path":
- key_name = "train_input_config"
- field_name = "input_path"
- elif field_name == "append_eval_input_path":
- key_name = "eval_input_configs"
- field_name = "input_path"
- else:
- is_valid_input_config_key = False
-
- return is_valid_input_config_key, key_name, input_name, field_name
-
-
- def check_and_parse_input_config_key(configs, key):
- """Checks key and returns specific fields if key is valid input config update.
-
- Args:
- configs: Dictionary of configuration objects. See outputs from
- get_configs_from_pipeline_file() or get_configs_from_multiple_files().
- key: string indicates the target of update operation.
-
- Returns:
- is_valid_input_config_key: A boolean indicate whether the input key is to
- update input config(s).
- key_name: 'eval_input_configs' or 'train_input_config' string if
- is_valid_input_config_key is true. None if is_valid_input_config_key is
- false.
- input_name: the name of the input config to be updated. None if
- is_valid_input_config_key is false.
- field_name: the field name in input config. `key` itself if
- is_valid_input_config_key is false.
-
- Raises:
- ValueError: when the input key format doesn't match any known formats.
- ValueError: if key_name doesn't match 'eval_input_configs' or
- 'train_input_config'.
- ValueError: if input_name doesn't match any name in train or eval input
- configs.
- ValueError: if field_name doesn't match any supported fields.
- """
- key_name = None
- input_name = None
- field_name = None
- fields = key.split(":")
- if len(fields) == 1:
- field_name = key
- return _check_and_convert_legacy_input_config_key(key)
- elif len(fields) == 3:
- key_name = fields[0]
- input_name = fields[1]
- field_name = fields[2]
- else:
- raise ValueError("Invalid key format when overriding configs.")
-
- # Checks if key_name is valid for specific update.
- if key_name not in ["eval_input_configs", "train_input_config"]:
- raise ValueError("Invalid key_name when overriding input config.")
-
- # Checks if input_name is valid for specific update. For train input config it
- # should match configs[key_name].name, for eval input configs it should match
- # the name field of one of the eval_input_configs.
- if isinstance(configs[key_name], input_reader_pb2.InputReader):
- is_valid_input_name = configs[key_name].name == input_name
- else:
- is_valid_input_name = input_name in [
- eval_input_config.name for eval_input_config in configs[key_name]
- ]
- if not is_valid_input_name:
- raise ValueError("Invalid input_name when overriding input config.")
-
- # Checks if field_name is valid for specific update.
- if field_name not in [
- "input_path", "label_map_path", "shuffle", "mask_type",
- "sample_1_of_n_examples"
- ]:
- raise ValueError("Invalid field_name when overriding input config.")
-
- return True, key_name, input_name, field_name
-
-
- def merge_external_params_with_configs(configs, hparams=None, kwargs_dict=None):
- """Updates `configs` dictionary based on supplied parameters.
-
- This utility is for modifying specific fields in the object detection configs.
- Say that one would like to experiment with different learning rates, momentum
- values, or batch sizes. Rather than creating a new config text file for each
- experiment, one can use a single base config file, and update particular
- values.
-
- There are two types of field overrides:
- 1. Strategy-based overrides, which update multiple relevant configuration
- options. For example, updating `learning_rate` will update both the warmup and
- final learning rates.
- In this case key can be one of the following formats:
- 1. legacy update: single string that indicates the attribute to be
- updated. E.g. 'label_map_path', 'eval_input_path', 'shuffle'.
- Note that when updating fields (e.g. eval_input_path, eval_shuffle) in
- eval_input_configs, the override will only be applied when
- eval_input_configs has exactly 1 element.
- 2. specific update: colon separated string that indicates which field in
- which input_config to update. It should have 3 fields:
- - key_name: Name of the input config we should update, either
- 'train_input_config' or 'eval_input_configs'
- - input_name: a 'name' that can be used to identify elements, especially
- when configs[key_name] is a repeated field.
- - field_name: name of the field that you want to override.
- For example, given configs dict as below:
- configs = {
- 'model': {...}
- 'train_config': {...}
- 'train_input_config': {...}
- 'eval_config': {...}
- 'eval_input_configs': [{ name:"eval_coco", ...},
- { name:"eval_voc", ... }]
- }
- Assume we want to update the input_path of the eval_input_config
- whose name is 'eval_coco'. The `key` would then be:
- 'eval_input_configs:eval_coco:input_path'
- 2. Generic key/value, which update a specific parameter based on namespaced
- configuration keys. For example,
- `model.ssd.loss.hard_example_miner.max_negatives_per_positive` will update the
- hard example miner configuration for an SSD model config. Generic overrides
- are automatically detected based on the namespaced keys.
-
- Args:
- configs: Dictionary of configuration objects. See outputs from
- get_configs_from_pipeline_file() or get_configs_from_multiple_files().
- hparams: A `HParams`.
- kwargs_dict: Extra keyword arguments that are treated the same way as
- attribute/value pairs in `hparams`. Note that hyperparameters with the
- same names will override keyword arguments.
-
- Returns:
- `configs` dictionary.
-
- Raises:
- ValueError: when the key string doesn't match any of its allowed formats.
- """
-
- if kwargs_dict is None:
- kwargs_dict = {}
- if hparams:
- kwargs_dict.update(hparams.values())
- for key, value in kwargs_dict.items():
- tf.logging.info("Maybe overwriting %s: %s", key, value)
- # pylint: disable=g-explicit-bool-comparison
- if value == "" or value is None:
- continue
- # pylint: enable=g-explicit-bool-comparison
- elif _maybe_update_config_with_key_value(configs, key, value):
- continue
- elif _is_generic_key(key):
- _update_generic(configs, key, value)
- else:
- tf.logging.info("Ignoring config override key: %s", key)
- return configs
-
-
- def _maybe_update_config_with_key_value(configs, key, value):
- """Checks key type and updates `configs` with the key value pair accordingly.
-
- Args:
- configs: Dictionary of configuration objects. See outputs from
- get_configs_from_pipeline_file() or get_configs_from_multiple_files().
- key: String indicates the field(s) to be updated.
- value: Value used to override existing field value.
-
- Returns:
- A boolean value that indicates whether the override succeeds.
-
- Raises:
- ValueError: when the key string doesn't match any of the formats above.
- """
- is_valid_input_config_key, key_name, input_name, field_name = (
- check_and_parse_input_config_key(configs, key))
- if is_valid_input_config_key:
- update_input_reader_config(
- configs,
- key_name=key_name,
- input_name=input_name,
- field_name=field_name,
- value=value)
- elif field_name == "learning_rate":
- _update_initial_learning_rate(configs, value)
- elif field_name == "batch_size":
- _update_batch_size(configs, value)
- elif field_name == "momentum_optimizer_value":
- _update_momentum_optimizer_value(configs, value)
- elif field_name == "classification_localization_weight_ratio":
- # Localization weight is fixed to 1.0.
- _update_classification_localization_weight_ratio(configs, value)
- elif field_name == "focal_loss_gamma":
- _update_focal_loss_gamma(configs, value)
- elif field_name == "focal_loss_alpha":
- _update_focal_loss_alpha(configs, value)
- elif field_name == "train_steps":
- _update_train_steps(configs, value)
- elif field_name == "label_map_path":
- _update_label_map_path(configs, value)
- elif field_name == "mask_type":
- _update_mask_type(configs, value)
- elif field_name == "sample_1_of_n_eval_examples":
- _update_all_eval_input_configs(configs, "sample_1_of_n_examples", value)
- elif field_name == "eval_num_epochs":
- _update_all_eval_input_configs(configs, "num_epochs", value)
- elif field_name == "eval_with_moving_averages":
- _update_use_moving_averages(configs, value)
- elif field_name == "retain_original_images_in_eval":
- _update_retain_original_images(configs["eval_config"], value)
- elif field_name == "use_bfloat16":
- _update_use_bfloat16(configs, value)
- else:
- return False
- return True
-
-
- def _update_tf_record_input_path(input_config, input_path):
- """Updates input configuration to reflect a new input path.
-
- The input_config object is updated in place, and hence not returned.
-
- Args:
- input_config: A input_reader_pb2.InputReader.
- input_path: A path to data or list of paths.
-
- Raises:
- TypeError: if input reader type is not `tf_record_input_reader`.
- """
- input_reader_type = input_config.WhichOneof("input_reader")
- if input_reader_type == "tf_record_input_reader":
- input_config.tf_record_input_reader.ClearField("input_path")
- if isinstance(input_path, list):
- input_config.tf_record_input_reader.input_path.extend(input_path)
- else:
- input_config.tf_record_input_reader.input_path.append(input_path)
- else:
- raise TypeError("Input reader type must be `tf_record_input_reader`.")
-
-
- def update_input_reader_config(configs,
- key_name=None,
- input_name=None,
- field_name=None,
- value=None,
- path_updater=_update_tf_record_input_path):
- """Updates specified input reader config field.
-
- Args:
- configs: Dictionary of configuration objects. See outputs from
- get_configs_from_pipeline_file() or get_configs_from_multiple_files().
- key_name: Name of the input config we should update, either
- 'train_input_config' or 'eval_input_configs'
- input_name: String name used to identify input config to update with. Should
- be either None or value of the 'name' field in one of the input reader
- configs.
- field_name: Field name in input_reader_pb2.InputReader.
- value: Value used to override existing field value.
- path_updater: helper function used to update the input path. Only used when
- field_name is "input_path".
-
- Raises:
- ValueError: when input field_name is None.
- ValueError: when input_name is None and number of eval_input_readers does
- not equal to 1.
- """
- if isinstance(configs[key_name], input_reader_pb2.InputReader):
- # Updates singular input_config object.
- target_input_config = configs[key_name]
- if field_name == "input_path":
- path_updater(input_config=target_input_config, input_path=value)
- else:
- setattr(target_input_config, field_name, value)
- elif input_name is None and len(configs[key_name]) == 1:
- # Updates first (and the only) object of input_config list.
- target_input_config = configs[key_name][0]
- if field_name == "input_path":
- path_updater(input_config=target_input_config, input_path=value)
- else:
- setattr(target_input_config, field_name, value)
- elif input_name is not None and len(configs[key_name]):
- # Updates input_config whose name matches input_name.
- update_count = 0
- for input_config in configs[key_name]:
- if input_config.name == input_name:
- setattr(input_config, field_name, value)
- update_count = update_count + 1
- if not update_count:
- raise ValueError(
- "Input name {} not found when overriding.".format(input_name))
- elif update_count > 1:
- raise ValueError("Duplicate input name found when overriding.")
- else:
- key_name = "None" if key_name is None else key_name
- input_name = "None" if input_name is None else input_name
- field_name = "None" if field_name is None else field_name
- raise ValueError("Unknown input config overriding: "
- "key_name:{}, input_name:{}, field_name:{}.".format(
- key_name, input_name, field_name))
-
-
- def _update_initial_learning_rate(configs, learning_rate):
- """Updates `configs` to reflect the new initial learning rate.
-
- This function updates the initial learning rate. For learning rate schedules,
- all other defined learning rates in the pipeline config are scaled to maintain
- their same ratio with the initial learning rate.
- The configs dictionary is updated in place, and hence not returned.
-
- Args:
- configs: Dictionary of configuration objects. See outputs from
- get_configs_from_pipeline_file() or get_configs_from_multiple_files().
- learning_rate: Initial learning rate for optimizer.
-
- Raises:
- TypeError: if optimizer type is not supported, or if learning rate type is
- not supported.
- """
-
- optimizer_type = get_optimizer_type(configs["train_config"])
- if optimizer_type == "rms_prop_optimizer":
- optimizer_config = configs["train_config"].optimizer.rms_prop_optimizer
- elif optimizer_type == "momentum_optimizer":
- optimizer_config = configs["train_config"].optimizer.momentum_optimizer
- elif optimizer_type == "adam_optimizer":
- optimizer_config = configs["train_config"].optimizer.adam_optimizer
- else:
- raise TypeError("Optimizer %s is not supported." % optimizer_type)
-
- learning_rate_type = get_learning_rate_type(optimizer_config)
- if learning_rate_type == "constant_learning_rate":
- constant_lr = optimizer_config.learning_rate.constant_learning_rate
- constant_lr.learning_rate = learning_rate
- elif learning_rate_type == "exponential_decay_learning_rate":
- exponential_lr = (
- optimizer_config.learning_rate.exponential_decay_learning_rate)
- exponential_lr.initial_learning_rate = learning_rate
- elif learning_rate_type == "manual_step_learning_rate":
- manual_lr = optimizer_config.learning_rate.manual_step_learning_rate
- original_learning_rate = manual_lr.initial_learning_rate
- learning_rate_scaling = float(learning_rate) / original_learning_rate
- manual_lr.initial_learning_rate = learning_rate
- for schedule in manual_lr.schedule:
- schedule.learning_rate *= learning_rate_scaling
- elif learning_rate_type == "cosine_decay_learning_rate":
- cosine_lr = optimizer_config.learning_rate.cosine_decay_learning_rate
- learning_rate_base = cosine_lr.learning_rate_base
- warmup_learning_rate = cosine_lr.warmup_learning_rate
- warmup_scale_factor = warmup_learning_rate / learning_rate_base
- cosine_lr.learning_rate_base = learning_rate
- cosine_lr.warmup_learning_rate = warmup_scale_factor * learning_rate
- else:
- raise TypeError("Learning rate %s is not supported." % learning_rate_type)
-
-
- def _update_batch_size(configs, batch_size):
- """Updates `configs` to reflect the new training batch size.
-
- The configs dictionary is updated in place, and hence not returned.
-
- Args:
- configs: Dictionary of configuration objects. See outputs from
- get_configs_from_pipeline_file() or get_configs_from_multiple_files().
- batch_size: Batch size to use for training (Ideally a power of 2). Inputs
- are rounded, and capped to be 1 or greater.
- """
- configs["train_config"].batch_size = max(1, int(round(batch_size)))
-
-
- def _validate_message_has_field(message, field):
- if not message.HasField(field):
- raise ValueError("Expecting message to have field %s" % field)
-
-
- def _update_generic(configs, key, value):
- """Update a pipeline configuration parameter based on a generic key/value.
-
- Args:
- configs: Dictionary of pipeline configuration protos.
- key: A string key, dot-delimited to represent the argument key.
- e.g. "model.ssd.train_config.batch_size"
- value: A value to set the argument to. The type of the value must match the
- type for the protocol buffer. Note that setting the wrong type will
- result in a TypeError.
- e.g. 42
-
- Raises:
- ValueError if the message key does not match the existing proto fields.
- TypeError the value type doesn't match the protobuf field type.
- """
- fields = key.split(".")
- first_field = fields.pop(0)
- last_field = fields.pop()
- message = configs[first_field]
- for field in fields:
- _validate_message_has_field(message, field)
- message = getattr(message, field)
- _validate_message_has_field(message, last_field)
- setattr(message, last_field, value)
-
-
- def _update_momentum_optimizer_value(configs, momentum):
- """Updates `configs` to reflect the new momentum value.
-
- Momentum is only supported for RMSPropOptimizer and MomentumOptimizer. For any
- other optimizer, no changes take place. The configs dictionary is updated in
- place, and hence not returned.
-
- Args:
- configs: Dictionary of configuration objects. See outputs from
- get_configs_from_pipeline_file() or get_configs_from_multiple_files().
- momentum: New momentum value. Values are clipped at 0.0 and 1.0.
-
- Raises:
- TypeError: If the optimizer type is not `rms_prop_optimizer` or
- `momentum_optimizer`.
- """
- optimizer_type = get_optimizer_type(configs["train_config"])
- if optimizer_type == "rms_prop_optimizer":
- optimizer_config = configs["train_config"].optimizer.rms_prop_optimizer
- elif optimizer_type == "momentum_optimizer":
- optimizer_config = configs["train_config"].optimizer.momentum_optimizer
- else:
- raise TypeError("Optimizer type must be one of `rms_prop_optimizer` or "
- "`momentum_optimizer`.")
-
- optimizer_config.momentum_optimizer_value = min(max(0.0, momentum), 1.0)
-
-
- def _update_classification_localization_weight_ratio(configs, ratio):
- """Updates the classification/localization weight loss ratio.
-
- Detection models usually define a loss weight for both classification and
- objectness. This function updates the weights such that the ratio between
- classification weight to localization weight is the ratio provided.
- Arbitrarily, localization weight is set to 1.0.
-
- Note that in the case of Faster R-CNN, this same ratio is applied to the first
- stage objectness loss weight relative to localization loss weight.
-
- The configs dictionary is updated in place, and hence not returned.
-
- Args:
- configs: Dictionary of configuration objects. See outputs from
- get_configs_from_pipeline_file() or get_configs_from_multiple_files().
- ratio: Desired ratio of classification (and/or objectness) loss weight to
- localization loss weight.
- """
- meta_architecture = configs["model"].WhichOneof("model")
- if meta_architecture == "faster_rcnn":
- model = configs["model"].faster_rcnn
- model.first_stage_localization_loss_weight = 1.0
- model.first_stage_objectness_loss_weight = ratio
- model.second_stage_localization_loss_weight = 1.0
- model.second_stage_classification_loss_weight = ratio
- if meta_architecture == "ssd":
- model = configs["model"].ssd
- model.loss.localization_weight = 1.0
- model.loss.classification_weight = ratio
-
-
- def _get_classification_loss(model_config):
- """Returns the classification loss for a model."""
- meta_architecture = model_config.WhichOneof("model")
- if meta_architecture == "faster_rcnn":
- model = model_config.faster_rcnn
- classification_loss = model.second_stage_classification_loss
- elif meta_architecture == "ssd":
- model = model_config.ssd
- classification_loss = model.loss.classification_loss
- else:
- raise TypeError("Did not recognize the model architecture.")
- return classification_loss
-
-
- def _update_focal_loss_gamma(configs, gamma):
- """Updates the gamma value for a sigmoid focal loss.
-
- The configs dictionary is updated in place, and hence not returned.
-
- Args:
- configs: Dictionary of configuration objects. See outputs from
- get_configs_from_pipeline_file() or get_configs_from_multiple_files().
- gamma: Exponent term in focal loss.
-
- Raises:
- TypeError: If the classification loss is not `weighted_sigmoid_focal`.
- """
- classification_loss = _get_classification_loss(configs["model"])
- classification_loss_type = classification_loss.WhichOneof(
- "classification_loss")
- if classification_loss_type != "weighted_sigmoid_focal":
- raise TypeError("Classification loss must be `weighted_sigmoid_focal`.")
- classification_loss.weighted_sigmoid_focal.gamma = gamma
-
-
- def _update_focal_loss_alpha(configs, alpha):
- """Updates the alpha value for a sigmoid focal loss.
-
- The configs dictionary is updated in place, and hence not returned.
-
- Args:
- configs: Dictionary of configuration objects. See outputs from
- get_configs_from_pipeline_file() or get_configs_from_multiple_files().
- alpha: Class weight multiplier for sigmoid loss.
-
- Raises:
- TypeError: If the classification loss is not `weighted_sigmoid_focal`.
- """
- classification_loss = _get_classification_loss(configs["model"])
- classification_loss_type = classification_loss.WhichOneof(
- "classification_loss")
- if classification_loss_type != "weighted_sigmoid_focal":
- raise TypeError("Classification loss must be `weighted_sigmoid_focal`.")
- classification_loss.weighted_sigmoid_focal.alpha = alpha
-
-
- def _update_train_steps(configs, train_steps):
- """Updates `configs` to reflect new number of training steps."""
- configs["train_config"].num_steps = int(train_steps)
-
-
- def _update_all_eval_input_configs(configs, field, value):
- """Updates the content of `field` with `value` for all eval input configs."""
- for eval_input_config in configs["eval_input_configs"]:
- setattr(eval_input_config, field, value)
-
-
- def _update_label_map_path(configs, label_map_path):
- """Updates the label map path for both train and eval input readers.
-
- The configs dictionary is updated in place, and hence not returned.
-
- Args:
- configs: Dictionary of configuration objects. See outputs from
- get_configs_from_pipeline_file() or get_configs_from_multiple_files().
- label_map_path: New path to `StringIntLabelMap` pbtxt file.
- """
- configs["train_input_config"].label_map_path = label_map_path
- _update_all_eval_input_configs(configs, "label_map_path", label_map_path)
-
-
- def _update_mask_type(configs, mask_type):
- """Updates the mask type for both train and eval input readers.
-
- The configs dictionary is updated in place, and hence not returned.
-
- Args:
- configs: Dictionary of configuration objects. See outputs from
- get_configs_from_pipeline_file() or get_configs_from_multiple_files().
- mask_type: A string name representing a value of
- input_reader_pb2.InstanceMaskType
- """
- configs["train_input_config"].mask_type = mask_type
- _update_all_eval_input_configs(configs, "mask_type", mask_type)
-
-
- def _update_use_moving_averages(configs, use_moving_averages):
- """Updates the eval config option to use or not use moving averages.
-
- The configs dictionary is updated in place, and hence not returned.
-
- Args:
- configs: Dictionary of configuration objects. See outputs from
- get_configs_from_pipeline_file() or get_configs_from_multiple_files().
- use_moving_averages: Boolean indicating whether moving average variables
- should be loaded during evaluation.
- """
- configs["eval_config"].use_moving_averages = use_moving_averages
-
-
- def _update_retain_original_images(eval_config, retain_original_images):
- """Updates eval config with option to retain original images.
-
- The eval_config object is updated in place, and hence not returned.
-
- Args:
- eval_config: A eval_pb2.EvalConfig.
- retain_original_images: Boolean indicating whether to retain original images
- in eval mode.
- """
- eval_config.retain_original_images = retain_original_images
-
-
- def _update_use_bfloat16(configs, use_bfloat16):
- """Updates `configs` to reflect the new setup on whether to use bfloat16.
-
- The configs dictionary is updated in place, and hence not returned.
-
- Args:
- configs: Dictionary of configuration objects. See outputs from
- get_configs_from_pipeline_file() or get_configs_from_multiple_files().
- use_bfloat16: A bool, indicating whether to use bfloat16 for training.
- """
- configs["train_config"].use_bfloat16 = use_bfloat16
|