You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

133 lines
6.3 KiB

6 years ago
  1. from flask import Flask, request, Response
  2. from flask_restful import Resource, Api
  3. import os
  4. from object_detection.utils import label_map_util
  5. from object_detection.utils import visualization_utils as vis_util
  6. from object_detection.utils import ops as utils_ops
  7. from PIL import Image
  8. import base64
  9. import io
  10. import json
  11. import tensorflow as tf
  12. import sys
  13. import numpy as np
  14. switches = {"coco":1, "damage":1}
  15. COCO_MODEL_NAME = "rfcn_resnet101_coco_2018_01_28"
  16. PATH_TO_FROZEN_COCO_GRAPH = 'modules/'+COCO_MODEL_NAME + '/frozen_inference_graph.pb'
  17. PATH_TO_FROZEN_DAMAGE_GRAPH = 'modules/trainedModels/ssd_mobilenet_RoadDamageDetector.pb'
  18. if sys.platform == "win32":
  19. detection_graph_coco = tf.Graph()
  20. detection_graph_damage = tf.Graph()
  21. with detection_graph_coco.as_default():
  22. od_graph_def = tf.GraphDef()
  23. with tf.gfile.GFile(PATH_TO_FROZEN_COCO_GRAPH, 'rb') as fid:
  24. serialized_graph = fid.read()
  25. od_graph_def.ParseFromString(serialized_graph)
  26. tf.import_graph_def(od_graph_def, name='')
  27. with detection_graph_damage.as_default():
  28. od_graph_def = tf.GraphDef()
  29. with tf.gfile.GFile(PATH_TO_FROZEN_DAMAGE_GRAPH, 'rb') as fid:
  30. serialized_graph = fid.read()
  31. od_graph_def.ParseFromString(serialized_graph)
  32. tf.import_graph_def(od_graph_def, name='')
  33. def load_image_into_numpy_array(image):
  34. (im_width, im_height) = image.size
  35. return np.array(image.getdata()).reshape(
  36. (im_height, im_width, 3)).astype(np.uint8)
  37. def run_inference_for_single_image(image, graph,type):
  38. global switches
  39. global sess_coco
  40. global sess_damage
  41. with graph.as_default():
  42. if(switches[type]):
  43. if type == "coco":
  44. sess_coco = tf.Session()
  45. elif type == "damage":
  46. sess_damage = tf.Session()
  47. switches[type] = 0
  48. if type == "coco":
  49. ops = tf.get_default_graph().get_operations()
  50. all_tensor_names = {output.name for op in ops for output in op.outputs}
  51. tensor_dict = {}
  52. for key in [
  53. 'num_detections', 'detection_boxes', 'detection_scores',
  54. 'detection_classes', 'detection_masks'
  55. ]:
  56. tensor_name = key + ':0'
  57. if tensor_name in all_tensor_names:
  58. tensor_dict[key] = tf.get_default_graph().get_tensor_by_name(
  59. tensor_name)
  60. if 'detection_masks' in tensor_dict:
  61. # The following processing is only for single image
  62. detection_boxes = tf.squeeze(tensor_dict['detection_boxes'], [0])
  63. detection_masks = tf.squeeze(tensor_dict['detection_masks'], [0])
  64. # Reframe is required to translate mask from box coordinates to image coordinates and fit the image size.
  65. real_num_detection = tf.cast(tensor_dict['num_detections'][0], tf.int32)
  66. detection_boxes = tf.slice(detection_boxes, [0, 0], [real_num_detection, -1])
  67. detection_masks = tf.slice(detection_masks, [0, 0, 0], [real_num_detection, -1, -1])
  68. detection_masks_reframed = utils_ops.reframe_box_masks_to_image_masks(
  69. detection_masks, detection_boxes, image.shape[1], image.shape[2])
  70. detection_masks_reframed = tf.cast(
  71. tf.greater(detection_masks_reframed, 0.5), tf.uint8)
  72. # Follow the convention by adding back the batch dimension
  73. tensor_dict['detection_masks'] = tf.expand_dims(
  74. detection_masks_reframed, 0)
  75. image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0')
  76. # Run inference
  77. output_dict = sess_coco.run(tensor_dict,
  78. feed_dict={image_tensor: image})
  79. # all outputs are float32 numpy arrays, so convert types as appropriate
  80. output_dict['num_detections'] = int(output_dict['num_detections'][0])
  81. output_dict['detection_classes'] = output_dict[
  82. 'detection_classes'][0].astype(np.int64)
  83. output_dict['detection_boxes'] = output_dict['detection_boxes'][0]
  84. output_dict['detection_scores'] = output_dict['detection_scores'][0]
  85. if 'detection_masks' in output_dict:
  86. output_dict['detection_masks'] = output_dict['detection_masks'][0]
  87. elif type=="damage":
  88. image_tensor = graph.get_tensor_by_name('image_tensor:0')
  89. # Each box represents a part of the image where a particular object was detected.
  90. detection_boxes = graph.get_tensor_by_name('detection_boxes:0')
  91. # Each score represent how level of confidence for each of the objects.
  92. # Score is shown on the result image, together with the class label.
  93. detection_scores = graph.get_tensor_by_name('detection_scores:0')
  94. detection_classes = graph.get_tensor_by_name('detection_classes:0')
  95. num_detections = graph.get_tensor_by_name('num_detections:0')
  96. # Actual detection.
  97. (boxes, scores, classes, num) = sess_damage.run(
  98. [detection_boxes, detection_scores, detection_classes, num_detections],
  99. feed_dict={image_tensor: image})
  100. output_dict = {'detection_classes': np.squeeze(classes).astype(np.int32), 'detection_scores': np.squeeze(scores)}
  101. return output_dict
  102. class Process(Resource):
  103. def post(self):
  104. base64_img = request.form['img']
  105. image = Image.open(io.BytesIO(base64.b64decode(base64_img)))
  106. type = request.form["type"]
  107. image_np = load_image_into_numpy_array(image)
  108. image_np_expanded = np.expand_dims(image_np, axis=0)
  109. if type == "coco":
  110. output_dict = run_inference_for_single_image(image_np_expanded, detection_graph_coco,type)
  111. elif type == "damage":
  112. output_dict = run_inference_for_single_image(image_np_expanded, detection_graph_damage,type)
  113. return json.dumps(output_dict,cls=NumpyEncoder)
  114. class NumpyEncoder(json.JSONEncoder):
  115. def default(self, obj):
  116. if isinstance(obj, np.ndarray):
  117. return obj.tolist()
  118. return json.JSONEncoder.default(self, obj)