[9d22e8]: / main.py

Download this file

229 lines (194 with data), 8.9 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
import streamlit as st
import numpy as np
import pandas as pd
import logging
from data.data_loader import load_data, preprocess_data
from visualization.plots import generate_shap_plots, plot_feature_importance
from models.model import train_model, predict
from models.evaluation import calculate_metrics
def setup_logging():
"""Configure logging settings"""
logging.basicConfig(
level=logging.INFO,
format='%(levelname)s - %(message)s'
)
def is_valid_number(value):
"""Validate if input string is a valid number"""
try:
float(value)
return True
except ValueError:
return False
def main():
st.set_page_config(
page_title="30-Day Readmission Risk Predictor",
layout="wide"
)
setup_logging()
st.title("30-Day Mental Health Hospital Readmission Risk Predictor")
st.write("""
This application predicts the risk of patient readmission within 30 days
of discharge from a mental health hospital using machine learning.
""")
try:
# Load and preprocess data
data = load_data()
if data is None:
st.error("Error loading data")
return
X_train, X_test, y_train, y_test, feature_names, scaler = preprocess_data(data)
if X_train is None:
st.error("Error preprocessing data")
return
# Ensure X_train and X_test are DataFrames with feature names
X_train_df = pd.DataFrame(X_train, columns=feature_names)
X_test_df = pd.DataFrame(X_test, columns=feature_names)
# Train model
model = train_model(X_train_df, y_train)
if model is None:
st.error("Error training model")
return
# Make predictions
y_pred, y_pred_proba = predict(model, X_test_df)
if y_pred is None:
st.error("Error making predictions")
return
# Calculate and display metrics
metrics = calculate_metrics(y_test, y_pred, y_pred_proba)
if metrics:
st.header("Model Performance Metrics")
col1, col2, col3, col4, col5 = st.columns(5)
with col1:
st.metric("Accuracy", f"{metrics['accuracy']:.2f}")
with col2:
st.metric("Precision", f"{metrics['precision']:.2f}")
with col3:
st.metric("Recall", f"{metrics['recall']:.2f}")
with col4:
st.metric("F1 Score", f"{metrics['f1']:.2f}")
with col5:
st.metric("AUC-ROC", f"{metrics['auc_roc']:.2f}")
# Define categorical features and their options
categorical_features = {
'age': ['18-25', '26-35', '36-45', '46-55', '56-65', '65+'],
'length_of_stay': ['1-3 days', '4-7 days', '8-14 days', '15+ days']
}
# Define numerical features and their ranges
numerical_features = {
'previous_admissions': {'min': 0, 'max': 100},
'num_procedures': {'min': 0, 'max': 50},
'num_medications': {'min': 0, 'max': 100},
'num_diagnoses': {'min': 0, 'max': 50},
'num_lab_procedures': {'min': 0, 'max': 200},
'num_outpatient': {'min': 0, 'max': 100},
'num_emergency': {'min': 0, 'max': 100},
'num_inpatient': {'min': 0, 'max': 100}
}
# Interactive Prediction
st.header("Predict Readmission Risk")
st.write("Enter patient information to predict readmission risk")
# Create columns for input fields
col1, col2 = st.columns(2)
# Store all input values
input_values = {}
# First column of inputs
with col1:
# Handle categorical features
for feature in categorical_features:
input_values[feature] = st.selectbox(
f"{feature.replace('_', ' ').title()}",
options=categorical_features[feature],
key=f"select_{feature}"
)
# Handle first half of numerical features
numerical_features_list = list(numerical_features.keys())
for feature in numerical_features_list[:len(numerical_features_list)//2]:
input_text = st.text_input(
f"{feature.replace('_', ' ').title()}",
value="0",
key=f"input_{feature}"
)
if not is_valid_number(input_text):
st.error(f"Please enter a valid number for {feature}")
input_values[feature] = 0
else:
value = float(input_text)
if value < numerical_features[feature]['min'] or value > numerical_features[feature]['max']:
st.warning(f"Value should be between {numerical_features[feature]['min']} and {numerical_features[feature]['max']}")
input_values[feature] = value
# Second column of inputs
with col2:
# Handle second half of numerical features
for feature in numerical_features_list[len(numerical_features_list)//2:]:
input_text = st.text_input(
f"{feature.replace('_', ' ').title()}",
value="0",
key=f"input_{feature}"
)
if not is_valid_number(input_text):
st.error(f"Please enter a valid number for {feature}")
input_values[feature] = 0
else:
value = float(input_text)
if value < numerical_features[feature]['min'] or value > numerical_features[feature]['max']:
st.warning(f"Value should be between {numerical_features[feature]['min']} and {numerical_features[feature]['max']}")
input_values[feature] = value
# Prediction button
if st.button("Predict"):
# Convert categorical inputs to numerical
input_data = input_values.copy()
# Process age categories
age_mapping = {'18-25': 21.5, '26-35': 30.5, '36-45': 40.5,
'46-55': 50.5, '56-65': 60.5, '65+': 70}
input_data['age'] = age_mapping[input_data['age']]
# Process length of stay categories
los_mapping = {'1-3 days': 2, '4-7 days': 5.5, '8-14 days': 11,
'15+ days': 15}
input_data['length_of_stay'] = los_mapping[input_data['length_of_stay']]
# Create DataFrame and make prediction
input_df = pd.DataFrame([input_data], columns=feature_names)
input_scaled = pd.DataFrame(
scaler.transform(input_df),
columns=feature_names
)
_, probabilities = predict(model, input_scaled)
if probabilities is not None:
prediction = probabilities[0]
st.write(f"Predicted probability of readmission: {prediction:.2%}")
risk_category = (
"High" if prediction > 0.7
else "Medium" if prediction > 0.3
else "Low"
)
st.write(f"Risk Category: {risk_category}")
# Model Insights Section
st.header("Model Insights")
# Feature Importance Plot
st.subheader("Feature Importance")
st.write("""
This plot shows the relative importance of each feature in making predictions.
Higher values indicate more important features.
""")
fig_importance = plot_feature_importance(model, X_train_df)
if fig_importance:
st.plotly_chart(fig_importance)
# SHAP Analysis
st.subheader("SHAP Analysis")
st.write("""
SHAP (SHapley Additive exPlanations) values show how each feature
contributes to predictions. Features in red increase the prediction,
while features in blue decrease it.
""")
# Create and display SHAP plot with custom width
fig_shap = generate_shap_plots(model, X_test_df)
if fig_shap:
# Use a container with custom width
with st.container():
st.pyplot(fig_shap, use_container_width=False)
else:
st.error("Error generating SHAP plot")
except Exception as e:
st.error(f"An error occurred: {str(e)}")
logging.error(f"Error in main: {str(e)}")
if __name__ == "__main__":
main()