Diff of /main.py [000000] .. [9d22e8]

Switch to unified view

a b/main.py
1
import streamlit as st
2
import numpy as np
3
import pandas as pd
4
import logging
5
from data.data_loader import load_data, preprocess_data
6
from visualization.plots import generate_shap_plots, plot_feature_importance
7
from models.model import train_model, predict
8
from models.evaluation import calculate_metrics
9
10
def setup_logging():
11
    """Configure logging settings"""
12
    logging.basicConfig(
13
        level=logging.INFO,
14
        format='%(levelname)s - %(message)s'
15
    )
16
17
def is_valid_number(value):
18
    """Validate if input string is a valid number"""
19
    try:
20
        float(value)
21
        return True
22
    except ValueError:
23
        return False
24
25
def main():
26
    st.set_page_config(
27
        page_title="30-Day Readmission Risk Predictor",
28
        layout="wide"
29
    )
30
    
31
    setup_logging()
32
    
33
    st.title("30-Day Mental Health Hospital Readmission Risk Predictor")
34
    
35
    st.write("""
36
    This application predicts the risk of patient readmission within 30 days
37
    of discharge from a mental health hospital using machine learning.
38
    """)
39
    
40
    try:
41
        # Load and preprocess data
42
        data = load_data()
43
        if data is None:
44
            st.error("Error loading data")
45
            return
46
            
47
        X_train, X_test, y_train, y_test, feature_names, scaler = preprocess_data(data)
48
        if X_train is None:
49
            st.error("Error preprocessing data")
50
            return
51
        
52
        # Ensure X_train and X_test are DataFrames with feature names
53
        X_train_df = pd.DataFrame(X_train, columns=feature_names)
54
        X_test_df = pd.DataFrame(X_test, columns=feature_names)
55
        
56
        # Train model
57
        model = train_model(X_train_df, y_train)
58
        if model is None:
59
            st.error("Error training model")
60
            return
61
            
62
        # Make predictions
63
        y_pred, y_pred_proba = predict(model, X_test_df)
64
        if y_pred is None:
65
            st.error("Error making predictions")
66
            return
67
            
68
        # Calculate and display metrics
69
        metrics = calculate_metrics(y_test, y_pred, y_pred_proba)
70
        if metrics:
71
            st.header("Model Performance Metrics")
72
            col1, col2, col3, col4, col5 = st.columns(5)
73
            with col1:
74
                st.metric("Accuracy", f"{metrics['accuracy']:.2f}")
75
            with col2:
76
                st.metric("Precision", f"{metrics['precision']:.2f}")
77
            with col3:
78
                st.metric("Recall", f"{metrics['recall']:.2f}")
79
            with col4:
80
                st.metric("F1 Score", f"{metrics['f1']:.2f}")
81
            with col5:
82
                st.metric("AUC-ROC", f"{metrics['auc_roc']:.2f}")
83
        
84
        # Define categorical features and their options
85
        categorical_features = {
86
            'age': ['18-25', '26-35', '36-45', '46-55', '56-65', '65+'],
87
            'length_of_stay': ['1-3 days', '4-7 days', '8-14 days', '15+ days']
88
        }
89
90
        # Define numerical features and their ranges
91
        numerical_features = {
92
            'previous_admissions': {'min': 0, 'max': 100},
93
            'num_procedures': {'min': 0, 'max': 50},
94
            'num_medications': {'min': 0, 'max': 100},
95
            'num_diagnoses': {'min': 0, 'max': 50},
96
            'num_lab_procedures': {'min': 0, 'max': 200},
97
            'num_outpatient': {'min': 0, 'max': 100},
98
            'num_emergency': {'min': 0, 'max': 100},
99
            'num_inpatient': {'min': 0, 'max': 100}
100
        }
101
102
        # Interactive Prediction
103
        st.header("Predict Readmission Risk")
104
        st.write("Enter patient information to predict readmission risk")
105
        
106
        # Create columns for input fields
107
        col1, col2 = st.columns(2)
108
        
109
        # Store all input values
110
        input_values = {}
111
        
112
        # First column of inputs
113
        with col1:
114
            # Handle categorical features
115
            for feature in categorical_features:
116
                input_values[feature] = st.selectbox(
117
                    f"{feature.replace('_', ' ').title()}",
118
                    options=categorical_features[feature],
119
                    key=f"select_{feature}"
120
                )
