Switch to side-by-side view

--- a
+++ b/src/iterpretability/explain.py
@@ -0,0 +1,171 @@
+from typing import Dict, List, Optional
+
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+from captum._utils.models.linear_model import SkLearnLinearRegression
+from captum.attr import (
+    DeepLift,
+    FeatureAblation,
+    FeaturePermutation,
+    IntegratedGradients,
+    KernelShap,
+    Lime,
+    ShapleyValueSampling,
+    GradientShap,
+)
+from captum.attr._core.lime import get_exp_kernel_similarity_function
+from torch import nn
+
+
+DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+
+class Explainer:
+    """
+    Explainer instance, consisting of several explainability methods.
+    """
+
+    def __init__(
+        self,
+        model: nn.Module,
+        feature_names: List,
+        explainer_list: List = [
+            "feature_ablation",
+            "feature_permutation",
+            "integrated_gradients",
+            "deeplift",
+            "shapley_value_sampling",
+            "lime",
+        ],
+        n_steps: int = 500,
+        perturbations_per_eval: int = 10,
+        n_samples: int = 1000,
+        kernel_width: float = 1.0,
+        baseline: Optional[torch.Tensor] = None,
+    ) -> None:
+        self.baseline = baseline
+        self.explainer_list = explainer_list
+        self.feature_names = feature_names
+
+        # Feature ablation
+        feature_ablation_model = FeatureAblation(model)
+
+        def feature_ablation_cbk(X_test: torch.Tensor) -> torch.Tensor:
+            out = feature_ablation_model.attribute(
+                X_test, n_steps=n_steps, perturbations_per_eval=perturbations_per_eval
+            )
+
+            return out
+
+        # Integrated gradients
+        integrated_gradients_model = IntegratedGradients(model)
+
+        def integrated_gradients_cbk(X_test: torch.Tensor) -> torch.Tensor:
+            return integrated_gradients_model.attribute(X_test, n_steps=n_steps)
+
+        # DeepLift
+        deeplift_model = DeepLift(model)
+
+        def deeplift_cbk(X_test: torch.Tensor) -> torch.Tensor:
+            return deeplift_model.attribute(X_test)
+
+        # Feature permutation
+        feature_permutation_model = FeaturePermutation(model)
+
+        def feature_permutation_cbk(X_test: torch.Tensor) -> torch.Tensor:
+            return feature_permutation_model.attribute(
+                X_test, n_steps=n_steps, perturbations_per_eval=perturbations_per_eval
+            )
+
+        # LIME
+        exp_eucl_distance = get_exp_kernel_similarity_function(
+            kernel_width=kernel_width
+        )
+        lime_model = Lime(
+            model,
+            interpretable_model=SkLearnLinearRegression(),
+            similarity_func=exp_eucl_distance,
+        )
+
+        def lime_cbk(X_test: torch.Tensor) -> torch.Tensor:
+            return lime_model.attribute(
+                X_test,
+                n_samples=n_samples,
+                perturbations_per_eval=perturbations_per_eval,
+            )
+
+        # Shapley value sampling
+        shapley_value_sampling_model = ShapleyValueSampling(model)
+
+        def shapley_value_sampling_cbk(X_test: torch.Tensor) -> torch.Tensor:
+            return shapley_value_sampling_model.attribute(
+                X_test,
+                n_samples=n_samples,
+                perturbations_per_eval=perturbations_per_eval,
+            )
+
+        # Kernel SHAP
+        kernel_shap_model = KernelShap(model)
+
+        def kernel_shap_cbk(X_test: torch.Tensor) -> torch.Tensor:
+            return kernel_shap_model.attribute(
+                X_test,
+                n_samples=n_samples,
+                perturbations_per_eval=perturbations_per_eval,
+            )
+
+        # Gradient SHAP
+        gradient_shap_model = GradientShap(model)
+
+        def gradient_shap_cbk(X_test: torch.Tensor) -> torch.Tensor:
+            return gradient_shap_model.attribute(X_test, baselines=self.baseline)
+
+        self.explainers = {
+            "feature_ablation": feature_ablation_cbk,
+            "integrated_gradients": integrated_gradients_cbk,
+            "deeplift": deeplift_cbk,
+            "feature_permutation": feature_permutation_cbk,
+            "lime": lime_cbk,
+            "shapley_value_sampling": shapley_value_sampling_cbk,
+            "kernel_shap": kernel_shap_cbk,
+            "gradient_shap": gradient_shap_cbk,
+        }
+
+    def _check_tensor(self, X: torch.Tensor) -> torch.Tensor:
+        if isinstance(X, torch.Tensor):
+            return X.to(DEVICE)
+        else:
+            return torch.from_numpy(np.asarray(X)).float().to(DEVICE)
+
+    def explain(self, X: torch.Tensor) -> Dict:
+        output = {}
+        if self.baseline is None:
+            self.baseline = torch.zeros(
+                X.shape
+            )  # Zero tensor as baseline if no baseline specified
+        for name in self.explainer_list:
+            X_test = self._check_tensor(X)
+            self.baseline = self._check_tensor(self.baseline)
+            X_test.requires_grad_()
+            explainer = self.explainers[name]
+            output[name] = explainer(X_test).detach().cpu().numpy()
+        return output
+
+    def plot(self, X: torch.Tensor) -> None:
+        explanations = self.explain(X)
+
+        fig, axs = plt.subplots(int((len(explanations) + 1) / 2), 2)
+
+        idx = 0
+        for name in explanations:
+            x_pos = np.arange(len(self.feature_names))
+
+            ax = axs[int(idx / 2), idx % 2]
+
+            ax.bar(x_pos, np.mean(np.abs(explanations[name]), axis=0), align="center")
+            ax.set_xlabel("Features")
+            ax.set_title(f"{name}")
+
+            idx += 1
+        plt.tight_layout()