a b/main.py
1
<<<<<<< HEAD
2
import os
3
import cv2
4
import numpy as np
5
from flask import Flask, render_template, request, jsonify, send_from_directory
6
import torch
7
from torchvision import transforms, models
8
from PIL import Image
9
from pytorch_grad_cam import GradCAM
10
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
11
import uuid
12
13
14
app = Flask(__name__)
15
16
# Class names
17
class_names = ['Covid-19', 'Emphysema', 'Healthy', 'Pneumonia', 'Tuberculosis', 'Random']
18
19
# Load the pre-trained ResNet101 model
20
model = models.resnet101(pretrained=False)
21
model.fc = torch.nn.Linear(in_features=2048, out_features=len(class_names))
22
checkpoint_path = "resnet101_state_dict.pth"  
23
model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu')))
24
model.eval()
25
26
# Image transformations
27
transform = transforms.Compose([
28
    transforms.Resize((224, 224)),
29
    transforms.ToTensor(),
30
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
31
])
32
33
# Function to preprocess image
34
def preprocess_image(image_path):
35
    image = Image.open(image_path).convert('RGB')
36
    return transform(image).unsqueeze(0)
37
38
# Generate visualizations for Grad-CAM and other images
39
def generate_visualizations(image_path):
40
    original_image = cv2.imread(image_path)
41
    original_image_rgb = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
42
43
    # Grayscale image
44
    grayscale_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2GRAY)
45
46
    # Histogram equalized image
47
    equalized_image = cv2.equalizeHist(grayscale_image)
48
49
    # Edge detection result
50
    edges_image = cv2.Canny(grayscale_image, 50, 150)
51
52
    # Segmented image
53
    _, segmented_image = cv2.threshold(grayscale_image, 127, 255, cv2.THRESH_BINARY)
54
55
    # Grad-CAM Visualization
56
    input_tensor = preprocess_image(image_path)
57
    target_layer = model.layer4[-1]  # Last layer of ResNet101
58
    cam = GradCAM(model=model, target_layers=[target_layer])
59
    grayscale_cam = cam(input_tensor=input_tensor)[0]
60
61
    # Normalize the Grad-CAM output to range [0, 1]
62
    grayscale_cam = np.maximum(grayscale_cam, 0)
63
    grayscale_cam = grayscale_cam / np.max(grayscale_cam)
64
65
    # Apply the Grad-CAM heatmap on the image with a red color map
66
    input_image_rgb_resized = cv2.resize(original_image_rgb, (224, 224))
67
68
    # Using a custom colormap (ensure red regions are highlighted)
69
    heatmap = cv2.applyColorMap(np.uint8(255 * grayscale_cam), cv2.COLORMAP_JET)
70
    grad_cam_image = cv2.addWeighted(input_image_rgb_resized, 0.7, heatmap, 0.3, 0)
71
72
    # Mask for ROI extraction (high-confidence areas are those with grayscale_cam > 0.5)
73
    roi_mask = (grayscale_cam > 0.5).astype(np.uint8)
74
    roi = cv2.bitwise_and(input_image_rgb_resized, input_image_rgb_resized, mask=roi_mask)
75
76
    # Generate unique file name
77
    file_id = str(uuid.uuid4())
78
79
    # Save images to server and return file paths
80
    visualization_paths = {
81
        "original": f'/uploads/{file_id}_original.png',
82
        "grayscale": f'/uploads/{file_id}_grayscale.png',
83
        "equalized": f'/uploads/{file_id}_equalized.png',
84
        "edges": f'/uploads/{file_id}_edges.png',
85
        "segmented": f'/uploads/{file_id}_segmented.png',
86
        "grad_cam": f'/uploads/{file_id}_grad_cam.png',
87
        "roi": f'/uploads/{file_id}_roi.png'
88
    }
89
90
    # Create uploads directory if it doesn't exist
91
    os.makedirs('uploads', exist_ok=True)
92
93
    # Save the visualizations as images
94
    cv2.imwrite(f'uploads/{file_id}_original.png', original_image)
95
    cv2.imwrite(f'uploads/{file_id}_grayscale.png', grayscale_image)
96
    cv2.imwrite(f'uploads/{file_id}_equalized.png', equalized_image)
