--- a +++ b/src/iterpretability/utils.py @@ -0,0 +1,186 @@ +# stdlib +import random +from typing import Optional + +# third party +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +import torch +from matplotlib.lines import Line2D +from sklearn.metrics import mean_squared_error + + +abbrev_dict = { + "shapley_value_sampling": "SVS", + "integrated_gradients": "IG", + "kernel_shap": "SHAP", + "gradient_shap": "GSHAP", + "feature_permutation": "FP", + "feature_ablation": "FA", + "deeplift": "DL", + "lime": "LIME", +} + +explainer_symbols = { + "shapley_value_sampling": "D", + "integrated_gradients": "8", + "kernel_shap": "s", + "feature_permutation": "<", + "feature_ablation": "x", + "deeplift": "H", + "lime": ">", +} + +cblind_palete = sns.color_palette("colorblind", as_cmap=True) +learner_colors = { + "SLearner": cblind_palete[0], + "TLearner": cblind_palete[1], + "TARNet": cblind_palete[3], + "CFRNet_0.01": cblind_palete[4], + "CFRNet_0.001": cblind_palete[6], + "CFRNet_0.0001": cblind_palete[7], + "DRLearner": cblind_palete[8], + "XLearner": cblind_palete[5], + "Truth": cblind_palete[9], +} + + +def enable_reproducible_results(seed: int = 42) -> None: + """ + Set a fixed seed for all the used libraries + + Args: + seed: int + The seed to use + """ + np.random.seed(seed) + torch.manual_seed(seed) + random.seed(seed) + + +def dataframe_line_plot( + df: pd.DataFrame, + x_axis: str, + y_axis: str, + explainers: list, + learners: list, + x_logscale: bool = True, + aggregate: bool = False, + aggregate_type: str = "mean", +) -> plt.Figure: + fig = plt.figure() + ax = fig.add_subplot(1, 1, 1) + sns.set_style("white") + for learner_name in learners: + for explainer_name in explainers: + sub_df = df.loc[ + (df["Learner"] == learner_name) & (df["Explainer"] == explainer_name) + ] + if aggregate: + sub_df = sub_df.groupby(x_axis).agg(aggregate_type).reset_index() + x_values = sub_df.loc[:, x_axis].values + y_values = sub_df.loc[:, y_axis].values + ax.plot( + x_values, + y_values, + color=learner_colors[learner_name], + marker=explainer_symbols[explainer_name], + ) + + learner_lines = [ + Line2D([0], [0], color=learner_colors[learner_name], lw=2) + for learner_name in learners + ] + explainer_lines = [ + Line2D([0], [0], color="black", marker=explainer_symbols[explainer_name]) + for explainer_name in explainers + ] + + legend_learners = plt.legend( + learner_lines, learners, loc="lower left", bbox_to_anchor=(1.04, 0.7) + ) + legend_explainers = plt.legend( + explainer_lines, + [abbrev_dict[explainer_name] for explainer_name in explainers], + loc="lower left", + bbox_to_anchor=(1.04, 0), + ) + plt.subplots_adjust(right=0.75) + ax.add_artist(legend_learners) + ax.add_artist(legend_explainers) + if x_logscale: + ax.set_xscale("log") + ax.set_xlabel(x_axis) + ax.set_ylabel(y_axis) + return fig + + +def compute_pehe( + cate_true: np.ndarray, + cate_pred: torch.Tensor, +) -> tuple: + pehe = np.sqrt(mean_squared_error(cate_true, cate_pred.detach().cpu().numpy())) + return pehe + + +def compute_cate_metrics( + cate_true: np.ndarray, + y_true: np.ndarray, + w_true: np.ndarray, + mu0_pred: torch.Tensor, + mu1_pred: torch.Tensor, +) -> tuple: + mu0_pred = mu0_pred.detach().cpu().numpy() + mu1_pred = mu1_pred.detach().cpu().numpy() + + cate_pred = mu1_pred - mu0_pred + + pehe = np.sqrt(mean_squared_error(cate_true, cate_pred)) + + y_pred = w_true.reshape(len(cate_true),) * mu1_pred.reshape(len(cate_true),) + ( + 1 + - w_true.reshape( + len(cate_true), + ) + ) * mu0_pred.reshape( + len(cate_true), + ) + factual_rmse = np.sqrt( + mean_squared_error( + y_true.reshape( + len(cate_true), + ), + y_pred, + ) + ) + return pehe, factual_rmse + + +def attribution_accuracy( + target_features: list, feature_attributions: np.ndarray +) -> float: + """ + Computes the fraction of the most important features that are truly important + Args: + target_features: list of truly important feature indices + feature_attributions: feature attribution outputted by a feature importance method + + Returns: + Fraction of the most important features that are truly important + """ + + if target_features is None: + return -1 + + n_important = len(target_features) # Number of features that are important + largest_attribution_idx = torch.topk( + torch.from_numpy(feature_attributions), n_important + )[ + 1 + ] # Features with largest attribution + accuracy = 0 # Attribution accuracy + for k in range(len(largest_attribution_idx)): + accuracy += len(np.intersect1d(largest_attribution_idx[k], target_features)) + return accuracy / (len(feature_attributions) * n_important)