Switch to unified view

a b/src/iterpretability/utils.py
1
# stdlib
2
import random
3
from typing import Optional
4
5
# third party
6
import matplotlib.pyplot as plt
7
import numpy as np
8
import pandas as pd
9
import seaborn as sns
10
import torch
11
from matplotlib.lines import Line2D
12
from sklearn.metrics import mean_squared_error
13
14
15
abbrev_dict = {
16
    "shapley_value_sampling": "SVS",
17
    "integrated_gradients": "IG",
18
    "kernel_shap": "SHAP",
19
    "gradient_shap": "GSHAP",
20
    "feature_permutation": "FP",
21
    "feature_ablation": "FA",
22
    "deeplift": "DL",
23
    "lime": "LIME",
24
}
25
26
explainer_symbols = {
27
    "shapley_value_sampling": "D",
28
    "integrated_gradients": "8",
29
    "kernel_shap": "s",
30
    "feature_permutation": "<",
31
    "feature_ablation": "x",
32
    "deeplift": "H",
33
    "lime": ">",
34
}
35
36
cblind_palete = sns.color_palette("colorblind", as_cmap=True)
37
learner_colors = {
38
    "SLearner": cblind_palete[0],
39
    "TLearner": cblind_palete[1],
40
    "TARNet": cblind_palete[3],
41
    "CFRNet_0.01": cblind_palete[4],
42
    "CFRNet_0.001": cblind_palete[6],
43
    "CFRNet_0.0001": cblind_palete[7],
44
    "DRLearner": cblind_palete[8],
45
    "XLearner": cblind_palete[5],
46
    "Truth": cblind_palete[9],
47
}
48
49
50
def enable_reproducible_results(seed: int = 42) -> None:
51
    """
52
    Set a fixed seed for all the used libraries
53
54
    Args:
55
        seed: int
56
            The seed to use
57
    """
58
    np.random.seed(seed)
59
    torch.manual_seed(seed)
60
    random.seed(seed)
61
62
63
def dataframe_line_plot(
64
    df: pd.DataFrame,
65
    x_axis: str,
66
    y_axis: str,
67
    explainers: list,
68
    learners: list,
69
    x_logscale: bool = True,
70
    aggregate: bool = False,
71
    aggregate_type: str = "mean",
72
) -> plt.Figure:
73
    fig = plt.figure()
74
    ax = fig.add_subplot(1, 1, 1)
75
    sns.set_style("white")
76
    for learner_name in learners:
77
        for explainer_name in explainers:
78
            sub_df = df.loc[
79
                (df["Learner"] == learner_name) & (df["Explainer"] == explainer_name)
80
            ]
81
            if aggregate:
82
                sub_df = sub_df.groupby(x_axis).agg(aggregate_type).reset_index()
83
            x_values = sub_df.loc[:, x_axis].values
84
            y_values = sub_df.loc[:, y_axis].values
85
            ax.plot(
86
                x_values,
87
                y_values,
88
                color=learner_colors[learner_name],
89
                marker=explainer_symbols[explainer_name],
90
            )
91
92
    learner_lines = [
93
        Line2D([0], [0], color=learner_colors[learner_name], lw=2)
94
        for learner_name in learners
95
    ]
96
    explainer_lines = [
97
        Line2D([0], [0], color="black", marker=explainer_symbols[explainer_name])
98
        for explainer_name in explainers
99
    ]
100
101
    legend_learners = plt.legend(
102
        learner_lines, learners, loc="lower left", bbox_to_anchor=(1.04, 0.7)
103
    )
104
    legend_explainers = plt.legend(
105
        explainer_lines,
106
        [abbrev_dict[explainer_name] for explainer_name in explainers],
107
        loc="lower left",
108
        bbox_to_anchor=(1.04, 0),
109
    )
110
    plt.subplots_adjust(right=0.75)
111
    ax.add_artist(legend_learners)
112
    ax.add_artist(legend_explainers)
113
    if x_logscale:
114
        ax.set_xscale("log")
115
    ax.set_xlabel(x_axis)
116
    ax.set_ylabel(y_axis)
117
    return fig
118
119
120
def compute_pehe(
121
    cate_true: np.ndarray,
122
    cate_pred: torch.Tensor,
123
) -> tuple:
124
    pehe = np.sqrt(mean_squared_error(cate_true, cate_pred.detach().cpu().numpy()))
125
    return pehe
126
127
128
def compute_cate_metrics(
129
    cate_true: np.ndarray,
130
    y_true: np.ndarray,
131
    w_true: np.ndarray,
132
    mu0_pred: torch.Tensor,
133
    mu1_pred: torch.Tensor,
134
) -> tuple:
135
    mu0_pred = mu0_pred.detach().cpu().numpy()
136
    mu1_pred = mu1_pred.detach().cpu().numpy()
137
138
    cate_pred = mu1_pred - mu0_pred
139
140
    pehe = np.sqrt(mean_squared_error(cate_true, cate_pred))
141
142
    y_pred = w_true.reshape(len(cate_true),) * mu1_pred.reshape(len(cate_true),) + (
143
        1
144
        - w_true.reshape(
145
            len(cate_true),
146
        )
147
    ) * mu0_pred.reshape(
148
        len(cate_true),
149
    )
150
    factual_rmse = np.sqrt(
151
        mean_squared_error(
152
            y_true.reshape(
153
                len(cate_true),
154
            ),
155
            y_pred,
156
        )
157
    )
158
    return pehe, factual_rmse
159
160
161
def attribution_accuracy(
162
    target_features: list, feature_attributions: np.ndarray
163
) -> float:
164
    """
165
    Computes the fraction of the most important features that are truly important
166
    Args:
167
        target_features: list of truly important feature indices
168
        feature_attributions: feature attribution outputted by a feature importance method
169
170
    Returns:
171
        Fraction of the most important features that are truly important
172
    """
173
174
    if target_features is None:
175
        return -1
176
    
177
    n_important = len(target_features)  # Number of features that are important
178
    largest_attribution_idx = torch.topk(
179
        torch.from_numpy(feature_attributions), n_important
180
    )[
181
        1
182
    ]  # Features with largest attribution
183
    accuracy = 0  # Attribution accuracy
184
    for k in range(len(largest_attribution_idx)):
185
        accuracy += len(np.intersect1d(largest_attribution_idx[k], target_features))
186
    return accuracy / (len(feature_attributions) * n_important)