Diff of /visualization/plots.py [000000] .. [9d22e8]

Switch to unified view

a b/visualization/plots.py
1
# visualization/plots.py
2
import plotly.express as px
3
import plotly.graph_objects as go
4
import shap
5
import matplotlib.pyplot as plt
6
from sklearn.metrics import confusion_matrix
7
import logging
8
import pandas as pd
9
import numpy as np
10
11
def plot_feature_importance(model, X):
12
    """Create feature importance plot"""
13
    try:
14
        importance = model.feature_importances_
15
        feat_importance = pd.DataFrame({
16
            'feature': X.columns,
17
            'importance': importance
18
        })
19
        feat_importance = feat_importance.sort_values('importance', ascending=True)
20
        
21
        fig = px.bar(
22
            feat_importance,
23
            x='importance',
24
            y='feature',
25
            title='Feature Importance',
26
            orientation='h'
27
        )
28
        
29
        fig.update_layout(
30
            height=500,
31
            margin=dict(l=20, r=20, t=40, b=20),
32
            title_x=0.5,
33
            xaxis_title="Relative Importance",
34
            yaxis_title="Features"
35
        )
36
        
37
        return fig
38
    except Exception as e:
39
        logging.error(f"Error plotting feature importance: {e}")
40
        return None
41
42
def plot_confusion_matrix(y_true, y_pred):
43
    """Create confusion matrix plot"""
44
    try:
45
        cm = confusion_matrix(y_true, y_pred)
46
        fig = px.imshow(
47
            cm,
48
            labels=dict(x="Predicted", y="Actual"),
49
            x=['Not Readmitted', 'Readmitted'],
50
            y=['Not Readmitted', 'Readmitted'],
51
            title="Confusion Matrix"
52
        )
53
        return fig
54
    except Exception as e:
55
        logging.error(f"Error plotting confusion matrix: {e}")
56
        return None
57
58
def generate_shap_plots(model, X_test_df):
59
    """Generate SHAP plots"""
60
    try:
61
        explainer = shap.TreeExplainer(model)
62
        shap_values = explainer.shap_values(X_test_df)
63
        
64
        # Create figure with even smaller size
65
        plt.figure(figsize=(6, 4))
66
        
67
        # Generate summary plot with smaller size and font
68
        shap.summary_plot(
69
            shap_values,
70
            X_test_df,
71
            show=False,
72
            plot_size=(6, 4),
73
            max_display=10,  # Limit number of features shown
74
            plot_type="bar"  # Use bar plot for more compact display
75
        )
76
        
77
        # Adjust layout and fonts
78
        plt.xticks(fontsize=8)
79
        plt.yticks(fontsize=8)
80
        plt.xlabel("SHAP value (impact on model output)", fontsize=8)
81
        
82
        # Adjust layout to prevent cutoff
83
        plt.tight_layout()
84
        
85
        # Get the current figure
86
        fig = plt.gcf()
87
        
88
        return fig
89
    except Exception as e:
90
        logging.error(f"Error generating SHAP plots: {e}")
91
        return None