Diff of /webapp.py [000000] .. [663859]

Switch to unified view

a b/webapp.py
1
import streamlit as st
2
import cv2
3
import numpy as np
4
import tempfile
5
import os
6
from ultralytics import YOLO
7
8
# Load the YOLOv8 model
9
try:
10
    model = YOLO('best.pt')  # Ensure the path to your trained YOLOv8 model weights is correct
11
except FileNotFoundError:
12
    st.error("Model weights file 'best.pt' not found. Please ensure the model file is in the correct directory.")
13
    st.stop()
14
15
# Predefined colors for each class
16
COLORS = {
17
    "high-pneumonia": [255, 0, 0],    # Red
18
    "low-pneumonia": [0, 255, 0],     # Green
19
    "no-pneumonia": [0, 0, 255]       # Blue
20
}
21
22
# Define a constant image size
23
IMAGE_SIZE = (640, 640)  # Width, Height
24
25
def segment_image(image):
26
    results = model(image)
27
    return results
28
29
def display_segmented_image(image, results):
30
    detected_classes = set()
31
    if results and results[0].masks is not None and results[0].boxes is not None:
32
        masks = results[0].masks.data.cpu().numpy()
33
        boxes = results[0].boxes.data.cpu().numpy()
34
        class_ids = results[0].boxes.cls.cpu().numpy()
35
36
        for mask, box, class_id in zip(masks, boxes, class_ids):
37
            class_name = model.names[int(class_id)]
38
            detected_classes.add(class_name)
39
40
        # If low-pneumonia or high-pneumonia is detected, ignore no-pneumonia masks
41
        if 'low-pneumonia' in detected_classes or 'high-pneumonia' in detected_classes:
42
            detected_classes.discard('no-pneumonia')
43
44
        for mask, box, class_id in zip(masks, boxes, class_ids):
45
            class_name = model.names[int(class_id)]
46
            if class_name not in detected_classes:
47
                continue
48
            mask = mask.astype(np.uint8)
49
            mask_resized = cv2.resize(mask, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_NEAREST)
50
            color = COLORS.get(class_name, [255, 255, 255])  # Default to white if class not found
51
            image[mask_resized == 1] = color
52
            x1, y1, x2, y2 = box[:4].astype(int)
53
            # Draw a rectangle behind the text for better visibility
54
            text_size = cv2.getTextSize(class_name, cv2.FONT_HERSHEY_SIMPLEX, 0.9, 2)[0]
55
            text_x = x1
56
            text_y = y1 - 10
57
            cv2.rectangle(image, (text_x, text_y - text_size[1]), (text_x + text_size[0], text_y), color, -1)
58
            cv2.putText(image, class_name, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255, 255, 255), 2)
59
            # Draw the bounding box rectangle
60
            cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
61
    else:
62
        st.warning("No objects detected.")
63
64
    # Resize the final image to the constant size
65
    resized_image = cv2.resize(image, IMAGE_SIZE, interpolation=cv2.INTER_LINEAR)
66
    
67
    return resized_image, detected_classes
68
69
70
st.title('X-ray Segmentation Project')
71
72
# Add an author section in the sidebar
73
st.sidebar.title('About the Author')
74
st.sidebar.markdown("""
75
    **Author Name**: Makhammadjonov Izzatullokh 
76
    **Email**: izzatullokhm@gmail.com  
77
    <a href="https://github.com/Izzatullokh24" target="_blank"><i class="fab fa-github"></i> GitHub</a>  
78
    <a href="https://www.linkedin.com/in/izzatullokh-makhammadjonov-242042195/" target="_blank"><i class="fab fa-linkedin"></i> LinkedIn</a>  
79
    <style>
80
        .fab {
81
            font-size: 24px;
82
            margin-right: 10px;
83
        }
84
    </style>
85
    <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/5.15.4/css/all.min.css">
86
    Izzatullokh is a machine learning engineer with a passion for computer vision and deep learning.
87
""", unsafe_allow_html=True)
88
89
uploaded_file = st.file_uploader("Upload an image or video", type=["jpg", "jpeg", "png", "mp4", "avi"])
90
91
if uploaded_file is not None:
92
    file_type = uploaded_file.type.split('/')[0]
93
94
    try:
95
        if file_type == 'image':
96
            file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
97
            image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
98
            results = segment_image(image)
99
            segmented_image, detected_classes = display_segmented_image(image, results)
100
            st.image(segmented_image, caption='Segmented Image', use_column_width=True)
101
102
             # Display pneumonia information
103
            if detected_classes:
104
                st.subheader('Diagnosis:')
105
                diagnosis_html = ""
106
                if 'high-pneumonia' in detected_classes:
107
                    diagnosis_html += '<p style="color: red; font-size: 20px;">The patient has <strong>high pneumonia</strong>.</p>'
108
                if 'low-pneumonia' in detected_classes:
109
                    diagnosis_html += '<p style="color: green; font-size: 20px;">The patient has <strong>low pneumonia</strong>.</p>'
110
                if 'no-pneumonia' in detected_classes:
111
                    diagnosis_html += '<p style="color: blue; font-size: 20px;">The patient does <strong>not have pneumonia</strong>.</p>'
112
                st.markdown(diagnosis_html, unsafe_allow_html=True)
113
            else:
114
                st.write('No pneumonia detected.')
115
116
        elif file_type == 'video':
117
            tfile = tempfile.NamedTemporaryFile(delete=False)
118
            tfile.write(uploaded_file.read())
119
            video_path = tfile.name
120
            segmented_frames = segment_video(video_path)
121
            stframe = st.empty()
122
            for frame, results in segmented_frames:
123
                segmented_frame, detected_classes = display_segmented_image(frame, results)
124
                stframe.image(segmented_frame, channels="BGR")
125
            os.remove(video_path)
126
127
        else:
128
            st.error("Unsupported file format. Please upload a jpg, jpeg, png, mp4, or avi file.")
129
130
    except Exception as e:
131
        st.error(f"An error occurred: {e}")
132
133
else:
134
    st.info("Please upload a file to proceed.")