97
    cv2.imwrite(f'uploads/{file_id}_edges.png', edges_image)
98
    cv2.imwrite(f'uploads/{file_id}_segmented.png', segmented_image)
99
    cv2.imwrite(f'uploads/{file_id}_grad_cam.png', grad_cam_image)
100
    cv2.imwrite(f'uploads/{file_id}_roi.png', roi)
101
102
    return visualization_paths
103
104
@app.route('/')
105
def index():
106
    return render_template('index.html')
107
108
@app.route('/predict', methods=['POST'])
109
def predict():
110
    if 'file' not in request.files:
111
        return jsonify({"error": "No file part"}), 400
112
113
    file = request.files['file']
114
    if file.filename == '':
115
        return jsonify({"error": "No selected file"}), 400
116
117
    image_id = str(uuid.uuid4())
118
    file_path = os.path.join('uploads', f'{image_id}.png')
119
    os.makedirs('uploads', exist_ok=True)
120
    file.save(file_path)
121
122
    input_tensor = preprocess_image(file_path)
123
    with torch.no_grad():
124
        output = model(input_tensor)
125
        predicted_class = output.argmax(dim=1).item()
126
        confidence_score = torch.nn.functional.softmax(output, dim=1)[0, predicted_class].item() * 100
127
128
    result = {
129
        "predicted_class": class_names[predicted_class],
130
        "confidence_score": f"{confidence_score:.2f}%",
131
    }
132
133
    if class_names[predicted_class] == "Random":
134
        result["message"] = "Sorry... You inserted a Non X-ray image. Please try again with a chest x-ray image. Thank you."
135
        result["visualizations"] = None
136
    else:
137
        visualizations = generate_visualizations(file_path)
138
        result["visualizations"] = visualizations
139
140
    return jsonify(result)
141
142
# To serve the images from the 'uploads' folder correctly
143
@app.route('/uploads/<filename>')
144
def uploaded_file(filename):
145
    return send_from_directory('uploads', filename)
146
147
if __name__ == '__main__':
148
    app.run(debug=True)
149
=======
150
import os
151
import cv2
152
import numpy as np
153
from flask import Flask, render_template, request, jsonify, send_from_directory
154
import torch
155
from torchvision import transforms, models
156
from PIL import Image
157
from pytorch_grad_cam import GradCAM
158
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
159
import uuid
160
161
162
app = Flask(__name__)
163
164
# Class names
165
class_names = ['Covid-19', 'Emphysema', 'Healthy', 'Pneumonia', 'Tuberculosis', 'Random']
166
167
# Load the pre-trained ResNet101 model
168
model = models.resnet101(pretrained=False)
169
model.fc = torch.nn.Linear(in_features=2048, out_features=len(class_names))
170
checkpoint_path = "resnet101_state_dict.pth"  
171
model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu')))
172
model.eval()
173
174
# Image transformations
175
transform = transforms.Compose([
176
    transforms.Resize((224, 224)),
177
    transforms.ToTensor(),
178
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
179
])
180
181
# Function to preprocess image
182
def preprocess_image(image_path):
183
    image = Image.open(image_path).convert('RGB')
184
    return transform(image).unsqueeze(0)
185
186
# Generate visualizations for Grad-CAM and other images
187
def generate_visualizations(image_path):
188
    original_image = cv2.imread(image_path)
189
    original_image_rgb = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
190
191
    # Grayscale image
192
    grayscale_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2GRAY)
193
194
    # Histogram equalized image
195
    equalized_image = cv2.equalizeHist(grayscale_image)
196
197
    # Edge detection result
198
    edges_image = cv2.Canny(grayscale_image, 50, 150)
199
200
    # Segmented image
201
    _, segmented_image = cv2.threshold(grayscale_image, 127, 255, cv2.THRESH_BINARY)
202
203
    # Grad-CAM Visualization
204
    input_tensor = preprocess_image(image_path)
205
    target_layer = model.layer4[-1]  # Last layer of ResNet101
206
    cam = GradCAM(model=model, target_layers=[target_layer])
207
    grayscale_cam = cam(input_tensor=input_tensor)[0]
208
209
    # Normalize the Grad-CAM output to range [0, 1]
