Diff of /visualization/plots.py [000000] .. [9d22e8]

Switch to side-by-side view

--- a
+++ b/visualization/plots.py
@@ -0,0 +1,91 @@
+# visualization/plots.py
+import plotly.express as px
+import plotly.graph_objects as go
+import shap
+import matplotlib.pyplot as plt
+from sklearn.metrics import confusion_matrix
+import logging
+import pandas as pd
+import numpy as np
+
+def plot_feature_importance(model, X):
+    """Create feature importance plot"""
+    try:
+        importance = model.feature_importances_
+        feat_importance = pd.DataFrame({
+            'feature': X.columns,
+            'importance': importance
+        })
+        feat_importance = feat_importance.sort_values('importance', ascending=True)
+        
+        fig = px.bar(
+            feat_importance,
+            x='importance',
+            y='feature',
+            title='Feature Importance',
+            orientation='h'
+        )
+        
+        fig.update_layout(
+            height=500,
+            margin=dict(l=20, r=20, t=40, b=20),
+            title_x=0.5,
+            xaxis_title="Relative Importance",
+            yaxis_title="Features"
+        )
+        
+        return fig
+    except Exception as e:
+        logging.error(f"Error plotting feature importance: {e}")
+        return None
+
+def plot_confusion_matrix(y_true, y_pred):
+    """Create confusion matrix plot"""
+    try:
+        cm = confusion_matrix(y_true, y_pred)
+        fig = px.imshow(
+            cm,
+            labels=dict(x="Predicted", y="Actual"),
+            x=['Not Readmitted', 'Readmitted'],
+            y=['Not Readmitted', 'Readmitted'],
+            title="Confusion Matrix"
+        )
+        return fig
+    except Exception as e:
+        logging.error(f"Error plotting confusion matrix: {e}")
+        return None
+
+def generate_shap_plots(model, X_test_df):
+    """Generate SHAP plots"""
+    try:
+        explainer = shap.TreeExplainer(model)
+        shap_values = explainer.shap_values(X_test_df)
+        
+        # Create figure with even smaller size
+        plt.figure(figsize=(6, 4))
+        
+        # Generate summary plot with smaller size and font
+        shap.summary_plot(
+            shap_values,
+            X_test_df,
+            show=False,
+            plot_size=(6, 4),
+            max_display=10,  # Limit number of features shown
+            plot_type="bar"  # Use bar plot for more compact display
+        )
+        
+        # Adjust layout and fonts
+        plt.xticks(fontsize=8)
+        plt.yticks(fontsize=8)
+        plt.xlabel("SHAP value (impact on model output)", fontsize=8)
+        
+        # Adjust layout to prevent cutoff
+        plt.tight_layout()
+        
+        # Get the current figure
+        fig = plt.gcf()
+        
+        return fig
+    except Exception as e:
+        logging.error(f"Error generating SHAP plots: {e}")
+        return None
\ No newline at end of file