a b/app.py
1
"""
2
=============================================================================================
3
Project     : Chest X-Ray Pathology Detection and Localization using Deep Learning
4
Author Name : Rammuni Ravidu Suien Silva
5
UoW No      : 16267097
6
IIT No      : 2016134
7
Module      : Final Year Project 20/21
8
Supervisor  : Mr Pumudu Fernando
9
10
Prototype   : Web Interface - BackEnd [Draft: .v01]
11
University of Westminster, UK || IIT Sri Lanka
12
=============================================================================================
13
"""
14
import json
15
import os
16
from datetime import datetime
17
18
import numpy as np
19
# Flask Imports
20
from flask import Flask, request, render_template
21
from flask import send_file
22
from flask_jsglue import JSGlue
23
# Tensorflow Keras imports
24
from tensorflow.keras.models import load_model
25
# For secure src links
26
from werkzeug.utils import secure_filename
27
28
# System Library import
29
from lab_cxr_scripts.lab_cxr import CXRPrediction, CXRLocalization
30
31
# Model 0, 2 :- xray_labels_set[0] || Model 1 :- xray_labels_set[1]
32
xray_labels_set = [["Enlarged Cardiomediastinum", "Cardiomegaly", "Lung Lesion", "Lung Opacity", "Edema",
33
                    "Consolidation", "Pneumonia", "Atelectasis", "Pneumothorax", "Pleural Effusion",
34
                    "Pleural Other", "Fracture", "Support Devices"],
35
                   ["Nodule", "Cardiomegaly", "Emphysema", "Fibrosis", "Edema", "Consolidation", "Pneumonia",
36
                    "Atelectasis", "Pneumothorax", "Pleural Effusion", "Mass", "Infiltration", "Hernia",
37
                    "Plueral Thickening"]]
38
39
# Labels for classification tasks
40
xray_labels = xray_labels_set[0]
41
# Dependency pip install pyopenssl
42
# Flask Configs
43
app = Flask(__name__)
44
app.config['MAX_CONTENT_LENGTH'] = 20 * 1024 * 1024  # Request data limited to 20MB
45
jsglue = JSGlue(app)
46
47
# TODO: USER GUIDE
48
49
# model load
50
models = [[
51
    load_model('models/MIMIC/PAR-64-MODEL-MIMIC-FINAL-2.h5',
52
               custom_objects={'weighted_loss': CXRPrediction.get_weighted_loss(1, 1)}),
53
    load_model('models/MIMIC/PAR-128-MODEL-MIMIC-FINAL-2.h5',
54
               custom_objects={'weighted_loss': CXRPrediction.get_weighted_loss(1, 1)})
55
], [
56
    load_model('models/NIH/PAR-64-MODEL-FINAL-NIH-2.h5',
57
               custom_objects={'weighted_loss': CXRPrediction.get_weighted_loss(1, 1)}),
58
    load_model('models/NIH/PAR-128-MODEL-FINAL-NIH-2.h5',
59
               custom_objects={'weighted_loss': CXRPrediction.get_weighted_loss(1, 1)})
60
]]
61
model = models[0]
62
cur_cxr_hash = 'none'
63
64
"""
65
==================================================================================================================
66
                                            Web request functions
67
==================================================================================================================
68
"""
69
70
71
# Web page startup
72
@app.route('/')
73
def start_web():
74
    return render_template("index.html")
75
76
77
# CXR Image upload API
78
@app.route('/predict/<int:model_id>', methods=['GET', 'POST'])
79
def upload(model_id):
80
    if request.method == 'POST':
81
        print("Model ID", model_id)
82
83
        # Selecting Model and labels set
84
        global model, xray_labels
85
        model = models[model_id % len(models)]
86
        xray_labels = xray_labels_set[model_id % len(xray_labels_set)]
87
88
        global cur_cxr_hash
89
        preds = []
90
        file_count = len(request.files)
91
        if file_count > 8:
92
            return
93
        for file_num in range(file_count):
94
            # Getting image file from post request through the Web
95
            cxr_img_file = request.files['file_' + str(file_num)]
96
            # Generating Hash of the image file
97
            hashed_filename = CXRPrediction.hash_cxr(cxr_img_file)
98
            print(hashed_filename)
99
            cur_cxr_hash = hashed_filename
100
101
            # Saving the CXR image to uploads
102
            cxr_img_path = os.path.dirname(__file__)
103
            file_path = os.path.join(
104
                cxr_img_path, 'uploads', secure_filename(hashed_filename))
105
            cxr_img_file.save(file_path)
106
107
            # Detection results calculation
108
            preds.append(np.array(CXRPrediction.model_predict(file_path, model)[0]).tolist())
109
110
        # Final results calculation considering the results of all the uploaded images
111
        final_preds = np.round(np.multiply(np.mean(preds, axis=0), 100), 2)
112
        final_preds_max = np.round(np.multiply(np.max(preds, axis=0), 100), 2)
113
        final_preds_min = np.round(np.multiply(np.min(preds, axis=0), 100), 2)
114
115
        print(final_preds)
116
        # Creating the detection results dictionary/ JSON
117
        predictions_dict = {}
118
        for i in range(0, len(xray_labels)):
119
            det_rate_str = str(final_preds[i]) + "% (" + str(final_preds_max[i]) + "% - " + str(
120
                final_preds_min[i]) + "%)"
121
            predictions_dict[xray_labels[i]] = det_rate_str
122
123
        # Creating detection result JSON to be sent
124
        json_predictions = json.dumps(predictions_dict, indent=4)
125
126
        result = json_predictions
127
        return result
128
    return None
129
130
131
@app.route('/localize')
132
def localization():  # Localization API
133
    global cur_cxr_hash
134
    start = datetime.now()
135
    filepath = 'localizations/' + cur_cxr_hash.split('.')[0]
136
137
    if os.path.exists(filepath):
138
        file_count = len([name for name in os.listdir(filepath) if os.path.isfile(os.path.join(filepath, name))])
139
        if not file_count == len(xray_labels):
140
            # If the localized img is already there no need to re-process
141
            CXRLocalization.create_cxr_localization_heatmap(cur_cxr_hash, model[len(model) - 1], xray_labels)
142
    else:
143
        # Calling Localization Function
144
        CXRLocalization.create_cxr_localization_heatmap(cur_cxr_hash, model[len(model) - 1], xray_labels)
145
146
    print(datetime.now() - start)
147
    return str(len(xray_labels))  # Returning the localized labels
148
149
150
# Function for sending the localized CXR image
151
@app.route('/get_cxr_detect_img/<int:pathology_id>')
152
def get_cxr_detect_img(pathology_id):
153
    print(pathology_id)
154
    global cur_cxr_hash
155
    localized_image_name = xray_labels[pathology_id] + '-localizedHeatmap-' + cur_cxr_hash
156
    filepath = 'localizations/' + cur_cxr_hash.split('.')[0] + '/'
157
    return send_file(filepath + localized_image_name, mimetype='image/jpg')
158
159
160
# Function for getting symptoms
161
@app.route('/get_symptoms')
162
def get_symptoms():
163
    return send_file('static/files/Symptoms.json', mimetype='application/json')
164
165
166
print("Server Running...")
167
if __name__ == '__main__':
168
    app.run(debug=True)  # Debugging