|
a |
|
b/src/iterpretability/utils.py |
|
|
1 |
# stdlib |
|
|
2 |
import random |
|
|
3 |
from typing import Optional |
|
|
4 |
|
|
|
5 |
# third party |
|
|
6 |
import matplotlib.pyplot as plt |
|
|
7 |
import numpy as np |
|
|
8 |
import pandas as pd |
|
|
9 |
import seaborn as sns |
|
|
10 |
import torch |
|
|
11 |
from matplotlib.lines import Line2D |
|
|
12 |
from sklearn.metrics import mean_squared_error |
|
|
13 |
|
|
|
14 |
|
|
|
15 |
abbrev_dict = { |
|
|
16 |
"shapley_value_sampling": "SVS", |
|
|
17 |
"integrated_gradients": "IG", |
|
|
18 |
"kernel_shap": "SHAP", |
|
|
19 |
"gradient_shap": "GSHAP", |
|
|
20 |
"feature_permutation": "FP", |
|
|
21 |
"feature_ablation": "FA", |
|
|
22 |
"deeplift": "DL", |
|
|
23 |
"lime": "LIME", |
|
|
24 |
} |
|
|
25 |
|
|
|
26 |
explainer_symbols = { |
|
|
27 |
"shapley_value_sampling": "D", |
|
|
28 |
"integrated_gradients": "8", |
|
|
29 |
"kernel_shap": "s", |
|
|
30 |
"feature_permutation": "<", |
|
|
31 |
"feature_ablation": "x", |
|
|
32 |
"deeplift": "H", |
|
|
33 |
"lime": ">", |
|
|
34 |
} |
|
|
35 |
|
|
|
36 |
cblind_palete = sns.color_palette("colorblind", as_cmap=True) |
|
|
37 |
learner_colors = { |
|
|
38 |
"SLearner": cblind_palete[0], |
|
|
39 |
"TLearner": cblind_palete[1], |
|
|
40 |
"TARNet": cblind_palete[3], |
|
|
41 |
"CFRNet_0.01": cblind_palete[4], |
|
|
42 |
"CFRNet_0.001": cblind_palete[6], |
|
|
43 |
"CFRNet_0.0001": cblind_palete[7], |
|
|
44 |
"DRLearner": cblind_palete[8], |
|
|
45 |
"XLearner": cblind_palete[5], |
|
|
46 |
"Truth": cblind_palete[9], |
|
|
47 |
} |
|
|
48 |
|
|
|
49 |
|
|
|
50 |
def enable_reproducible_results(seed: int = 42) -> None: |
|
|
51 |
""" |
|
|
52 |
Set a fixed seed for all the used libraries |
|
|
53 |
|
|
|
54 |
Args: |
|
|
55 |
seed: int |
|
|
56 |
The seed to use |
|
|
57 |
""" |
|
|
58 |
np.random.seed(seed) |
|
|
59 |
torch.manual_seed(seed) |
|
|
60 |
random.seed(seed) |
|
|
61 |
|
|
|
62 |
|
|
|
63 |
def dataframe_line_plot( |
|
|
64 |
df: pd.DataFrame, |
|
|
65 |
x_axis: str, |
|
|
66 |
y_axis: str, |
|
|
67 |
explainers: list, |
|
|
68 |
learners: list, |
|
|
69 |
x_logscale: bool = True, |
|
|
70 |
aggregate: bool = False, |
|
|
71 |
aggregate_type: str = "mean", |
|
|
72 |
) -> plt.Figure: |
|
|
73 |
fig = plt.figure() |
|
|
74 |
ax = fig.add_subplot(1, 1, 1) |
|
|
75 |
sns.set_style("white") |
|
|
76 |
for learner_name in learners: |
|
|
77 |
for explainer_name in explainers: |
|
|
78 |
sub_df = df.loc[ |
|
|
79 |
(df["Learner"] == learner_name) & (df["Explainer"] == explainer_name) |
|
|
80 |
] |
|
|
81 |
if aggregate: |
|
|
82 |
sub_df = sub_df.groupby(x_axis).agg(aggregate_type).reset_index() |
|
|
83 |
x_values = sub_df.loc[:, x_axis].values |
|
|
84 |
y_values = sub_df.loc[:, y_axis].values |
|
|
85 |
ax.plot( |
|
|
86 |
x_values, |
|
|
87 |
y_values, |
|
|
88 |
color=learner_colors[learner_name], |
|
|
89 |
marker=explainer_symbols[explainer_name], |
|
|
90 |
) |
|
|
91 |
|
|
|
92 |
learner_lines = [ |
|
|
93 |
Line2D([0], [0], color=learner_colors[learner_name], lw=2) |
|
|
94 |
for learner_name in learners |
|
|
95 |
] |
|
|
96 |
explainer_lines = [ |
|
|
97 |
Line2D([0], [0], color="black", marker=explainer_symbols[explainer_name]) |
|
|
98 |
for explainer_name in explainers |
|
|
99 |
] |
|
|
100 |
|
|
|
101 |
legend_learners = plt.legend( |
|
|
102 |
learner_lines, learners, loc="lower left", bbox_to_anchor=(1.04, 0.7) |
|
|
103 |
) |
|
|
104 |
legend_explainers = plt.legend( |
|
|
105 |
explainer_lines, |
|
|
106 |
[abbrev_dict[explainer_name] for explainer_name in explainers], |
|
|
107 |
loc="lower left", |
|
|
108 |
bbox_to_anchor=(1.04, 0), |
|
|
109 |
) |
|
|
110 |
plt.subplots_adjust(right=0.75) |
|
|
111 |
ax.add_artist(legend_learners) |
|
|
112 |
ax.add_artist(legend_explainers) |
|
|
113 |
if x_logscale: |
|
|
114 |
ax.set_xscale("log") |
|
|
115 |
ax.set_xlabel(x_axis) |
|
|
116 |
ax.set_ylabel(y_axis) |
|
|
117 |
return fig |
|
|
118 |
|
|
|
119 |
|
|
|
120 |
def compute_pehe( |
|
|
121 |
cate_true: np.ndarray, |
|
|
122 |
cate_pred: torch.Tensor, |
|
|
123 |
) -> tuple: |
|
|
124 |
pehe = np.sqrt(mean_squared_error(cate_true, cate_pred.detach().cpu().numpy())) |
|
|
125 |
return pehe |
|
|
126 |
|
|
|
127 |
|
|
|
128 |
def compute_cate_metrics( |
|
|
129 |
cate_true: np.ndarray, |
|
|
130 |
y_true: np.ndarray, |
|
|
131 |
w_true: np.ndarray, |
|
|
132 |
mu0_pred: torch.Tensor, |
|
|
133 |
mu1_pred: torch.Tensor, |
|
|
134 |
) -> tuple: |
|
|
135 |
mu0_pred = mu0_pred.detach().cpu().numpy() |
|
|
136 |
mu1_pred = mu1_pred.detach().cpu().numpy() |
|
|
137 |
|
|
|
138 |
cate_pred = mu1_pred - mu0_pred |
|
|
139 |
|
|
|
140 |
pehe = np.sqrt(mean_squared_error(cate_true, cate_pred)) |
|
|
141 |
|
|
|
142 |
y_pred = w_true.reshape(len(cate_true),) * mu1_pred.reshape(len(cate_true),) + ( |
|
|
143 |
1 |
|
|
144 |
- w_true.reshape( |
|
|
145 |
len(cate_true), |
|
|
146 |
) |
|
|
147 |
) * mu0_pred.reshape( |
|
|
148 |
len(cate_true), |
|
|
149 |
) |
|
|
150 |
factual_rmse = np.sqrt( |
|
|
151 |
mean_squared_error( |
|
|
152 |
y_true.reshape( |
|
|
153 |
len(cate_true), |
|
|
154 |
), |
|
|
155 |
y_pred, |
|
|
156 |
) |
|
|
157 |
) |
|
|
158 |
return pehe, factual_rmse |
|
|
159 |
|
|
|
160 |
|
|
|
161 |
def attribution_accuracy( |
|
|
162 |
target_features: list, feature_attributions: np.ndarray |
|
|
163 |
) -> float: |
|
|
164 |
""" |
|
|
165 |
Computes the fraction of the most important features that are truly important |
|
|
166 |
Args: |
|
|
167 |
target_features: list of truly important feature indices |
|
|
168 |
feature_attributions: feature attribution outputted by a feature importance method |
|
|
169 |
|
|
|
170 |
Returns: |
|
|
171 |
Fraction of the most important features that are truly important |
|
|
172 |
""" |
|
|
173 |
|
|
|
174 |
if target_features is None: |
|
|
175 |
return -1 |
|
|
176 |
|
|
|
177 |
n_important = len(target_features) # Number of features that are important |
|
|
178 |
largest_attribution_idx = torch.topk( |
|
|
179 |
torch.from_numpy(feature_attributions), n_important |
|
|
180 |
)[ |
|
|
181 |
1 |
|
|
182 |
] # Features with largest attribution |
|
|
183 |
accuracy = 0 # Attribution accuracy |
|
|
184 |
for k in range(len(largest_attribution_idx)): |
|
|
185 |
accuracy += len(np.intersect1d(largest_attribution_idx[k], target_features)) |
|
|
186 |
return accuracy / (len(feature_attributions) * n_important) |