You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

209 lines
6.8 KiB

6 years ago
  1. # Preparing Inputs
  2. [TOC]
  3. To use your own dataset in Tensorflow Object Detection API, you must convert it
  4. into the [TFRecord file format](https://www.tensorflow.org/api_guides/python/python_io#tfrecords_format_details).
  5. This document outlines how to write a script to generate the TFRecord file.
  6. ## Label Maps
  7. Each dataset is required to have a label map associated with it. This label map
  8. defines a mapping from string class names to integer class Ids. The label map
  9. should be a `StringIntLabelMap` text protobuf. Sample label maps can be found in
  10. object_detection/data. Label maps should always start from id 1.
  11. ## Dataset Requirements
  12. For every example in your dataset, you should have the following information:
  13. 1. An RGB image for the dataset encoded as jpeg or png.
  14. 2. A list of bounding boxes for the image. Each bounding box should contain:
  15. 1. A bounding box coordinates (with origin in top left corner) defined by 4
  16. floating point numbers [ymin, xmin, ymax, xmax]. Note that we store the
  17. _normalized_ coordinates (x / width, y / height) in the TFRecord dataset.
  18. 2. The class of the object in the bounding box.
  19. # Example Image
  20. Consider the following image:
  21. ![Example Image](img/example_cat.jpg "Example Image")
  22. with the following label map:
  23. ```
  24. item {
  25. id: 1
  26. name: 'Cat'
  27. }
  28. item {
  29. id: 2
  30. name: 'Dog'
  31. }
  32. ```
  33. We can generate a tf.Example proto for this image using the following code:
  34. ```python
  35. def create_cat_tf_example(encoded_cat_image_data):
  36. """Creates a tf.Example proto from sample cat image.
  37. Args:
  38. encoded_cat_image_data: The jpg encoded data of the cat image.
  39. Returns:
  40. example: The created tf.Example.
  41. """
  42. height = 1032.0
  43. width = 1200.0
  44. filename = 'example_cat.jpg'
  45. image_format = b'jpg'
  46. xmins = [322.0 / 1200.0]
  47. xmaxs = [1062.0 / 1200.0]
  48. ymins = [174.0 / 1032.0]
  49. ymaxs = [761.0 / 1032.0]
  50. classes_text = ['Cat']
  51. classes = [1]
  52. tf_example = tf.train.Example(features=tf.train.Features(feature={
  53. 'image/height': dataset_util.int64_feature(height),
  54. 'image/width': dataset_util.int64_feature(width),
  55. 'image/filename': dataset_util.bytes_feature(filename),
  56. 'image/source_id': dataset_util.bytes_feature(filename),
  57. 'image/encoded': dataset_util.bytes_feature(encoded_image_data),
  58. 'image/format': dataset_util.bytes_feature(image_format),
  59. 'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
  60. 'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
  61. 'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
  62. 'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
  63. 'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
  64. 'image/object/class/label': dataset_util.int64_list_feature(classes),
  65. }))
  66. return tf_example
  67. ```
  68. ## Conversion Script Outline {#conversion-script-outline}
  69. A typical conversion script will look like the following:
  70. ```python
  71. import tensorflow as tf
  72. from object_detection.utils import dataset_util
  73. flags = tf.app.flags
  74. flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
  75. FLAGS = flags.FLAGS
  76. def create_tf_example(example):
  77. # TODO(user): Populate the following variables from your example.
  78. height = None # Image height
  79. width = None # Image width
  80. filename = None # Filename of the image. Empty if image is not from file
  81. encoded_image_data = None # Encoded image bytes
  82. image_format = None # b'jpeg' or b'png'
  83. xmins = [] # List of normalized left x coordinates in bounding box (1 per box)
  84. xmaxs = [] # List of normalized right x coordinates in bounding box
  85. # (1 per box)
  86. ymins = [] # List of normalized top y coordinates in bounding box (1 per box)
  87. ymaxs = [] # List of normalized bottom y coordinates in bounding box
  88. # (1 per box)
  89. classes_text = [] # List of string class name of bounding box (1 per box)
  90. classes = [] # List of integer class id of bounding box (1 per box)
  91. tf_example = tf.train.Example(features=tf.train.Features(feature={
  92. 'image/height': dataset_util.int64_feature(height),
  93. 'image/width': dataset_util.int64_feature(width),
  94. 'image/filename': dataset_util.bytes_feature(filename),
  95. 'image/source_id': dataset_util.bytes_feature(filename),
  96. 'image/encoded': dataset_util.bytes_feature(encoded_image_data),
  97. 'image/format': dataset_util.bytes_feature(image_format),
  98. 'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
  99. 'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
  100. 'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
  101. 'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
  102. 'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
  103. 'image/object/class/label': dataset_util.int64_list_feature(classes),
  104. }))
  105. return tf_example
  106. def main(_):
  107. writer = tf.python_io.TFRecordWriter(FLAGS.output_path)
  108. # TODO(user): Write code to read in your dataset to examples variable
  109. for example in examples:
  110. tf_example = create_tf_example(example)
  111. writer.write(tf_example.SerializeToString())
  112. writer.close()
  113. if __name__ == '__main__':
  114. tf.app.run()
  115. ```
  116. Note: You may notice additional fields in some other datasets. They are
  117. currently unused by the API and are optional.
  118. Note: Please refer to the section on [Running an Instance Segmentation
  119. Model](instance_segmentation.md) for instructions on how to configure a model
  120. that predicts masks in addition to object bounding boxes.
  121. ## Sharding datasets
  122. When you have more than a few thousand examples, it is beneficial to shard your
  123. dataset into multiple files:
  124. * tf.data.Dataset API can read input examples in parallel improving
  125. throughput.
  126. * tf.data.Dataset API can shuffle the examples better with sharded files which
  127. improves performance of the model slightly.
  128. Instead of writing all tf.Example protos to a single file as shown in
  129. [conversion script outline](#conversion-script-outline), use the snippet below.
  130. ```python
  131. import contextlib2
  132. from object_detection.dataset_tools import tf_record_creation_util
  133. num_shards=10
  134. output_filebase='/path/to/train_dataset.record'
  135. with contextlib2.ExitStack() as tf_record_close_stack:
  136. output_tfrecords = tf_record_creation_util.open_sharded_output_tfrecords(
  137. tf_record_close_stack, output_filebase, num_shards)
  138. for index, example in examples:
  139. tf_example = create_tf_example(example)
  140. output_shard_index = index % num_shards
  141. output_tfrecords[output_shard_index].write(tf_example.SerializeToString())
  142. ```
  143. This will produce the following output files
  144. ```bash
  145. /path/to/train_dataset.record-00000-00010
  146. /path/to/train_dataset.record-00001-00010
  147. ...
  148. /path/to/train_dataset.record-00009-00010
  149. ```
  150. which can then be used in the config file as below.
  151. ```bash
  152. tf_record_input_reader {
  153. input_path: "/path/to/train_dataset.record-?????-of-00010"
  154. }
  155. ```