--- a +++ b/main.py @@ -0,0 +1,229 @@ +import streamlit as st +import numpy as np +import pandas as pd +import logging +from data.data_loader import load_data, preprocess_data +from visualization.plots import generate_shap_plots, plot_feature_importance +from models.model import train_model, predict +from models.evaluation import calculate_metrics + +def setup_logging(): + """Configure logging settings""" + logging.basicConfig( + level=logging.INFO, + format='%(levelname)s - %(message)s' + ) + +def is_valid_number(value): + """Validate if input string is a valid number""" + try: + float(value) + return True + except ValueError: + return False + +def main(): + st.set_page_config( + page_title="30-Day Readmission Risk Predictor", + layout="wide" + ) + + setup_logging() + + st.title("30-Day Mental Health Hospital Readmission Risk Predictor") + + st.write(""" + This application predicts the risk of patient readmission within 30 days + of discharge from a mental health hospital using machine learning. + """) + + try: + # Load and preprocess data + data = load_data() + if data is None: + st.error("Error loading data") + return + + X_train, X_test, y_train, y_test, feature_names, scaler = preprocess_data(data) + if X_train is None: + st.error("Error preprocessing data") + return + + # Ensure X_train and X_test are DataFrames with feature names + X_train_df = pd.DataFrame(X_train, columns=feature_names) + X_test_df = pd.DataFrame(X_test, columns=feature_names) + + # Train model + model = train_model(X_train_df, y_train) + if model is None: + st.error("Error training model") + return + + # Make predictions + y_pred, y_pred_proba = predict(model, X_test_df) + if y_pred is None: + st.error("Error making predictions") + return + + # Calculate and display metrics + metrics = calculate_metrics(y_test, y_pred, y_pred_proba) + if metrics: + st.header("Model Performance Metrics") + col1, col2, col3, col4, col5 = st.columns(5) + with col1: + st.metric("Accuracy", f"{metrics['accuracy']:.2f}") + with col2: + st.metric("Precision", f"{metrics['precision']:.2f}") + with col3: + st.metric("Recall", f"{metrics['recall']:.2f}") + with col4: + st.metric("F1 Score", f"{metrics['f1']:.2f}") + with col5: + st.metric("AUC-ROC", f"{metrics['auc_roc']:.2f}") + + # Define categorical features and their options + categorical_features = { + 'age': ['18-25', '26-35', '36-45', '46-55', '56-65', '65+'], + 'length_of_stay': ['1-3 days', '4-7 days', '8-14 days', '15+ days'] + } + + # Define numerical features and their ranges + numerical_features = { + 'previous_admissions': {'min': 0, 'max': 100}, + 'num_procedures': {'min': 0, 'max': 50}, + 'num_medications': {'min': 0, 'max': 100}, + 'num_diagnoses': {'min': 0, 'max': 50}, + 'num_lab_procedures': {'min': 0, 'max': 200}, + 'num_outpatient': {'min': 0, 'max': 100}, + 'num_emergency': {'min': 0, 'max': 100}, + 'num_inpatient': {'min': 0, 'max': 100} + } + + # Interactive Prediction + st.header("Predict Readmission Risk") + st.write("Enter patient information to predict readmission risk") + + # Create columns for input fields + col1, col2 = st.columns(2) + + # Store all input values + input_values = {} + + # First column of inputs + with col1: + # Handle categorical features + for feature in categorical_features: + input_values[feature] = st.selectbox( + f"{feature.replace('_', ' ').title()}", + options=categorical_features[feature], + key=f"select_{feature}" + ) + + # Handle first half of numerical features + numerical_features_list = list(numerical_features.keys()) + for feature in numerical_features_list[:len(numerical_features_list)//2]: + input_text = st.text_input( + f"{feature.replace('_', ' ').title()}", + value="0", + key=f"input_{feature}" + ) + + if not is_valid_number(input_text): + st.error(f"Please enter a valid number for {feature}") + input_values[feature] = 0 + else: + value = float(input_text) + if value < numerical_features[feature]['min'] or value > numerical_features[feature]['max']: + st.warning(f"Value should be between {numerical_features[feature]['min']} and {numerical_features[feature]['max']}") + input_values[feature] = value + + # Second column of inputs + with col2: + # Handle second half of numerical features + for feature in numerical_features_list[len(numerical_features_list)//2:]: + input_text = st.text_input( + f"{feature.replace('_', ' ').title()}", + value="0", + key=f"input_{feature}" + ) + + if not is_valid_number(input_text): + st.error(f"Please enter a valid number for {feature}") + input_values[feature] = 0 + else: + value = float(input_text) + if value < numerical_features[feature]['min'] or value > numerical_features[feature]['max']: + st.warning(f"Value should be between {numerical_features[feature]['min']} and {numerical_features[feature]['max']}") + input_values[feature] = value + + # Prediction button + if st.button("Predict"): + # Convert categorical inputs to numerical + input_data = input_values.copy() + + # Process age categories + age_mapping = {'18-25': 21.5, '26-35': 30.5, '36-45': 40.5, + '46-55': 50.5, '56-65': 60.5, '65+': 70} + input_data['age'] = age_mapping[input_data['age']] + + # Process length of stay categories + los_mapping = {'1-3 days': 2, '4-7 days': 5.5, '8-14 days': 11, + '15+ days': 15} + input_data['length_of_stay'] = los_mapping[input_data['length_of_stay']] + + # Create DataFrame and make prediction + input_df = pd.DataFrame([input_data], columns=feature_names) + input_scaled = pd.DataFrame( + scaler.transform(input_df), + columns=feature_names + ) + + _, probabilities = predict(model, input_scaled) + + if probabilities is not None: + prediction = probabilities[0] + st.write(f"Predicted probability of readmission: {prediction:.2%}") + + risk_category = ( + "High" if prediction > 0.7 + else "Medium" if prediction > 0.3 + else "Low" + ) + st.write(f"Risk Category: {risk_category}") + + # Model Insights Section + st.header("Model Insights") + + # Feature Importance Plot + st.subheader("Feature Importance") + st.write(""" + This plot shows the relative importance of each feature in making predictions. + Higher values indicate more important features. + """) + fig_importance = plot_feature_importance(model, X_train_df) + if fig_importance: + st.plotly_chart(fig_importance) + + # SHAP Analysis + st.subheader("SHAP Analysis") + st.write(""" + SHAP (SHapley Additive exPlanations) values show how each feature + contributes to predictions. Features in red increase the prediction, + while features in blue decrease it. + """) + + # Create and display SHAP plot with custom width + fig_shap = generate_shap_plots(model, X_test_df) + if fig_shap: + # Use a container with custom width + with st.container(): + st.pyplot(fig_shap, use_container_width=False) + else: + st.error("Error generating SHAP plot") + + except Exception as e: + st.error(f"An error occurred: {str(e)}") + logging.error(f"Error in main: {str(e)}") + +if __name__ == "__main__": + main() \ No newline at end of file