a b/src/iterpretability/explain.py
1
from typing import Dict, List, Optional
2
3
import matplotlib.pyplot as plt
4
import numpy as np
5
import torch
6
from captum._utils.models.linear_model import SkLearnLinearRegression
7
from captum.attr import (
8
    DeepLift,
9
    FeatureAblation,
10
    FeaturePermutation,
11
    IntegratedGradients,
12
    KernelShap,
13
    Lime,
14
    ShapleyValueSampling,
15
    GradientShap,
16
)
17
from captum.attr._core.lime import get_exp_kernel_similarity_function
18
from torch import nn
19
20
21
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
23
24
class Explainer:
25
    """
26
    Explainer instance, consisting of several explainability methods.
27
    """
28
29
    def __init__(
30
        self,
31
        model: nn.Module,
32
        feature_names: List,
33
        explainer_list: List = [
34
            "feature_ablation",
35
            "feature_permutation",
36
            "integrated_gradients",
37
            "deeplift",
38
            "shapley_value_sampling",
39
            "lime",
40
        ],
41
        n_steps: int = 500,
42
        perturbations_per_eval: int = 10,
43
        n_samples: int = 1000,
44
        kernel_width: float = 1.0,
45
        baseline: Optional[torch.Tensor] = None,
46
    ) -> None:
47
        self.baseline = baseline
48
        self.explainer_list = explainer_list
49
        self.feature_names = feature_names
50
51
        # Feature ablation
52
        feature_ablation_model = FeatureAblation(model)
53
54
        def feature_ablation_cbk(X_test: torch.Tensor) -> torch.Tensor:
55
            out = feature_ablation_model.attribute(
56
                X_test, n_steps=n_steps, perturbations_per_eval=perturbations_per_eval
57
            )
58
59
            return out
60
61
        # Integrated gradients
62
        integrated_gradients_model = IntegratedGradients(model)
63
64
        def integrated_gradients_cbk(X_test: torch.Tensor) -> torch.Tensor:
65
            return integrated_gradients_model.attribute(X_test, n_steps=n_steps)
66
67
        # DeepLift
68
        deeplift_model = DeepLift(model)
69
70
        def deeplift_cbk(X_test: torch.Tensor) -> torch.Tensor:
71
            return deeplift_model.attribute(X_test)
72
73
        # Feature permutation
74
        feature_permutation_model = FeaturePermutation(model)
75
76
        def feature_permutation_cbk(X_test: torch.Tensor) -> torch.Tensor:
77
            return feature_permutation_model.attribute(
78
                X_test, n_steps=n_steps, perturbations_per_eval=perturbations_per_eval
79
            )
80
81
        # LIME
82
        exp_eucl_distance = get_exp_kernel_similarity_function(
83
            kernel_width=kernel_width
84
        )
85
        lime_model = Lime(
86
            model,
87
            interpretable_model=SkLearnLinearRegression(),
88
            similarity_func=exp_eucl_distance,
89
        )
90
91
        def lime_cbk(X_test: torch.Tensor) -> torch.Tensor:
92
            return lime_model.attribute(
93
                X_test,
94
                n_samples=n_samples,
95
                perturbations_per_eval=perturbations_per_eval,
96
            )
97
98
        # Shapley value sampling
99
        shapley_value_sampling_model = ShapleyValueSampling(model)
100
101
        def shapley_value_sampling_cbk(X_test: torch.Tensor) -> torch.Tensor:
102
            return shapley_value_sampling_model.attribute(
103
                X_test,
104
                n_samples=n_samples,
105
                perturbations_per_eval=perturbations_per_eval,
106
            )
107
108
        # Kernel SHAP
109
        kernel_shap_model = KernelShap(model)
110
111
        def kernel_shap_cbk(X_test: torch.Tensor) -> torch.Tensor:
112
            return kernel_shap_model.attribute(
113
                X_test,
114
                n_samples=n_samples,
115
                perturbations_per_eval=perturbations_per_eval,
116
            )
117
118
        # Gradient SHAP
119
        gradient_shap_model = GradientShap(model)
120
121
        def gradient_shap_cbk(X_test: torch.Tensor) -> torch.Tensor:
122
            return gradient_shap_model.attribute(X_test, baselines=self.baseline)
123
124
        self.explainers = {
125
            "feature_ablation": feature_ablation_cbk,
126
            "integrated_gradients": integrated_gradients_cbk,
127
            "deeplift": deeplift_cbk,
128
            "feature_permutation": feature_permutation_cbk,
129
            "lime": lime_cbk,
130
            "shapley_value_sampling": shapley_value_sampling_cbk,
131
            "kernel_shap": kernel_shap_cbk,
132
            "gradient_shap": gradient_shap_cbk,
133
        }
134
135
    def _check_tensor(self, X: torch.Tensor) -> torch.Tensor:
136
        if isinstance(X, torch.Tensor):
137
            return X.to(DEVICE)
138
        else:
139
            return torch.from_numpy(np.asarray(X)).float().to(DEVICE)
140
141
    def explain(self, X: torch.Tensor) -> Dict:
142
        output = {}
143
        if self.baseline is None:
144
            self.baseline = torch.zeros(
145
                X.shape
146
            )  # Zero tensor as baseline if no baseline specified
147
        for name in self.explainer_list:
148
            X_test = self._check_tensor(X)
149
            self.baseline = self._check_tensor(self.baseline)
150
            X_test.requires_grad_()
151
            explainer = self.explainers[name]
152
            output[name] = explainer(X_test).detach().cpu().numpy()
153
        return output
154
155
    def plot(self, X: torch.Tensor) -> None:
156
        explanations = self.explain(X)
157
158
        fig, axs = plt.subplots(int((len(explanations) + 1) / 2), 2)
159
160
        idx = 0
161
        for name in explanations:
162
            x_pos = np.arange(len(self.feature_names))
163
164
            ax = axs[int(idx / 2), idx % 2]
165
166
            ax.bar(x_pos, np.mean(np.abs(explanations[name]), axis=0), align="center")
167
            ax.set_xlabel("Features")
168
            ax.set_title(f"{name}")
169
170
            idx += 1
171
        plt.tight_layout()