Diff of /main.py [000000] .. [9d22e8]

Switch to side-by-side view

--- a
+++ b/main.py
@@ -0,0 +1,229 @@
+import streamlit as st
+import numpy as np
+import pandas as pd
+import logging
+from data.data_loader import load_data, preprocess_data
+from visualization.plots import generate_shap_plots, plot_feature_importance
+from models.model import train_model, predict
+from models.evaluation import calculate_metrics
+
+def setup_logging():
+    """Configure logging settings"""
+    logging.basicConfig(
+        level=logging.INFO,
+        format='%(levelname)s - %(message)s'
+    )
+
+def is_valid_number(value):
+    """Validate if input string is a valid number"""
+    try:
+        float(value)
+        return True
+    except ValueError:
+        return False
+
+def main():
+    st.set_page_config(
+        page_title="30-Day Readmission Risk Predictor",
+        layout="wide"
+    )
+    
+    setup_logging()
+    
+    st.title("30-Day Mental Health Hospital Readmission Risk Predictor")
+    
+    st.write("""
+    This application predicts the risk of patient readmission within 30 days
+    of discharge from a mental health hospital using machine learning.
+    """)
+    
+    try:
+        # Load and preprocess data
+        data = load_data()
+        if data is None:
+            st.error("Error loading data")
+            return
+            
+        X_train, X_test, y_train, y_test, feature_names, scaler = preprocess_data(data)
+        if X_train is None:
+            st.error("Error preprocessing data")
+            return
+        
+        # Ensure X_train and X_test are DataFrames with feature names
+        X_train_df = pd.DataFrame(X_train, columns=feature_names)
+        X_test_df = pd.DataFrame(X_test, columns=feature_names)
+        
+        # Train model
+        model = train_model(X_train_df, y_train)
+        if model is None:
+            st.error("Error training model")
+            return
+            
+        # Make predictions
+        y_pred, y_pred_proba = predict(model, X_test_df)
+        if y_pred is None:
+            st.error("Error making predictions")
+            return
+            
+        # Calculate and display metrics
+        metrics = calculate_metrics(y_test, y_pred, y_pred_proba)
+        if metrics:
+            st.header("Model Performance Metrics")
+            col1, col2, col3, col4, col5 = st.columns(5)
+            with col1:
+                st.metric("Accuracy", f"{metrics['accuracy']:.2f}")
+            with col2:
+                st.metric("Precision", f"{metrics['precision']:.2f}")
+            with col3:
+                st.metric("Recall", f"{metrics['recall']:.2f}")
+            with col4:
+                st.metric("F1 Score", f"{metrics['f1']:.2f}")
+            with col5:
+                st.metric("AUC-ROC", f"{metrics['auc_roc']:.2f}")
+        
+        # Define categorical features and their options
+        categorical_features = {
+            'age': ['18-25', '26-35', '36-45', '46-55', '56-65', '65+'],
+            'length_of_stay': ['1-3 days', '4-7 days', '8-14 days', '15+ days']
+        }
+
+        # Define numerical features and their ranges
+        numerical_features = {
+            'previous_admissions': {'min': 0, 'max': 100},
+            'num_procedures': {'min': 0, 'max': 50},
+            'num_medications': {'min': 0, 'max': 100},
+            'num_diagnoses': {'min': 0, 'max': 50},
+            'num_lab_procedures': {'min': 0, 'max': 200},
+            'num_outpatient': {'min': 0, 'max': 100},
+            'num_emergency': {'min': 0, 'max': 100},
+            'num_inpatient': {'min': 0, 'max': 100}
+        }
+
+        # Interactive Prediction
+        st.header("Predict Readmission Risk")
+        st.write("Enter patient information to predict readmission risk")
+        
+        # Create columns for input fields
+        col1, col2 = st.columns(2)
+        
+        # Store all input values
+        input_values = {}
+        
+        # First column of inputs
+        with col1:
+            # Handle categorical features
+            for feature in categorical_features:
+                input_values[feature] = st.selectbox(
+                    f"{feature.replace('_', ' ').title()}",
+                    options=categorical_features[feature],
+                    key=f"select_{feature}"
+                )
+            
+            # Handle first half of numerical features
+            numerical_features_list = list(numerical_features.keys())
+            for feature in numerical_features_list[:len(numerical_features_list)//2]:
+                input_text = st.text_input(
+                    f"{feature.replace('_', ' ').title()}",
+                    value="0",
+                    key=f"input_{feature}"
+                )
+                
+                if not is_valid_number(input_text):
+                    st.error(f"Please enter a valid number for {feature}")
+                    input_values[feature] = 0
+                else:
+                    value = float(input_text)
+                    if value < numerical_features[feature]['min'] or value > numerical_features[feature]['max']:
+                        st.warning(f"Value should be between {numerical_features[feature]['min']} and {numerical_features[feature]['max']}")
+                    input_values[feature] = value
+
+        # Second column of inputs
+        with col2:
+            # Handle second half of numerical features
+            for feature in numerical_features_list[len(numerical_features_list)//2:]:
+                input_text = st.text_input(
+                    f"{feature.replace('_', ' ').title()}",
+                    value="0",
+                    key=f"input_{feature}"
+                )
+                
+                if not is_valid_number(input_text):
+                    st.error(f"Please enter a valid number for {feature}")
+                    input_values[feature] = 0
+                else:
+                    value = float(input_text)
+                    if value < numerical_features[feature]['min'] or value > numerical_features[feature]['max']:
+                        st.warning(f"Value should be between {numerical_features[feature]['min']} and {numerical_features[feature]['max']}")
+                    input_values[feature] = value
+
+        # Prediction button
+        if st.button("Predict"):
+            # Convert categorical inputs to numerical
+            input_data = input_values.copy()
+            
+            # Process age categories
+            age_mapping = {'18-25': 21.5, '26-35': 30.5, '36-45': 40.5, 
+                          '46-55': 50.5, '56-65': 60.5, '65+': 70}
+            input_data['age'] = age_mapping[input_data['age']]
+            
+            # Process length of stay categories
+            los_mapping = {'1-3 days': 2, '4-7 days': 5.5, '8-14 days': 11, 
+                          '15+ days': 15}
+            input_data['length_of_stay'] = los_mapping[input_data['length_of_stay']]
+            
+            # Create DataFrame and make prediction
+            input_df = pd.DataFrame([input_data], columns=feature_names)
+            input_scaled = pd.DataFrame(
+                scaler.transform(input_df),
+                columns=feature_names
+            )
+            
+            _, probabilities = predict(model, input_scaled)
+            
+            if probabilities is not None:
+                prediction = probabilities[0]
+                st.write(f"Predicted probability of readmission: {prediction:.2%}")
+                
+                risk_category = (
+                    "High" if prediction > 0.7
+                    else "Medium" if prediction > 0.3
+                    else "Low"
+                )
+                st.write(f"Risk Category: {risk_category}")
+        
+        # Model Insights Section
+        st.header("Model Insights")
+        
+        # Feature Importance Plot
+        st.subheader("Feature Importance")
+        st.write("""
+        This plot shows the relative importance of each feature in making predictions.
+        Higher values indicate more important features.
+        """)
+        fig_importance = plot_feature_importance(model, X_train_df)
+        if fig_importance:
+            st.plotly_chart(fig_importance)
+        
+        # SHAP Analysis
+        st.subheader("SHAP Analysis")
+        st.write("""
+        SHAP (SHapley Additive exPlanations) values show how each feature
+        contributes to predictions. Features in red increase the prediction,
+        while features in blue decrease it.
+        """)
+        
+        # Create and display SHAP plot with custom width
+        fig_shap = generate_shap_plots(model, X_test_df)
+        if fig_shap:
+            # Use a container with custom width
+            with st.container():
+                st.pyplot(fig_shap, use_container_width=False)
+        else:
+            st.error("Error generating SHAP plot")
+        
+    except Exception as e:
+        st.error(f"An error occurred: {str(e)}")
+        logging.error(f"Error in main: {str(e)}")
+
+if __name__ == "__main__":
+    main() 
\ No newline at end of file