--- a +++ b/main.py @@ -0,0 +1,297 @@ +<<<<<<< HEAD +import os +import cv2 +import numpy as np +from flask import Flask, render_template, request, jsonify, send_from_directory +import torch +from torchvision import transforms, models +from PIL import Image +from pytorch_grad_cam import GradCAM +from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image +import uuid + + +app = Flask(__name__) + +# Class names +class_names = ['Covid-19', 'Emphysema', 'Healthy', 'Pneumonia', 'Tuberculosis', 'Random'] + +# Load the pre-trained ResNet101 model +model = models.resnet101(pretrained=False) +model.fc = torch.nn.Linear(in_features=2048, out_features=len(class_names)) +checkpoint_path = "resnet101_state_dict.pth" +model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu'))) +model.eval() + +# Image transformations +transform = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), +]) + +# Function to preprocess image +def preprocess_image(image_path): + image = Image.open(image_path).convert('RGB') + return transform(image).unsqueeze(0) + +# Generate visualizations for Grad-CAM and other images +def generate_visualizations(image_path): + original_image = cv2.imread(image_path) + original_image_rgb = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB) + + # Grayscale image + grayscale_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2GRAY) + + # Histogram equalized image + equalized_image = cv2.equalizeHist(grayscale_image) + + # Edge detection result + edges_image = cv2.Canny(grayscale_image, 50, 150) + + # Segmented image + _, segmented_image = cv2.threshold(grayscale_image, 127, 255, cv2.THRESH_BINARY) + + # Grad-CAM Visualization + input_tensor = preprocess_image(image_path) + target_layer = model.layer4[-1] # Last layer of ResNet101 + cam = GradCAM(model=model, target_layers=[target_layer]) + grayscale_cam = cam(input_tensor=input_tensor)[0] + + # Normalize the Grad-CAM output to range [0, 1] + grayscale_cam = np.maximum(grayscale_cam, 0) + grayscale_cam = grayscale_cam / np.max(grayscale_cam) + + # Apply the Grad-CAM heatmap on the image with a red color map + input_image_rgb_resized = cv2.resize(original_image_rgb, (224, 224)) + + # Using a custom colormap (ensure red regions are highlighted) + heatmap = cv2.applyColorMap(np.uint8(255 * grayscale_cam), cv2.COLORMAP_JET) + grad_cam_image = cv2.addWeighted(input_image_rgb_resized, 0.7, heatmap, 0.3, 0) + + # Mask for ROI extraction (high-confidence areas are those with grayscale_cam > 0.5) + roi_mask = (grayscale_cam > 0.5).astype(np.uint8) + roi = cv2.bitwise_and(input_image_rgb_resized, input_image_rgb_resized, mask=roi_mask) + + # Generate unique file name + file_id = str(uuid.uuid4()) + + # Save images to server and return file paths + visualization_paths = { + "original": f'/uploads/{file_id}_original.png', + "grayscale": f'/uploads/{file_id}_grayscale.png', + "equalized": f'/uploads/{file_id}_equalized.png', + "edges": f'/uploads/{file_id}_edges.png', + "segmented": f'/uploads/{file_id}_segmented.png', + "grad_cam": f'/uploads/{file_id}_grad_cam.png', + "roi": f'/uploads/{file_id}_roi.png' + } + + # Create uploads directory if it doesn't exist + os.makedirs('uploads', exist_ok=True) + + # Save the visualizations as images + cv2.imwrite(f'uploads/{file_id}_original.png', original_image) + cv2.imwrite(f'uploads/{file_id}_grayscale.png', grayscale_image) + cv2.imwrite(f'uploads/{file_id}_equalized.png', equalized_image) + cv2.imwrite(f'uploads/{file_id}_edges.png', edges_image) + cv2.imwrite(f'uploads/{file_id}_segmented.png', segmented_image) + cv2.imwrite(f'uploads/{file_id}_grad_cam.png', grad_cam_image) + cv2.imwrite(f'uploads/{file_id}_roi.png', roi) + + return visualization_paths + +@app.route('/') +def index(): + return render_template('index.html') + +@app.route('/predict', methods=['POST']) +def predict(): + if 'file' not in request.files: + return jsonify({"error": "No file part"}), 400 + + file = request.files['file'] + if file.filename == '': + return jsonify({"error": "No selected file"}), 400 + + image_id = str(uuid.uuid4()) + file_path = os.path.join('uploads', f'{image_id}.png') + os.makedirs('uploads', exist_ok=True) + file.save(file_path) + + input_tensor = preprocess_image(file_path) + with torch.no_grad(): + output = model(input_tensor) + predicted_class = output.argmax(dim=1).item() + confidence_score = torch.nn.functional.softmax(output, dim=1)[0, predicted_class].item() * 100 + + result = { + "predicted_class": class_names[predicted_class], + "confidence_score": f"{confidence_score:.2f}%", + } + + if class_names[predicted_class] == "Random": + result["message"] = "Sorry... You inserted a Non X-ray image. Please try again with a chest x-ray image. Thank you." + result["visualizations"] = None + else: + visualizations = generate_visualizations(file_path) + result["visualizations"] = visualizations + + return jsonify(result) + +# To serve the images from the 'uploads' folder correctly +@app.route('/uploads/<filename>') +def uploaded_file(filename): + return send_from_directory('uploads', filename) + +if __name__ == '__main__': + app.run(debug=True) +======= +import os +import cv2 +import numpy as np +from flask import Flask, render_template, request, jsonify, send_from_directory +import torch +from torchvision import transforms, models +from PIL import Image +from pytorch_grad_cam import GradCAM +from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image +import uuid + + +app = Flask(__name__) + +# Class names +class_names = ['Covid-19', 'Emphysema', 'Healthy', 'Pneumonia', 'Tuberculosis', 'Random'] + +# Load the pre-trained ResNet101 model +model = models.resnet101(pretrained=False) +model.fc = torch.nn.Linear(in_features=2048, out_features=len(class_names)) +checkpoint_path = "resnet101_state_dict.pth" +model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu'))) +model.eval() + +# Image transformations +transform = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), +]) + +# Function to preprocess image +def preprocess_image(image_path): + image = Image.open(image_path).convert('RGB') + return transform(image).unsqueeze(0) + +# Generate visualizations for Grad-CAM and other images +def generate_visualizations(image_path): + original_image = cv2.imread(image_path) + original_image_rgb = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB) + + # Grayscale image + grayscale_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2GRAY) + + # Histogram equalized image + equalized_image = cv2.equalizeHist(grayscale_image) + + # Edge detection result + edges_image = cv2.Canny(grayscale_image, 50, 150) + + # Segmented image + _, segmented_image = cv2.threshold(grayscale_image, 127, 255, cv2.THRESH_BINARY) + + # Grad-CAM Visualization + input_tensor = preprocess_image(image_path) + target_layer = model.layer4[-1] # Last layer of ResNet101 + cam = GradCAM(model=model, target_layers=[target_layer]) + grayscale_cam = cam(input_tensor=input_tensor)[0] + + # Normalize the Grad-CAM output to range [0, 1] + grayscale_cam = np.maximum(grayscale_cam, 0) + grayscale_cam = grayscale_cam / np.max(grayscale_cam) + + # Apply the Grad-CAM heatmap on the image with a red color map + input_image_rgb_resized = cv2.resize(original_image_rgb, (224, 224)) + + # Using a custom colormap (ensure red regions are highlighted) + heatmap = cv2.applyColorMap(np.uint8(255 * grayscale_cam), cv2.COLORMAP_JET) + grad_cam_image = cv2.addWeighted(input_image_rgb_resized, 0.7, heatmap, 0.3, 0) + + # Mask for ROI extraction (high-confidence areas are those with grayscale_cam > 0.5) + roi_mask = (grayscale_cam > 0.5).astype(np.uint8) + roi = cv2.bitwise_and(input_image_rgb_resized, input_image_rgb_resized, mask=roi_mask) + + # Generate unique file name + file_id = str(uuid.uuid4()) + + # Save images to server and return file paths + visualization_paths = { + "original": f'/uploads/{file_id}_original.png', + "grayscale": f'/uploads/{file_id}_grayscale.png', + "equalized": f'/uploads/{file_id}_equalized.png', + "edges": f'/uploads/{file_id}_edges.png', + "segmented": f'/uploads/{file_id}_segmented.png', + "grad_cam": f'/uploads/{file_id}_grad_cam.png', + "roi": f'/uploads/{file_id}_roi.png' + } + + # Create uploads directory if it doesn't exist + os.makedirs('uploads', exist_ok=True) + + # Save the visualizations as images + cv2.imwrite(f'uploads/{file_id}_original.png', original_image) + cv2.imwrite(f'uploads/{file_id}_grayscale.png', grayscale_image) + cv2.imwrite(f'uploads/{file_id}_equalized.png', equalized_image) + cv2.imwrite(f'uploads/{file_id}_edges.png', edges_image) + cv2.imwrite(f'uploads/{file_id}_segmented.png', segmented_image) + cv2.imwrite(f'uploads/{file_id}_grad_cam.png', grad_cam_image) + cv2.imwrite(f'uploads/{file_id}_roi.png', roi) + + return visualization_paths + +@app.route('/') +def index(): + return render_template('index.html') + +@app.route('/predict', methods=['POST']) +def predict(): + if 'file' not in request.files: + return jsonify({"error": "No file part"}), 400 + + file = request.files['file'] + if file.filename == '': + return jsonify({"error": "No selected file"}), 400 + + image_id = str(uuid.uuid4()) + file_path = os.path.join('uploads', f'{image_id}.png') + os.makedirs('uploads', exist_ok=True) + file.save(file_path) + + input_tensor = preprocess_image(file_path) + with torch.no_grad(): + output = model(input_tensor) + predicted_class = output.argmax(dim=1).item() + confidence_score = torch.nn.functional.softmax(output, dim=1)[0, predicted_class].item() * 100 + + result = { + "predicted_class": class_names[predicted_class], + "confidence_score": f"{confidence_score:.2f}%", + } + + if class_names[predicted_class] == "Random": + result["message"] = "Sorry... You inserted a Non X-ray image. Please try again with a chest x-ray image. Thank you." + result["visualizations"] = None + else: + visualizations = generate_visualizations(file_path) + result["visualizations"] = visualizations + + return jsonify(result) + +# To serve the images from the 'uploads' folder correctly +@app.route('/uploads/<filename>') +def uploaded_file(filename): + return send_from_directory('uploads', filename) + +if __name__ == '__main__': + app.run(debug=True) +>>>>>>> 5676c55 (Add ResNet101 model using Git LFS)