--- a +++ b/src/iterpretability/explain.py @@ -0,0 +1,171 @@ +from typing import Dict, List, Optional + +import matplotlib.pyplot as plt +import numpy as np +import torch +from captum._utils.models.linear_model import SkLearnLinearRegression +from captum.attr import ( + DeepLift, + FeatureAblation, + FeaturePermutation, + IntegratedGradients, + KernelShap, + Lime, + ShapleyValueSampling, + GradientShap, +) +from captum.attr._core.lime import get_exp_kernel_similarity_function +from torch import nn + + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +class Explainer: + """ + Explainer instance, consisting of several explainability methods. + """ + + def __init__( + self, + model: nn.Module, + feature_names: List, + explainer_list: List = [ + "feature_ablation", + "feature_permutation", + "integrated_gradients", + "deeplift", + "shapley_value_sampling", + "lime", + ], + n_steps: int = 500, + perturbations_per_eval: int = 10, + n_samples: int = 1000, + kernel_width: float = 1.0, + baseline: Optional[torch.Tensor] = None, + ) -> None: + self.baseline = baseline + self.explainer_list = explainer_list + self.feature_names = feature_names + + # Feature ablation + feature_ablation_model = FeatureAblation(model) + + def feature_ablation_cbk(X_test: torch.Tensor) -> torch.Tensor: + out = feature_ablation_model.attribute( + X_test, n_steps=n_steps, perturbations_per_eval=perturbations_per_eval + ) + + return out + + # Integrated gradients + integrated_gradients_model = IntegratedGradients(model) + + def integrated_gradients_cbk(X_test: torch.Tensor) -> torch.Tensor: + return integrated_gradients_model.attribute(X_test, n_steps=n_steps) + + # DeepLift + deeplift_model = DeepLift(model) + + def deeplift_cbk(X_test: torch.Tensor) -> torch.Tensor: + return deeplift_model.attribute(X_test) + + # Feature permutation + feature_permutation_model = FeaturePermutation(model) + + def feature_permutation_cbk(X_test: torch.Tensor) -> torch.Tensor: + return feature_permutation_model.attribute( + X_test, n_steps=n_steps, perturbations_per_eval=perturbations_per_eval + ) + + # LIME + exp_eucl_distance = get_exp_kernel_similarity_function( + kernel_width=kernel_width + ) + lime_model = Lime( + model, + interpretable_model=SkLearnLinearRegression(), + similarity_func=exp_eucl_distance, + ) + + def lime_cbk(X_test: torch.Tensor) -> torch.Tensor: + return lime_model.attribute( + X_test, + n_samples=n_samples, + perturbations_per_eval=perturbations_per_eval, + ) + + # Shapley value sampling + shapley_value_sampling_model = ShapleyValueSampling(model) + + def shapley_value_sampling_cbk(X_test: torch.Tensor) -> torch.Tensor: + return shapley_value_sampling_model.attribute( + X_test, + n_samples=n_samples, + perturbations_per_eval=perturbations_per_eval, + ) + + # Kernel SHAP + kernel_shap_model = KernelShap(model) + + def kernel_shap_cbk(X_test: torch.Tensor) -> torch.Tensor: + return kernel_shap_model.attribute( + X_test, + n_samples=n_samples, + perturbations_per_eval=perturbations_per_eval, + ) + + # Gradient SHAP + gradient_shap_model = GradientShap(model) + + def gradient_shap_cbk(X_test: torch.Tensor) -> torch.Tensor: + return gradient_shap_model.attribute(X_test, baselines=self.baseline) + + self.explainers = { + "feature_ablation": feature_ablation_cbk, + "integrated_gradients": integrated_gradients_cbk, + "deeplift": deeplift_cbk, + "feature_permutation": feature_permutation_cbk, + "lime": lime_cbk, + "shapley_value_sampling": shapley_value_sampling_cbk, + "kernel_shap": kernel_shap_cbk, + "gradient_shap": gradient_shap_cbk, + } + + def _check_tensor(self, X: torch.Tensor) -> torch.Tensor: + if isinstance(X, torch.Tensor): + return X.to(DEVICE) + else: + return torch.from_numpy(np.asarray(X)).float().to(DEVICE) + + def explain(self, X: torch.Tensor) -> Dict: + output = {} + if self.baseline is None: + self.baseline = torch.zeros( + X.shape + ) # Zero tensor as baseline if no baseline specified + for name in self.explainer_list: + X_test = self._check_tensor(X) + self.baseline = self._check_tensor(self.baseline) + X_test.requires_grad_() + explainer = self.explainers[name] + output[name] = explainer(X_test).detach().cpu().numpy() + return output + + def plot(self, X: torch.Tensor) -> None: + explanations = self.explain(X) + + fig, axs = plt.subplots(int((len(explanations) + 1) / 2), 2) + + idx = 0 + for name in explanations: + x_pos = np.arange(len(self.feature_names)) + + ax = axs[int(idx / 2), idx % 2] + + ax.bar(x_pos, np.mean(np.abs(explanations[name]), axis=0), align="center") + ax.set_xlabel("Features") + ax.set_title(f"{name}") + + idx += 1 + plt.tight_layout()