210
    grayscale_cam = np.maximum(grayscale_cam, 0)
211
    grayscale_cam = grayscale_cam / np.max(grayscale_cam)
212
213
    # Apply the Grad-CAM heatmap on the image with a red color map
214
    input_image_rgb_resized = cv2.resize(original_image_rgb, (224, 224))
215
216
    # Using a custom colormap (ensure red regions are highlighted)
217
    heatmap = cv2.applyColorMap(np.uint8(255 * grayscale_cam), cv2.COLORMAP_JET)
218
    grad_cam_image = cv2.addWeighted(input_image_rgb_resized, 0.7, heatmap, 0.3, 0)
219
220
    # Mask for ROI extraction (high-confidence areas are those with grayscale_cam > 0.5)
221
    roi_mask = (grayscale_cam > 0.5).astype(np.uint8)
222
    roi = cv2.bitwise_and(input_image_rgb_resized, input_image_rgb_resized, mask=roi_mask)
223
224
    # Generate unique file name
225
    file_id = str(uuid.uuid4())
226
227
    # Save images to server and return file paths
228
    visualization_paths = {
229
        "original": f'/uploads/{file_id}_original.png',
230
        "grayscale": f'/uploads/{file_id}_grayscale.png',
231
        "equalized": f'/uploads/{file_id}_equalized.png',
232
        "edges": f'/uploads/{file_id}_edges.png',
233
        "segmented": f'/uploads/{file_id}_segmented.png',
234
        "grad_cam": f'/uploads/{file_id}_grad_cam.png',
235
        "roi": f'/uploads/{file_id}_roi.png'
236
    }
237
238
    # Create uploads directory if it doesn't exist
239
    os.makedirs('uploads', exist_ok=True)
240
241
    # Save the visualizations as images
242
    cv2.imwrite(f'uploads/{file_id}_original.png', original_image)
243
    cv2.imwrite(f'uploads/{file_id}_grayscale.png', grayscale_image)
244
    cv2.imwrite(f'uploads/{file_id}_equalized.png', equalized_image)
245
    cv2.imwrite(f'uploads/{file_id}_edges.png', edges_image)
246
    cv2.imwrite(f'uploads/{file_id}_segmented.png', segmented_image)
247
    cv2.imwrite(f'uploads/{file_id}_grad_cam.png', grad_cam_image)
248
    cv2.imwrite(f'uploads/{file_id}_roi.png', roi)
249
250
    return visualization_paths
251
252
@app.route('/')
253
def index():
254
    return render_template('index.html')
255
256
@app.route('/predict', methods=['POST'])
257
def predict():
258
    if 'file' not in request.files:
259
        return jsonify({"error": "No file part"}), 400
260
261
    file = request.files['file']
262
    if file.filename == '':
263
        return jsonify({"error": "No selected file"}), 400
264
265
    image_id = str(uuid.uuid4())
266
    file_path = os.path.join('uploads', f'{image_id}.png')
267
    os.makedirs('uploads', exist_ok=True)
268
    file.save(file_path)
269
270
    input_tensor = preprocess_image(file_path)
271
    with torch.no_grad():
272
        output = model(input_tensor)
273
        predicted_class = output.argmax(dim=1).item()
274
        confidence_score = torch.nn.functional.softmax(output, dim=1)[0, predicted_class].item() * 100
275
276
    result = {
277
        "predicted_class": class_names[predicted_class],
278
        "confidence_score": f"{confidence_score:.2f}%",
279
    }
280
281
    if class_names[predicted_class] == "Random":
282
        result["message"] = "Sorry... You inserted a Non X-ray image. Please try again with a chest x-ray image. Thank you."
283
        result["visualizations"] = None
284
    else:
285
        visualizations = generate_visualizations(file_path)
286
        result["visualizations"] = visualizations
287
288
    return jsonify(result)
289
290
# To serve the images from the 'uploads' folder correctly
291
@app.route('/uploads/<filename>')
292
def uploaded_file(filename):
293
    return send_from_directory('uploads', filename)
294
295
if __name__ == '__main__':
296
    app.run(debug=True)
297
>>>>>>> 5676c55 (Add ResNet101 model using Git LFS)