You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

174 lines
7.9 KiB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
  1. from flask import Flask, request, Response
  2. from flask_restful import Resource, Api
  3. import os
  4. from object_detection.utils import label_map_util
  5. from object_detection.utils import visualization_utils as vis_util
  6. from object_detection.utils import ops as utils_ops
  7. from PIL import Image
  8. import base64
  9. import io
  10. import json
  11. import re
  12. import tensorflow as tf
  13. import sys,getpass
  14. import numpy as np
  15. from flask import Flask, send_from_directory
  16. from flask_restful import Api
  17. from flask_cors import CORS, cross_origin
  18. app = Flask(__name__)
  19. api = Api(app)
  20. app.config['SECRET_KEY'] = 'the quick brown fox jumps over the lazy dog'
  21. app.config['CORS_HEADERS'] = 'Content-Type'
  22. cors = CORS(app, resources={r"/foo": {"origins": "*"}})
  23. switches = {"coco":1, "damage":1}
  24. COCO_MODEL_NAME = "rfcn_resnet101_coco_2018_01_28"
  25. PATH_TO_FROZEN_COCO_GRAPH = 'modules/'+COCO_MODEL_NAME + '/frozen_inference_graph.pb'
  26. PATH_TO_FROZEN_DAMAGE_GRAPH = 'modules/trainedModels/ssd_mobilenet_RoadDamageDetector.pb'
  27. linux_def = {"detection_boxes":[(106, 188, 480, 452)],"detection_scores":[0.99],"detection_classes":[1]}
  28. detection_graph_coco = None
  29. detection_graph_damage = None
  30. if getpass.getuser() == "tedankara":
  31. detection_graph_coco = tf.Graph()
  32. detection_graph_damage = tf.Graph()
  33. with detection_graph_coco.as_default():
  34. od_graph_def = tf.GraphDef()
  35. with tf.gfile.GFile(PATH_TO_FROZEN_COCO_GRAPH, 'rb') as fid:
  36. serialized_graph = fid.read()
  37. od_graph_def.ParseFromString(serialized_graph)
  38. tf.import_graph_def(od_graph_def, name='')
  39. with detection_graph_damage.as_default():
  40. od_graph_def = tf.GraphDef()
  41. with tf.gfile.GFile(PATH_TO_FROZEN_DAMAGE_GRAPH, 'rb') as fid:
  42. serialized_graph = fid.read()
  43. od_graph_def.ParseFromString(serialized_graph)
  44. tf.import_graph_def(od_graph_def, name='')
  45. def load_image_into_numpy_array(image):
  46. (im_width, im_height) = image.size
  47. return np.array(image.getdata()).reshape(
  48. (im_height, im_width, 3)).astype(np.uint8)
  49. def decode_base64(data, altchars=b'+/'):
  50. """Decode base64, padding being optional.
  51. :param data: Base64 data as an ASCII byte string
  52. :returns: The decoded byte string.
  53. """
  54. data = re.sub(rb'[^a-zA-Z0-9%s]+' % altchars, b'', bytes(data,"utf-8")) # normalize
  55. missing_padding = len(data) % 4
  56. if missing_padding:
  57. data += b'='* (4 - missing_padding)
  58. return base64.b64decode(data, altchars)
  59. def run_inference_for_single_image(image, graph,type):
  60. global switches
  61. global sess_coco
  62. global sess_damage
  63. if not getpass.getuser() == "tedankara":
  64. return linux_def
  65. with graph.as_default():
  66. if(switches[type]):
  67. if type == "coco":
  68. sess_coco = tf.Session()
  69. elif type == "damage":
  70. sess_damage = tf.Session()
  71. switches[type] = 0
  72. if type == "coco":
  73. ops = tf.get_default_graph().get_operations()
  74. all_tensor_names = {output.name for op in ops for output in op.outputs}
  75. tensor_dict = {}
  76. for key in [
  77. 'num_detections', 'detection_boxes', 'detection_scores',
  78. 'detection_classes', 'detection_masks'
  79. ]:
  80. tensor_name = key + ':0'
  81. if tensor_name in all_tensor_names:
  82. tensor_dict[key] = tf.get_default_graph().get_tensor_by_name(
  83. tensor_name)
  84. if 'detection_masks' in tensor_dict:
  85. # The following processing is only for single image
  86. detection_boxes = tf.squeeze(tensor_dict['detection_boxes'], [0])
  87. detection_masks = tf.squeeze(tensor_dict['detection_masks'], [0])
  88. # Reframe is required to translate mask from box coordinates to image coordinates and fit the image size.
  89. real_num_detection = tf.cast(tensor_dict['num_detections'][0], tf.int32)
  90. detection_boxes = tf.slice(detection_boxes, [0, 0], [real_num_detection, -1])
  91. detection_masks = tf.slice(detection_masks, [0, 0, 0], [real_num_detection, -1, -1])
  92. detection_masks_reframed = utils_ops.reframe_box_masks_to_image_masks(
  93. detection_masks, detection_boxes, image.shape[1], image.shape[2])
  94. detection_masks_reframed = tf.cast(
  95. tf.greater(detection_masks_reframed, 0.5), tf.uint8)
  96. # Follow the convention by adding back the batch dimension
  97. tensor_dict['detection_masks'] = tf.expand_dims(
  98. detection_masks_reframed, 0)
  99. image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0')
  100. # Run inference
  101. output_dict = sess_coco.run(tensor_dict,
  102. feed_dict={image_tensor: image})
  103. # all outputs are float32 numpy arrays, so convert types as appropriate
  104. output_dict['num_detections'] = int(output_dict['num_detections'][0])
  105. output_dict['detection_classes'] = output_dict[
  106. 'detection_classes'][0].astype(np.int64)
  107. output_dict['detection_boxes'] = output_dict['detection_boxes'][0]
  108. output_dict['detection_scores'] = output_dict['detection_scores'][0]
  109. if 'detection_masks' in output_dict:
  110. output_dict['detection_masks'] = output_dict['detection_masks'][0]
  111. elif type=="damage":
  112. image_tensor = graph.get_tensor_by_name('image_tensor:0')
  113. # Each box represents a part of the image where a particular object was detected.
  114. detection_boxes = graph.get_tensor_by_name('detection_boxes:0')
  115. # Each score represent how level of confidence for each of the objects.
  116. # Score is shown on the result image, together with the class label.
  117. detection_scores = graph.get_tensor_by_name('detection_scores:0')
  118. detection_classes = graph.get_tensor_by_name('detection_classes:0')
  119. num_detections = graph.get_tensor_by_name('num_detections:0')
  120. # Actual detection.
  121. (boxes, scores, classes, num) = sess_damage.run(
  122. [detection_boxes, detection_scores, detection_classes, num_detections],
  123. feed_dict={image_tensor: image})
  124. output_dict = {'detection_classes': np.squeeze(classes).astype(np.int32), 'detection_scores': np.squeeze(scores),'detection_boxes': np.squeeze(boxes)}
  125. return output_dict
  126. class Process(Resource):
  127. def post(self):
  128. base64_img = request.form['img']
  129. image = Image.open(io.BytesIO(decode_base64(base64_img)))
  130. type = request.form["type"]
  131. image_np = load_image_into_numpy_array(image)
  132. image_np_expanded = np.expand_dims(image_np, axis=0)
  133. if type == "coco":
  134. output_dict = run_inference_for_single_image(image_np_expanded, detection_graph_coco,type)
  135. elif type == "damage":
  136. output_dict = run_inference_for_single_image(image_np_expanded, detection_graph_damage,type)
  137. if getpass.getuser() == "tedankara":
  138. output_dict["detection_boxes"] = output_dict["detection_boxes"].tolist()
  139. output_dict["detection_scores"] = output_dict["detection_scores"].tolist()
  140. output_dict["detection_classes"] = output_dict["detection_classes"].tolist()
  141. return output_dict
  142. class NumpyEncoder(json.JSONEncoder):
  143. def default(self, obj):
  144. if isinstance(obj, np.ndarray):
  145. return obj.tolist()
  146. return json.JSONEncoder.default(self, obj)
  147. if __name__ == '__main__':
  148. context = ('encryption/mycity.crt', 'encryption/mycity-decrypted.key')
  149. api.add_resource(Process, '/ai', '/ai/')
  150. app.run(host='0.0.0.0', port=5001, ssl_context=context, debug=False)