121
            
122
            # Handle first half of numerical features
123
            numerical_features_list = list(numerical_features.keys())
124
            for feature in numerical_features_list[:len(numerical_features_list)//2]:
125
                input_text = st.text_input(
126
                    f"{feature.replace('_', ' ').title()}",
127
                    value="0",
128
                    key=f"input_{feature}"
129
                )
130
                
131
                if not is_valid_number(input_text):
132
                    st.error(f"Please enter a valid number for {feature}")
133
                    input_values[feature] = 0
134
                else:
135
                    value = float(input_text)
136
                    if value < numerical_features[feature]['min'] or value > numerical_features[feature]['max']:
137
                        st.warning(f"Value should be between {numerical_features[feature]['min']} and {numerical_features[feature]['max']}")
138
                    input_values[feature] = value
139
140
        # Second column of inputs
141
        with col2:
142
            # Handle second half of numerical features
143
            for feature in numerical_features_list[len(numerical_features_list)//2:]:
144
                input_text = st.text_input(
145
                    f"{feature.replace('_', ' ').title()}",
146
                    value="0",
147
                    key=f"input_{feature}"
148
                )
149
                
150
                if not is_valid_number(input_text):
151
                    st.error(f"Please enter a valid number for {feature}")
152
                    input_values[feature] = 0
153
                else:
154
                    value = float(input_text)
155
                    if value < numerical_features[feature]['min'] or value > numerical_features[feature]['max']:
156
                        st.warning(f"Value should be between {numerical_features[feature]['min']} and {numerical_features[feature]['max']}")
157
                    input_values[feature] = value
158
159
        # Prediction button
160
        if st.button("Predict"):
161
            # Convert categorical inputs to numerical
162
            input_data = input_values.copy()
163
            
164
            # Process age categories
165
            age_mapping = {'18-25': 21.5, '26-35': 30.5, '36-45': 40.5, 
166
                          '46-55': 50.5, '56-65': 60.5, '65+': 70}
167
            input_data['age'] = age_mapping[input_data['age']]
168
            
169
            # Process length of stay categories
170
            los_mapping = {'1-3 days': 2, '4-7 days': 5.5, '8-14 days': 11, 
171
                          '15+ days': 15}
172
            input_data['length_of_stay'] = los_mapping[input_data['length_of_stay']]
173
            
174
            # Create DataFrame and make prediction
175
            input_df = pd.DataFrame([input_data], columns=feature_names)
176
            input_scaled = pd.DataFrame(
177
                scaler.transform(input_df),
178
                columns=feature_names
179
            )
180
            
181
            _, probabilities = predict(model, input_scaled)
182
            
183
            if probabilities is not None:
184
                prediction = probabilities[0]
185
                st.write(f"Predicted probability of readmission: {prediction:.2%}")
186
                
187
                risk_category = (
188
                    "High" if prediction > 0.7
189
                    else "Medium" if prediction > 0.3
190
                    else "Low"
191
                )
192
                st.write(f"Risk Category: {risk_category}")
193
        
194
        # Model Insights Section
195
        st.header("Model Insights")
196
        
197
        # Feature Importance Plot
198
        st.subheader("Feature Importance")
199
        st.write("""
200
        This plot shows the relative importance of each feature in making predictions.
201
        Higher values indicate more important features.
202
        """)
203
        fig_importance = plot_feature_importance(model, X_train_df)
204
        if fig_importance:
205
            st.plotly_chart(fig_importance)
206
        
207
        # SHAP Analysis
208
        st.subheader("SHAP Analysis")
209
        st.write("""
210
        SHAP (SHapley Additive exPlanations) values show how each feature
211
        contributes to predictions. Features in red increase the prediction,
212
        while features in blue decrease it.
213
        """)
214
        
215
        # Create and display SHAP plot with custom width
216
        fig_shap = generate_shap_plots(model, X_test_df)
217
        if fig_shap:
218
            # Use a container with custom width
219
            with st.container():
220
                st.pyplot(fig_shap, use_container_width=False)
221
        else:
222
            st.error("Error generating SHAP plot")
223
        
224
    except Exception as e:
225
        st.error(f"An error occurred: {str(e)}")
226
        logging.error(f"Error in main: {str(e)}")
227
228
if __name__ == "__main__":
229
    main()