You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

156 lines
7.1 KiB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
  1. from flask import Flask, request, Response
  2. from flask_restful import Resource, Api
  3. import os
  4. from object_detection.utils import label_map_util
  5. from object_detection.utils import visualization_utils as vis_util
  6. from object_detection.utils import ops as utils_ops
  7. from PIL import Image
  8. import base64
  9. import io
  10. import json
  11. import tensorflow as tf
  12. import sys
  13. import numpy as np
  14. from flask import Flask, send_from_directory
  15. from flask_restful import Api
  16. from flask_cors import CORS, cross_origin
  17. app = Flask(__name__)
  18. api = Api(app)
  19. app.config['SECRET_KEY'] = 'the quick brown fox jumps over the lazy dog'
  20. app.config['CORS_HEADERS'] = 'Content-Type'
  21. cors = CORS(app, resources={r"/foo": {"origins": "*"}})
  22. switches = {"coco":1, "damage":1}
  23. COCO_MODEL_NAME = "rfcn_resnet101_coco_2018_01_28"
  24. PATH_TO_FROZEN_COCO_GRAPH = 'modules/'+COCO_MODEL_NAME + '/frozen_inference_graph.pb'
  25. PATH_TO_FROZEN_DAMAGE_GRAPH = 'modules/trainedModels/ssd_mobilenet_RoadDamageDetector.pb'
  26. linux_def = {"detection_boxes":[(106, 188, 480, 452)],"detection_scores":[0.99],"detection_classes":[1]}
  27. detection_graph_coco = None
  28. detection_graph_damage = None
  29. if sys.platform == "win32":
  30. detection_graph_coco = tf.Graph()
  31. detection_graph_damage = tf.Graph()
  32. with detection_graph_coco.as_default():
  33. od_graph_def = tf.GraphDef()
  34. with tf.gfile.GFile(PATH_TO_FROZEN_COCO_GRAPH, 'rb') as fid:
  35. serialized_graph = fid.read()
  36. od_graph_def.ParseFromString(serialized_graph)
  37. tf.import_graph_def(od_graph_def, name='')
  38. with detection_graph_damage.as_default():
  39. od_graph_def = tf.GraphDef()
  40. with tf.gfile.GFile(PATH_TO_FROZEN_DAMAGE_GRAPH, 'rb') as fid:
  41. serialized_graph = fid.read()
  42. od_graph_def.ParseFromString(serialized_graph)
  43. tf.import_graph_def(od_graph_def, name='')
  44. def load_image_into_numpy_array(image):
  45. (im_width, im_height) = image.size
  46. return np.array(image.getdata()).reshape(
  47. (im_height, im_width, 3)).astype(np.uint8)
  48. def run_inference_for_single_image(image, graph,type):
  49. global switches
  50. global sess_coco
  51. global sess_damage
  52. if not sys.platform == "win32":
  53. return linux_def
  54. with graph.as_default():
  55. if(switches[type]):
  56. if type == "coco":
  57. sess_coco = tf.Session()
  58. elif type == "damage":
  59. sess_damage = tf.Session()
  60. switches[type] = 0
  61. if type == "coco":
  62. ops = tf.get_default_graph().get_operations()
  63. all_tensor_names = {output.name for op in ops for output in op.outputs}
  64. tensor_dict = {}
  65. for key in [
  66. 'num_detections', 'detection_boxes', 'detection_scores',
  67. 'detection_classes', 'detection_masks'
  68. ]:
  69. tensor_name = key + ':0'
  70. if tensor_name in all_tensor_names:
  71. tensor_dict[key] = tf.get_default_graph().get_tensor_by_name(
  72. tensor_name)
  73. if 'detection_masks' in tensor_dict:
  74. # The following processing is only for single image
  75. detection_boxes = tf.squeeze(tensor_dict['detection_boxes'], [0])
  76. detection_masks = tf.squeeze(tensor_dict['detection_masks'], [0])
  77. # Reframe is required to translate mask from box coordinates to image coordinates and fit the image size.
  78. real_num_detection = tf.cast(tensor_dict['num_detections'][0], tf.int32)
  79. detection_boxes = tf.slice(detection_boxes, [0, 0], [real_num_detection, -1])
  80. detection_masks = tf.slice(detection_masks, [0, 0, 0], [real_num_detection, -1, -1])
  81. detection_masks_reframed = utils_ops.reframe_box_masks_to_image_masks(
  82. detection_masks, detection_boxes, image.shape[1], image.shape[2])
  83. detection_masks_reframed = tf.cast(
  84. tf.greater(detection_masks_reframed, 0.5), tf.uint8)
  85. # Follow the convention by adding back the batch dimension
  86. tensor_dict['detection_masks'] = tf.expand_dims(
  87. detection_masks_reframed, 0)
  88. image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0')
  89. # Run inference
  90. output_dict = sess_coco.run(tensor_dict,
  91. feed_dict={image_tensor: image})
  92. # all outputs are float32 numpy arrays, so convert types as appropriate
  93. output_dict['num_detections'] = int(output_dict['num_detections'][0])
  94. output_dict['detection_classes'] = output_dict[
  95. 'detection_classes'][0].astype(np.int64)
  96. output_dict['detection_boxes'] = output_dict['detection_boxes'][0]
  97. output_dict['detection_scores'] = output_dict['detection_scores'][0]
  98. if 'detection_masks' in output_dict:
  99. output_dict['detection_masks'] = output_dict['detection_masks'][0]
  100. elif type=="damage":
  101. image_tensor = graph.get_tensor_by_name('image_tensor:0')
  102. # Each box represents a part of the image where a particular object was detected.
  103. detection_boxes = graph.get_tensor_by_name('detection_boxes:0')
  104. # Each score represent how level of confidence for each of the objects.
  105. # Score is shown on the result image, together with the class label.
  106. detection_scores = graph.get_tensor_by_name('detection_scores:0')
  107. detection_classes = graph.get_tensor_by_name('detection_classes:0')
  108. num_detections = graph.get_tensor_by_name('num_detections:0')
  109. # Actual detection.
  110. (boxes, scores, classes, num) = sess_damage.run(
  111. [detection_boxes, detection_scores, detection_classes, num_detections],
  112. feed_dict={image_tensor: image})
  113. output_dict = {'detection_classes': np.squeeze(classes).astype(np.int32), 'detection_scores': np.squeeze(scores)}
  114. return output_dict
  115. class Process(Resource):
  116. def post(self):
  117. base64_img = request.form['img']
  118. image = Image.open(io.BytesIO(base64.b64decode(base64_img)))
  119. type = request.form["type"]
  120. image_np = load_image_into_numpy_array(image)
  121. image_np_expanded = np.expand_dims(image_np, axis=0)
  122. if type == "coco":
  123. output_dict = run_inference_for_single_image(image_np_expanded, detection_graph_coco,type)
  124. elif type == "damage":
  125. output_dict = run_inference_for_single_image(image_np_expanded, detection_graph_damage,type)
  126. return json.dumps(output_dict,cls=NumpyEncoder)
  127. class NumpyEncoder(json.JSONEncoder):
  128. def default(self, obj):
  129. if isinstance(obj, np.ndarray):
  130. return obj.tolist()
  131. return json.JSONEncoder.default(self, obj)
  132. if __name__ == '__main__':
  133. context = ('encryption/mycity.crt', 'encryption/mycity-decrypted.key')
  134. api.add_resource(Process, '/ai', '/ai/')
  135. app.run(host='0.0.0.0', port=5001, ssl_context=context, debug=False)