--- a +++ b/src/plotting.py @@ -0,0 +1,902 @@ +import seaborn as sns +import matplotlib.pyplot as plt +import pickle +import pandas as pd +import numpy as np +from pathlib import Path +import warnings +warnings.simplefilter(action='ignore', category=FutureWarning) +from PIL import Image +import sys + +# Hydra for configuration +import hydra +from omegaconf import DictConfig, OmegaConf +from matplotlib.ticker import ScalarFormatter +from matplotlib.ticker import MaxNLocator +from matplotlib.ticker import FuncFormatter + +# Custom formatter function +def custom_formatter(x, pos): + if x.is_integer(): + return f'{int(x)}' + elif x==0.5: + return r'$1/2$' + elif x==0.25: + return r'$1/4$' + # Do a diagonal fraction instead + + else: + return f'{x:.2f}' + +cblind_palete = sns.color_palette("colorblind", as_cmap=True) +learner_colors = { + "Torch_SLearner": cblind_palete[0], + "Torch_TLearner": cblind_palete[1], + "Torch_XLearner": cblind_palete[2], + "Torch_TARNet": cblind_palete[3], + 'Torch_CFRNet_0.01': cblind_palete[4], + "Torch_CFRNet_0.001": cblind_palete[6], + 'Torch_CFRNet_0.0001': cblind_palete[9], + 'Torch_ActionNet': cblind_palete[7], + "Torch_DRLearner": cblind_palete[8], + "Torch_RALearner": cblind_palete[9], + "Torch_DragonNet": cblind_palete[5], + "Torch_DragonNet_2": cblind_palete[5], + "Torch_DragonNet_4": cblind_palete[3], + "Torch_ULearner": cblind_palete[6], + "Torch_PWLearner": cblind_palete[7], + "Torch_RLearner": cblind_palete[8], + "Torch_FlexTENet": cblind_palete[9], + "EconML_CausalForestDML": cblind_palete[2], + "EconML_DML": cblind_palete[0], + "EconML_DMLOrthoForest": cblind_palete[1], + "EconML_DRLearner": cblind_palete[6], + "EconML_DROrthoForest": cblind_palete[9], + "EconML_ForestDRLearner": cblind_palete[7], + "EconML_LinearDML": cblind_palete[8], + "EconML_LinearDRLearner": cblind_palete[5], + "EconML_SparseLinearDML": cblind_palete[3], + "EconML_SparseLinearDRLearner": cblind_palete[4], + "EconML_XLearner_Lasso": cblind_palete[7], + "EconML_TLearner_Lasso": cblind_palete[8], + "EconML_SLearner_Lasso": cblind_palete[9], + "DiffPOLearner": cblind_palete[0], + "Truth": cblind_palete[9], +} + +learner_linestyles = { + "Torch_SLearner": "-", + "Torch_TLearner": "--", + "Torch_XLearner": ":", + "Torch_TARNet": "-.", + "Torch_DragonNet": "--", + "Torch_DragonNet_2": "-", + "Torch_DragonNet_4": "-.", + "Torch_XLearner": "--", + "Torch_CFRNet_0.01": "-", + "Torch_CFRNet_0.001": ":", + "Torch_CFRNet_0.0001": "--", + "Torch_DRLearner": "-", + "Torch_RALearner": "--", + "Torch_ULearner": "-", + "Torch_PWLearner": "-", + "Torch_RLearner": "-", + "Torch_FlexTENet": "-", + 'Torch_ActionNet': "-", + "EconML_CausalForestDML": "-", + "EconML_DML": "--", + "EconML_DMLOrthoForest": ":", + "EconML_DRLearner": "-.", + "EconML_DROrthoForest": "--", + "EconML_ForestDRLearner": "-.", + "EconML_LinearDML": ":", + "EconML_LinearDRLearner": "-", + "EconML_SparseLinearDML": "--", + "EconML_SparseLinearDRLearner": ":", + "EconML_SLearner_Lasso": "-.", + "EconML_TLearner_Lasso": "--", + "EconML_XLearner_Lasso": "-", + "DiffPOLearner": "-.", + "Truth": ":", +} + + +learner_markers = { + "Torch_SLearner": "d", + "Torch_TLearner": "o", + "Torch_XLearner": "^", + "Torch_TARNet": "*", + "Torch_DragonNet": "x", + "Torch_DragonNet_2": "o", + "Torch_DragonNet_4": "*", + "Torch_XLearner": "D", + "Torch_CFRNet_0.01": "8", + "Torch_CFRNet_0.001": "s", + "Torch_CFRNet_0.0001": "x", + "Torch_DRLearner": "x", + "Torch_RALearner": "H", + "Torch_ULearner": "x", + "Torch_PWLearner": "*", + "Torch_RLearner": "*", + "Torch_FlexTENet": "*", + 'Torch_ActionNet': "*", + "EconML_CausalForestDML": "d", + "EconML_DML": "o", + "EconML_DMLOrthoForest": "^", + "EconML_DRLearner": "*", + "EconML_DROrthoForest": "D", + "EconML_ForestDRLearner": "8", + "EconML_LinearDML": "s", + "EconML_LinearDRLearner": "x", + "EconML_SparseLinearDML": "x", + "EconML_SparseLinearDRLearner": "H", + "EconML_TLearner_Lasso": "o", + "EconML_SLearner_Lasso": "^", + "EconML_XLearner_Lasso": "d", + "DiffPOLearner": "H", + "Truth": "<", +} + +datasets_names_map = { + "tcga_100": "TCGA", + "twins": "Twins", + "news_100": "News", + "all_notupro_technologies": "AllNoTuproTechnologies", + "all_notupro_technologies_small": "AllNoTuproTechnologiesSmall", + "dummy_data": "DummyData", + "selected_technologies_pategan_1000": "selected_technologies_pategan_1000", + "selected_technologies_with_fastdrug": "selected_technologies_with_fastdrug", + "cytof_normalized":"cytof_normalized", + "cytof_normalized_with_fastdrug":"cytof_normalized_with_fastdrug", + "cytof_pategan_1000_normalized": "cytof_pategan_1000_normalized", + "all_notupro_technologies_with_fastdrug": "all_notupro_technologies_with_fastdrug", + "acic": "ACIC2016", + "depmap_drug_screen_2_drugs": "depmap_drug_screen_2_drugs", + "depmap_drug_screen_2_drugs_norm": "depmap_drug_screen_2_drugs_norm", + "depmap_drug_screen_2_drugs_all_features": "depmap_drug_screen_2_drugs_all_features", + "depmap_drug_screen_2_drugs_all_features_norm": "depmap_drug_screen_2_drugs_all_features_norm", + "depmap_crispr_screen_2_kos": "depmap_crispr_screen_2_kos", + "depmap_crispr_screen_2_kos_norm": "depmap_crispr_screen_2_kos_norm", + "depmap_crispr_screen_2_kos_all_features": "depmap_crispr_screen_2_kos_all_features", + "depmap_crispr_screen_2_kos_all_features_norm": "depmap_crispr_screen_2_kos_all_features_norm", + "depmap_drug_screen_2_drugs_100pcs_norm":"depmap_drug_screen_2_drugs_100pcs_norm", + "depmap_drug_screen_2_drugs_5000hv_norm":"depmap_drug_screen_2_drugs_5000hv_norm", + "depmap_crispr_screen_2_kos_100pcs_norm":"depmap_crispr_screen_2_kos_100pcs_norm", + "depmap_crispr_screen_2_kos_5000hv_norm":"depmap_crispr_screen_2_kos_5000hv_norm", + "independent_normally_dist": "independent_normally_dist", + "ovarian_semi_synthetic_l1":"ovarian_semi_synthetic_l1", + "ovarian_semi_synthetic_rf":"ovarian_semi_synthetic_rf", + "melanoma_semi_synthetic_l1": "melanoma_semi_synthetic_l1", + "pred": "Predictive confounding", + "prog": "Prognostic confounding", + "irrelevant_var": "Non-confounded propensity", + "selective": "General confounding"} + +metric_names_map = { + 'Pred: Pred features ACC': r'Predictive $\mathrm{Attr}$', #^{\mathrm{pred}}_{\mathrm{pred}}$', + 'Pred: Prog features ACC': r'$\mathrm{Attr}^{\mathrm{pred}}_{\mathrm{prog}}$', + 'Pred: Select features ACC': r'$\mathrm{Attr}^{\mathrm{pred}}_{\mathrm{select}}$', + 'Prog: Pred features ACC': r'$\mathrm{Attr}^{\mathrm{prog}}_{\mathrm{pred}}$', + 'Prog: Prog features ACC': r'Prognostic $\mathrm{Attr}$', #^{\mathrm{prog}}$', #_{\mathrm{prog}}$', + 'Prog: Select features ACC': r'$\mathrm{Attr}^{\mathrm{prog}}_{\mathrm{select}}$', + 'Select: Pred features ACC': r'$\mathrm{Attr}^{\mathrm{select}}_{\mathrm{pred}}$', + 'Select: Prog features ACC': r'$\mathrm{Attr}^{\mathrm{select}}_{\mathrm{prog}}$', + 'Select: Select features ACC': r'$\mathrm{Attr}^{\mathrm{select}}_{\mathrm{select}}$', + 'CI Coverage': 'CI Coverage', + 'Normalized PEHE': 'N. PEHE', + 'PEHE': 'PEHE', + 'CF RMSE': 'CF-RMSE', + 'AUROC': 'AUROC', + 'Factual AUROC': 'Factual AUROC', + 'CF AUROC': "CF AUROC", + 'Factual RMSE': 'F-RMSE', + "Factual RMSE Y0": "F-RMSE Y0", + "Factual RMSE Y1": "F-RMSE Y1", + "CF RMSE Y0": "CF-RMSE Y0", + "CF RMSE Y1": "CF-RMSE Y1", + 'Normalized F-RMSE': 'N. F-RMSE', + 'Normalized CF-RMSE': 'N. CF-RMSE', + "F-Outcome true mean":"F-Outcome true mean", + "CF-Outcome true mean":"CF-Outcome true mean", + "F-Outcome true std":"F-Outcome true std", + "CF-Outcome true std":"CF-Outcome true std", + "F-CF Outcome Diff":"F-CF Outcome Diff", + 'Swap AUROC@1': 'AUROC@1', + 'Swap AUPRC@1': 'AUPRC@1', + 'Swap AUROC@5': 'AUROC@5', + 'Swap AUPRC@5': 'AUPRC@5', + 'Swap AUROC@tre': 'AUROC@tre', + 'Swap AUPRC@tre': 'AUPRC@tre', + 'Swap AUROC@all': 'AUROC', + 'Swap AUPRC@all': 'AUPRC', + "GT Pred Expertise": r'$\mathrm{B}^{\pi}_{Y_1-Y_0}$', + "GT Prog Expertise": r'$\mathrm{B}^{\pi}_{Y_0}$', + "GT Tre Expertise": r'$\mathrm{B}^{\pi}_{Y_1}$', + "Upd. GT Pred Expertise": r'$\mathrm{B}^{\hat{\pi}}_{Y_1-Y_0}$', + "Upd. GT Prog Expertise": r'$\mathrm{B}^{\hat{\pi}}_{Y_0}$', + "Upd. GT Tre Expertise": r'$\mathrm{B}^{\hat{\pi}}_{Y_1}$', + "GT Expertise Ratio": r'$\mathrm{E}^{\pi}_{\mathrm{ratio}}$', + "GT Total Expertise": r'$\mathrm{B}^{\pi}_{Y_0,Y_1}$', + "ES Pred Expertise": "ES Pred Bias", + "ES Prog Expertise": "ES Prog Bias", + "ES Total Expertise": "ES Outcome Bias", + "Pred Precision": r'$\mathrm{Prec}^{\hat{\pi}}_{\mathrm{Ass.}}$', + "Policy Precision": r'$\mathrm{Prec}^{\pi}_{\mathrm{Ass.}}$', + 'T Distribution: Train': 'T Distribution: Train', + 'T Distribution: Test': 'T Distribution: Test', + 'True Swap Perc': 'True Swap Perc', + "Normalized F-CF Diff": "Normalized F-CF Diff", + 'Training Duration': 'Training Duration', + "FC PEHE":"PEHE(Model) - PEHE(TARNet)", + "FC F-RMSE":"Rel. N. F-RMSE", + "FC CF-RMSE":"Rel. N. CF-RMSE", + "FC Swap AUROC":"Rel. AUROC", + "FC Swap AUPRC":"Rel. AUPRC", + "GT In-context Var":r'$\mathrm{B}^{\pi}_{X}$', + "ES In-context Var":"ES Total Bias", + "GT-ES Pred Expertise Diff":r'$\mathrm{E}^{\pi}_{\mathrm{pred}}$ Error', + "GT-ES Prog Expertise Diff":r'$\mathrm{E}^{\pi}_{\mathrm{prog}}$ Error', + "GT-ES Total Expertise Diff":r'$\mathrm{E}^{\pi}$ Error', + "RMSE Y0":"RMSE Y0", + "RMSE Y1":"RMSE Y1", +} + +learners_names_map = { + "Torch_TLearner":"T-Learner-MLP", + "Torch_SLearner": "S-Learner-MLP", + "Torch_TARNet": "Baseline-TAR", + "Torch_DragonNet": "DragonNet-1 (Act. Pred.)", + "Torch_DragonNet_2": "DragonNet-2 (Act. Pred.)", + "Torch_DragonNet_4": "DragonNet-4 (Act. Pred.)", + "Torch_DRLearner": "Direct-DR", + "Torch_XLearner": "XLearner-MLP (Direct)", + "Torch_CFRNet_0.001": 'CFRNet-0.001 (Balancing)', #-\gamma=0.001)$', + "Torch_CFRNet_0.01": 'CFRNet-0.01 (Balancing)', + "Torch_CFRNet_0.0001": 'CFRNet-0.0001 (Balancing)', + 'Torch_ActionNet': "ActionNet (Act. Pred.)", + "Torch_RALearner": "RA-Learner", + "Torch_ULearner": "U-Learner", + "Torch_PWLearner":"Torch_PWLearner", + "Torch_RLearner":"Torch_RLearner", + "Torch_FlexTENet":"Torch_FlexTENet", + "EconML_CausalForestDML": "CausalForestDML", + "EconML_DML": "APred-Prop-Lasso", + "EconML_DMLOrthoForest": "DMLOrthoForest", + "EconML_DRLearner": "DRLearner", + "EconML_DROrthoForest": "DROrthoForest", + "EconML_ForestDRLearner": "ForestDRLearner", + "EconML_LinearDML": "LinearDML", + "EconML_LinearDRLearner": "LinearDRLearner", + "EconML_SparseLinearDML": "SparseLinearDML", + "EconML_SparseLinearDRLearner": "SparseLinearDRLearner", + "EconML_TLearner_Lasso": "T-Learner-Lasso", + "EconML_SLearner_Lasso": "S-Learner-Lasso", + "EconML_XLearner_Lasso": "XLearnerLasso", + "DiffPOLearner": "DiffPOLearner", + + "Truth": "Truth" +} + +compare_values_map = { + # Propensity + "none_prog": r'$\pi_{\mathrm{RCT}} \rightarrow \pi_{\mathrm{Y_0}}^\beta$', + "none_tre": r'$\pi_{\mathrm{RCT}} \rightarrow \pi_{\mathrm{Y_1}}^\beta$', + "none_pred": r'$\pi_{\mathrm{RCT}} \rightarrow \pi_{\mathrm{Y_1-Y_0}}^\beta$', + "rct_none": r'$\pi_{\mathrm{RCT}} \rightarrow \pi_{\mathrm{X_{irr}}}^\beta$', + "none_pred_overlap": r'$\pi_{\mathrm{RCT}} \rightarrow \pi_{\mathrm{X_{pred}}}^\beta$', + + # Toy + "toy7": r'$\pi_{\mathrm{RCT}} \rightarrow \pi_{\mathrm{T_7}}^\beta$', + "toy8_nonlinear": r'$\pi_{\mathrm{RCT}} \rightarrow \pi_{\mathrm{T_8}}^\beta$', + "toy1_linear": r'$\pi_{\mathrm{RCT}} \rightarrow \pi_{\mathrm{T_1^{lin}}}^\beta$', + "toy1_nonlinear": r'$\mathrm{Toy 1:} \pi_{\mathrm{T_1}}^\beta$', #r'$\pi_{\mathrm{RCT}} \rightarrow \pi_{\mathrm{T_1}}^\beta$', + "toy2_linear": r'$\pi_{\mathrm{RCT}} \rightarrow \pi_{\mathrm{T_3^{lin}}}^\beta$', + "toy2_nonlinear": r'$\mathrm{Toy 3:} \pi_{\mathrm{T_3}}^\beta$', + "toy3_nonlinear": r'$\mathrm{Toy 2:} \pi_{\mathrm{T_2}}^\beta$', + "toy4_nonlinear": r'$\pi_{\mathrm{RCT}} \rightarrow \pi_{\mathrm{T_4}}^\beta$', + "toy5": r'$\pi_{\mathrm{RCT}} \rightarrow \pi_{\mathrm{T_5}}^\beta$', + "toy6_nonlinear": r'$\mathrm{Toy 4:} \pi_{\mathrm{T_4}}^\beta$', #r'$\pi_{\mathrm{RCT}} \rightarrow \pi_{\mathrm{T_6}}^\beta$', + + # Expertise + # "prog_tre": r'$\pi_{\mathrm{Y_0}}^{\beta=4} \rightarrow \pi_{\mathrm{Y_1-Y_0}}^{\beta=4} \rightarrow \pi_{\mathrm{Y_1}}^{\beta=4}$', + # "none_prog": r'$\pi_{\mathrm{X_{rand}}}^{\beta=4} \rightarrow \pi_{\mathrm{Y_0}}^{\beta=4}$', + # "none_tre": r'$\pi_{\mathrm{X_{rand}}}^{\beta=4} \rightarrow \pi_{\mathrm{Y_1}}^{\beta=4}$', + # "none_pred": r'$\pi_{\mathrm{X_{rand}}}^{\beta=4} \rightarrow \pi_{\mathrm{Y_1-Y_0}}^{\beta=4}$', + #["prog_tre", "none_prog", "none_tre", "none_pred"] + + 0: r'$\pi_{\mathrm{RCT}}$', + 2: r'$\pi_{\mathrm{Y_1-Y_0}}^{\beta=2}$', + 100: r'$\pi_{\mathrm{Y_1-Y_0}}^{\beta=100}$', + + "0": r'$\pi_{\mathrm{RCT}}$', + "2": r'$\pi_{\mathrm{Y_1-Y_0}}^{\beta=2}$', + "100": r'$\pi_{\mathrm{Y_1-Y_0}}^{\beta=100}$', +} + + +def plot_results_datasets_compare(results_df: pd.DataFrame, + model_names: list, + dataset: str, + compare_axis: str, + compare_axis_values, + x_axis, + x_label_name, + x_values_to_plot, + metrics_list, + learners_list, + figsize, + legend_position, + seeds_list, + n_splits, + sharey=False, + legend_rows=1, + dim_X=1, + log_x_axis = False): + """ + Plot the results for a given dataset. + """ + # Get the unique values of the compare axis + if compare_axis_values is None: + compare_axis_values = results_df[compare_axis].unique() + + metrics_list = ["Pred Precision"] + # Initialize the plot + nrows = len(metrics_list) + columns = len(compare_axis_values) + figsize = (3*columns+2, 3*nrows) + #figsize = (3*columns, 3) + + font_size=10 + fig, axs = plt.subplots(len(metrics_list), len(compare_axis_values), figsize=figsize, squeeze=False, sharey=sharey, dpi=500) + plt.gcf().subplots_adjust(bottom=0.15) + + # Aggregate results across seeds for each metric + for i in range(len(compare_axis_values)): + cmp_value = compare_axis_values[i] + for metric_id, metric in enumerate(metrics_list): + for model_name in model_names: + # Extract results for individual cate models + sub_df = results_df.loc[(results_df["Learner"] == model_name)] + sub_df = sub_df.loc[(sub_df[compare_axis] == cmp_value)][[x_axis, metric]] + sub_df = sub_df[sub_df[x_axis].isin(x_values_to_plot)] + sub_df_mean = sub_df.groupby(x_axis).agg('median').reset_index() + sub_df_std = sub_df.groupby(x_axis).agg('std').reset_index() + sub_df_min = sub_df.groupby(x_axis).agg('min').reset_index() + sub_df_max = sub_df.groupby(x_axis).agg('max').reset_index() + + # Plot the results + x_values = sub_df_mean.loc[:, x_axis].values + + try: + y_values = sub_df_mean.loc[:, metric].values + except: + continue + + y_err = sub_df_std.loc[:, metric].values / (np.sqrt(n_splits*len(seeds_list))) + y_min = sub_df_min.loc[:, metric].values + y_max = sub_df_max.loc[:, metric].values + + # axs[metric_id][i].plot(x_values, y_values, label=learners_names_map[model_name], + # color=learner_colors[model_name], linestyle=learner_linestyles[model_name], marker=learner_markers[model_name], markersize=5) + axs[metric_id][i].plot(x_values, y_values, label=learners_names_map[model_name], + color=learner_colors[model_name], linestyle=learner_linestyles[model_name], marker=learner_markers[model_name], markersize=3, alpha=0.5) + axs[metric_id][i].fill_between(x_values, y_values-y_err, y_values+y_err, alpha=0.1, color=learner_colors[model_name]) + + + + # if log_x_axis: + # axs[metric_id][i].set_xscale('symlog', linthresh=0.5, base=2) + # #axs[metric_id][i].fill_between(x_values, y_min, y_max, alpha=0.1, color=learner_colors[model_name]) + + axs[metric_id][i].tick_params(axis='x', labelsize=font_size-2) + axs[metric_id][i].tick_params(axis='y', labelsize=font_size-1) + + + axs[metric_id][i].set_title(compare_values_map[cmp_value], fontsize=font_size+11, y=1.04) + + axs[metric_id][i].set_xlabel(x_label_name, fontsize=font_size-1) + if i == 0: + axs[metric_id][i].set_ylabel(metric_names_map[metric], fontsize=font_size-1) + + if log_x_axis: + axs[metric_id][i].set_xscale('symlog', linthresh=0.5, base=2) + # Display as fractions if not integers and as integers if integers + # axs[0][i].xaxis.set_major_formatter(ScalarFormatter()) + # Get the current ticks + current_ticks = axs[metric_id][i].get_xticks() + + # Calculate the midpoint between the first and second tick + if len(current_ticks) > 1: + midpoint = (current_ticks[0] + current_ticks[1]) / 2 + # Add the midpoint to the list of ticks + new_ticks = [current_ticks[0], midpoint] + list(current_ticks[1:]) + axs[metric_id][i].set_xticks(new_ticks) + + # Add a tick at 0.25 + axs[metric_id][i].set_xticks(sorted(set(axs[metric_id][i].get_xticks()).union({0.25}))) + axs[metric_id][i].xaxis.set_major_formatter(FuncFormatter(custom_formatter)) + + if metric in ["True Swap Perc", "T Distribution: Train", "T Distribution: Test", "GT Total Expertise", "ES Total Expertise", "GT Expertise Ratio", "GT Pred Expertise", "GT Prog Expertise", "ES Pred Expertise", "ES Prog Expertise","GT In-context Var","ES In-context Var","GT-ES Pred Expertise Diff","GT-ES Prog Expertise Diff","GT-ES Total Expertise Diff", "Policy Precision", "GT In-context Var", "GT Total Expertise", "GT Prog Expertise", "GT Tre Expertise", "GT Pred Expertise", "Upd. GT Prog Expertise", "Upd. GT Tre Expertise", "Upd. GT Pred Expertise"]: + axs[metric_id][i].set_ylim(0, 1) + + if metric == "PEHE": + axs[metric_id][i].set_ylim(top = 1.75) + #axs[metric_id][i].set_ylim(bottom=0.475) + #axs[metric_id][i].set_aspect(0.7/axs[metric_id][i].get_data_ratio(), adjustable='box') + #axs[metric_id][i].tick_params(axis='y', labelsize=font_size-1) + + # axs[metric_id][i].tick_params( + # axis='x', # changes apply to the x-axis + # which='both', # both major and minor ticks are affected + # bottom=False, # ticks along the bottom edge are off + # top=False, # ticks along the top edge are off + # labelbottom=False) # labels along the bottom edge are off + + # Add the legend + lines_labels = [ax.get_legend_handles_labels() for ax in fig.axes] + lines, labels = [sum(lol, []) for lol in zip(*lines_labels)] + legend_rows = 6 + + # Iterate over each row of subplots + for row in range(len(axs)): + # Create a legend for each row + handles, labels = axs[row, -1].get_legend_handles_labels() + axs[row, -1].legend( + lines[:len(learners_list)], + labels[:len(learners_list)], + ncol=1, #len(learners_list) if legend_rows == 1 else int((len(learners_list) + 1) / legend_rows), + loc='center right', + bbox_to_anchor=(1.8, 0.5), + prop={'size': font_size+2} + ) + + + #fig.tight_layout() + plt.subplots_adjust( wspace=0.07) + return fig + +def plot_performance_metrics(results_df: pd.DataFrame, + model_names: list, + dataset: str, + compare_axis: str, + compare_axis_values, + x_axis, + x_label_name, + x_values_to_plot, + metrics_list, + learners_list, + figsize, + legend_position, + seeds_list, + n_splits, + sharey=False, + legend_rows=1, + dim_X=1, + log_x_axis = False): + + # Get the unique values of the compare axis + if compare_axis_values is None: + compare_axis_values = results_df[compare_axis].unique() + + metrics_list = ['PEHE', 'FC PEHE', "Pred Precision", 'Pred: Pred features ACC', 'Prog: Prog features ACC'] #] + #log_x_axis=False + # Initialize the plot + nrows = len(metrics_list) + columns = len(compare_axis_values) + + #figsize = (3*columns+2, 3*nrows) #PREV + figsize = (3*columns+2, 3.4*nrows) + #figsize = (3*columns, 3) + + font_size=10 + fig, axs = plt.subplots(len(metrics_list), len(compare_axis_values), figsize=figsize, squeeze=False, sharey=sharey, dpi=500) + plt.gcf().subplots_adjust(bottom=0.15) + + model_names_cpy = model_names.copy() + # Aggregate results across seeds for each metric + for i in range(len(compare_axis_values)): + cmp_value = compare_axis_values[i] + for metric_id, metric in enumerate(metrics_list): + # if metric in ["FC PEHE", 'Prog: Prog features ACC', 'Prog: Pred features ACC']: + # model_names = ["Torch_TARNet","Torch_DragonNet","Torch_CFRNet_0.001","EconML_TLearner_Lasso"] + # else: + model_names = model_names_cpy #["Torch_TARNet","Torch_DragonNet","Torch_ActionNet", "Torch_CFRNet_0.001","EconML_TLearner_Lasso"] + + for model_name in model_names: + # Extract results for individual cate models + sub_df = results_df.loc[(results_df["Learner"] == model_name)] + sub_df = sub_df.loc[(sub_df[compare_axis] == cmp_value)][[x_axis, metric]] + sub_df = sub_df[sub_df[x_axis].isin(x_values_to_plot)] + sub_df_mean = sub_df.groupby(x_axis).agg('median').reset_index() + sub_df_std = sub_df.groupby(x_axis).agg('std').reset_index() + sub_df_min = sub_df.groupby(x_axis).agg('min').reset_index() + sub_df_max = sub_df.groupby(x_axis).agg('max').reset_index() + + # Plot the results + x_values = sub_df_mean.loc[:, x_axis].values + + try: + y_values = sub_df_mean.loc[:, metric].values + except: + continue + + y_err = sub_df_std.loc[:, metric].values / (np.sqrt(n_splits*len(seeds_list))) + y_min = sub_df_min.loc[:, metric].values + y_max = sub_df_max.loc[:, metric].values + + # axs[metric_id][i].plot(x_values, y_values, label=learners_names_map[model_name], + # color=learner_colors[model_name], linestyle=learner_linestyles[model_name], marker=learner_markers[model_name], markersize=5) + axs[metric_id][i].plot(x_values, y_values, label=learners_names_map[model_name], + color=learner_colors[model_name], linestyle=learner_linestyles[model_name], marker=learner_markers[model_name], markersize=3, alpha=0.5) + axs[metric_id][i].fill_between(x_values, y_values-y_err, y_values+y_err, alpha=0.1, color=learner_colors[model_name]) + + + + # if log_x_axis: + # axs[metric_id][i].set_xscale('symlog', linthresh=0.5, base=2) + # #axs[metric_id][i].fill_between(x_values, y_min, y_max, alpha=0.1, color=learner_colors[model_name]) + + axs[metric_id][i].tick_params(axis='x', labelsize=font_size-2) + axs[metric_id][i].tick_params(axis='y', labelsize=font_size-1) + + + if metric_id == 0: + axs[metric_id][i].set_title(compare_values_map[cmp_value], fontsize=font_size+1, y=1.0) + + + axs[metric_id][i].set_xlabel(x_label_name, fontsize=font_size-1) + if i == 0: + axs[metric_id][i].set_ylabel(metric_names_map[metric], fontsize=font_size-1) + + if log_x_axis: + axs[metric_id][i].set_xscale('symlog', linthresh=0.5, base=2) + # Display as fractions if not integers and as integers if integers + # axs[0][i].xaxis.set_major_formatter(ScalarFormatter()) + # Get the current ticks + current_ticks = axs[metric_id][i].get_xticks() + + # Calculate the midpoint between the first and second tick + if len(current_ticks) > 1: + midpoint = (current_ticks[0] + current_ticks[1]) / 2 + # Add the midpoint to the list of ticks + new_ticks = [current_ticks[0], midpoint] + list(current_ticks[1:]) + axs[metric_id][i].set_xticks(new_ticks) + + # Add a tick at 0.25 + axs[metric_id][i].set_xticks(sorted(set(axs[metric_id][i].get_xticks()).union({0.25}))) + axs[metric_id][i].xaxis.set_major_formatter(FuncFormatter(custom_formatter)) + + if metric in ["True Swap Perc", "T Distribution: Train", "T Distribution: Test", "GT Total Expertise", "ES Total Expertise", "GT Expertise Ratio", "GT Pred Expertise", "GT Prog Expertise", "ES Pred Expertise", "ES Prog Expertise","GT In-context Var","ES In-context Var","GT-ES Pred Expertise Diff","GT-ES Prog Expertise Diff","GT-ES Total Expertise Diff", "Policy Precision", "GT In-context Var", "GT Total Expertise", "GT Prog Expertise", "GT Tre Expertise", "GT Pred Expertise", "Upd. GT Prog Expertise", "Upd. GT Tre Expertise", "Upd. GT Pred Expertise"]: + axs[metric_id][i].set_ylim(0, 1) + + # if metric == "PEHE": + # axs[metric_id][i].set_ylim(top = 1.75) + #axs[metric_id][i].set_ylim(bottom=0.475) + #axs[metric_id][i].set_aspect(0.7/axs[metric_id][i].get_data_ratio(), adjustable='box') + #axs[metric_id][i].tick_params(axis='y', labelsize=font_size-1) + + axs[metric_id][i].tick_params( + axis='x', # changes apply to the x-axis + which='both', # both major and minor ticks are affected + bottom=False, # ticks along the bottom edge are off + top=False, # ticks along the top edge are off + labelbottom=False) # labels along the bottom edge are off + + # Add the legend + lines_labels = [ax.get_legend_handles_labels() for ax in fig.axes] + lines, labels = [sum(lol, []) for lol in zip(*lines_labels)] + legend_rows = 6 + + # Iterate over each row of subplots + for row in range(len(axs)): + # Create a legend for each row + handles, labels = axs[row, -1].get_legend_handles_labels() + axs[row, -1].legend( + lines[:len(learners_list)], + labels[:len(learners_list)], + ncol=1, #len(learners_list) if legend_rows == 1 else int((len(learners_list) + 1) / legend_rows), + loc='center right', + bbox_to_anchor=(1.9, 0.5), + prop={'size': font_size+2} + ) + + + plt.subplots_adjust( wspace=0.07) + #fig.tight_layout() + return fig + + + +def plot_performance_metrics_f_cf(results_df: pd.DataFrame, + model_names: list, + dataset: str, + compare_axis: str, + compare_axis_values, + x_axis, + x_label_name, + x_values_to_plot, + metrics_list, + learners_list, + figsize, + legend_position, + seeds_list, + n_splits, + sharey=False, + legend_rows=1, + dim_X=1, + log_x_axis = False): + # Get the unique values of the compare axis + if compare_axis_values is None: + compare_axis_values = results_df[compare_axis].unique() + + # Initialize the plot + #model_names = model_names[0] #["Torch_TARNet"] #[EconML_TLearner_Lasso"] + columns = len(compare_axis_values) + rows = len(model_names) + figsize = (3*columns+2, 3.3*rows) + #figsize = (3*columns, 3) + font_size=10 + fig, axs = plt.subplots(len(model_names), len(compare_axis_values), figsize=figsize, squeeze=False, sharey=sharey, dpi=500) + #plt.gcf().subplots_adjust(bottom=0.15) + + # Filter results_df for first model and first seed and first split + + # results_df = results_df.loc[(results_df["Seed"] == seeds_list[0])] + # results_df = results_df.loc[(results_df["Split ID"] == 0)] + + # Only consider expertise metrics + #colors = ['black', 'orange', 'darkorange', 'orchid', 'darkorchid'] + colors = ['blue', 'lightcoral', 'lightgreen', 'red', 'green'] + + markers = ['o', 'D', 'D', 'x', 'x'] + metrics_list = ["PEHE", "Factual RMSE Y0", "Factual RMSE Y1", "CF RMSE Y0", "CF RMSE Y1"] + filtered_df = results_df[[x_axis, compare_axis] + metrics_list] + + # Aggregate results across seeds for each metric + for model_id, model_name in enumerate(model_names): + filtered_df_model = filtered_df.loc[(results_df["Learner"] == model_name)] + for i in range(len(compare_axis_values)): + cmp_value = compare_axis_values[i] + + # Plot all metric outcomes as lines in a single plot for the given cmp_value and use x_axis as x-axis + x_values = filtered_df_model[x_axis].values + + for metric_id, metric in enumerate(metrics_list): + # Extract results for individual cate models + + sub_df = filtered_df_model.loc[(filtered_df_model[compare_axis] == cmp_value)][[x_axis, metric]] + sub_df = sub_df[sub_df[x_axis].isin(x_values_to_plot)] + sub_df_mean = sub_df.groupby(x_axis).agg('median').reset_index() + sub_df_std = sub_df.groupby(x_axis).agg('std').reset_index() + sub_df_min = sub_df.groupby(x_axis).agg('min').reset_index() + sub_df_max = sub_df.groupby(x_axis).agg('max').reset_index() + + # Plot the results + x_values = sub_df_mean.loc[:, x_axis].values + + try: + y_values = sub_df_mean.loc[:, metric].values + except: + continue + + y_err = sub_df_std.loc[:, metric].values / (np.sqrt(n_splits*len(seeds_list))) + y_min = sub_df_min.loc[:, metric].values + y_max = sub_df_max.loc[:, metric].values + + # use a different linestyle for each metric + + axs[model_id][i].plot(x_values, y_values, label=metric_names_map[metric], + color=colors[metric_id], linestyle='-', marker=markers[metric_id], alpha=0.5, markersize=3) + axs[model_id][i].fill_between(x_values, y_values-y_err, y_values+y_err, alpha=0.1, color=colors[metric_id]) + + + axs[model_id][i].tick_params(axis='x', labelsize=font_size-2) + axs[model_id][i].tick_params(axis='y', labelsize=font_size-1) + + axs[model_id][i].set_title(compare_values_map[cmp_value], fontsize=font_size+2, y=1.01) + + axs[model_id][i].set_xlabel(x_label_name, fontsize=font_size-1) + + # if i == 0: + # axs[model_id][i].set_ylabel(metric_names_map[metric], fontsize=font_size-1) + + if log_x_axis: + axs[model_id][i].set_xscale('symlog', linthresh=0.5, base=2) + # Display as fractions if not integers and as integers if integers + # axs[0][i].xaxis.set_major_formatter(ScalarFormatter()) + # Get the current ticks + current_ticks = axs[model_id][i].get_xticks() + + # Calculate the midpoint between the first and second tick + if len(current_ticks) > 1: + midpoint = (current_ticks[0] + current_ticks[1]) / 2 + # Add the midpoint to the list of ticks + new_ticks = [current_ticks[0], midpoint] + list(current_ticks[1:]) + axs[model_id][i].set_xticks(new_ticks) + + # Add a tick at 0.25 + axs[model_id][i].set_xticks(sorted(set(axs[model_id][i].get_xticks()).union({0.25}))) + axs[model_id][i].xaxis.set_major_formatter(FuncFormatter(custom_formatter)) + axs[model_id][i].tick_params( + axis='x', # changes apply to the x-axis + which='both', # both major and minor ticks are affected + bottom=False, # ticks along the bottom edge are off + top=False, # ticks along the top edge are off + labelbottom=False) # labels along the bottom edge are off + + axs[model_id][i].tick_params(axis='y', labelsize=font_size-1) + #axs[model_id][i].set_aspect(0.7/axs[model_id][i].get_data_ratio(), adjustable='box') + + + # Add the legend + lines_labels = [ax.get_legend_handles_labels() for ax in fig.axes] + lines, labels = [sum(lol, []) for lol in zip(*lines_labels)] + + # Add legends to the right of each row + for i, row in enumerate(axs): + lines_labels = [ax.get_legend_handles_labels() for ax in row] + lines, labels = [sum(lol, []) for lol in zip(*lines_labels)] + row[-1].legend( + lines[:len(metrics_list)], + labels[:len(metrics_list)], + loc='center right', + bbox_to_anchor=(1.8, 0.5), + ncol=1, + prop={'size': font_size+2}, + title_fontsize=font_size+4, + title=learners_names_map[model_names[i]] + ) + + #fig.tight_layout() + plt.subplots_adjust( wspace=0.07) + + return fig + + +def plot_expertise_metrics(results_df: pd.DataFrame, + model_names: list, + dataset: str, + compare_axis: str, + compare_axis_values, + x_axis, + x_label_name, + x_values_to_plot, + metrics_list, + learners_list, + figsize, + legend_position, + seeds_list, + n_splits, + sharey=False, + legend_rows=1, + dim_X=1, + log_x_axis = False): + + if compare_axis_values is None: + compare_axis_values = results_df[compare_axis].unique() + + # Initialize the plot + columns = len(compare_axis_values) + figsize = (3*columns+2, 3) + font_size=10 + fig, axs = plt.subplots(1, len(compare_axis_values), figsize=figsize, squeeze=False, sharey=sharey, dpi=500) + plt.gcf().subplots_adjust(bottom=0.15) + + # Filter results_df for first model and first seed and first split + results_df = results_df.loc[(results_df["Learner"] == model_names[0])] + results_df = results_df.loc[(results_df["Seed"] == seeds_list[0])] + results_df = results_df.loc[(results_df["Split ID"] == 0)] + + + # Only consider expertise metrics + colors = ['black', 'grey', 'red', 'green', 'blue'] + markers = ['o', 'x', 'x', 'x', 'x'] + metrics_list = ["Policy Precision", "GT In-context Var", "GT Prog Expertise", "GT Tre Expertise", "GT Pred Expertise"] + sub_df = results_df[[x_axis, compare_axis, "Seed", "Split ID"] + metrics_list] + + # Aggregate results across seeds for each metric + for i in range(len(compare_axis_values)): + cmp_value = compare_axis_values[i] + + # Plot all metric outcomes as lines in a single plot for the given cmp_value and use x_axis as x-axis + filtered_df = sub_df[(sub_df[compare_axis] == cmp_value)] + x_values = filtered_df[x_axis].values + + for metric_id, metric in enumerate(metrics_list): + y_values = filtered_df[metric].values + # use a different linestyle for each metric + + axs[0][i].plot(x_values, y_values, label=metric_names_map[metric], color=colors[metric_id], linestyle='-', marker=markers[metric_id], alpha=0.5, markersize=5) + + + # if i == 0: + # axs[0][i].set_ylabel("Selection Bias", fontsize=font_size) + + axs[0][i].tick_params(axis='x', labelsize=font_size-2) + axs[0][i].tick_params(axis='y', labelsize=font_size-1) + axs[0][i].set_title(compare_values_map[cmp_value], fontsize=font_size+11, y=1.04) + axs[0][i].set_xlabel(x_label_name, fontsize=font_size-1) + + if log_x_axis: + axs[0][i].set_xscale('symlog', linthresh=0.5, base=2) + # Display as fractions if not integers and as integers if integers + # axs[0][i].xaxis.set_major_formatter(ScalarFormatter()) + # Get the current ticks + current_ticks = axs[0][i].get_xticks() + + # Calculate the midpoint between the first and second tick + if len(current_ticks) > 1: + midpoint = (current_ticks[0] + current_ticks[1]) / 2 + # Add the midpoint to the list of ticks + new_ticks = [current_ticks[0], midpoint] + list(current_ticks[1:]) + axs[0][i].set_xticks(new_ticks) + + # Add a tick at 0.25 + axs[0][i].set_xticks(sorted(set(axs[0][i].get_xticks()).union({0.25}))) + axs[0][i].xaxis.set_major_formatter(FuncFormatter(custom_formatter)) + + # if i == 0: + # axs[0][i].set_ylabel(metric_names_map[metric], fontsize=font_size-1) + + axs[0][i].tick_params( + axis='x', # changes apply to the x-axis + which='both', # both major and minor ticks are affected + bottom=False, # ticks along the bottom edge are off + top=False, # ticks along the top edge are off + labelbottom=False) # labels along the bottom edge are off + #axs[0][i].set_aspect(0.7/axs[0][i].get_data_ratio(), adjustable='box') + axs[0][i].tick_params(axis='y', labelsize=font_size-1) + + # Add the legend + lines_labels = [ax.get_legend_handles_labels() for ax in fig.axes] + lines, labels = [sum(lol, []) for lol in zip(*lines_labels)] + + fig.legend( + lines[:len(metrics_list)], + labels[:len(metrics_list)], + loc='center right', # Position the legend to the right + #bbox_to_anchor=(1, 0.5), # Adjust the anchor point to the right center + ncol=1, # Set the number of columns to 1 for a vertical legend + prop={'size': font_size+2} + ) + + #fig.tight_layout() + plt.subplots_adjust( wspace=0.07) + + return fig + + + +def merge_pngs(images, axis="horizontal"): + """ + Merge a list of png images into a single image. + """ + widths, heights = zip(*(i.size for i in images)) + + if axis == "vertical": + total_height = sum(heights) + max_width = max(widths) + + new_im = Image.new('RGB', (max_width, total_height)) + + y_offset = 0 + for im in images: + new_im.paste(im, (0,y_offset)) + y_offset += im.size[1] + + return new_im + + elif axis == "horizontal": + total_width = sum(widths) + max_height = max(heights) + + new_im = Image.new('RGB', (total_width, max_height)) + + x_offset = 0 + for im in images: + new_im.paste(im, (x_offset,0)) + x_offset += im.size[0] + + return new_im +