Diff of /src/plotting.py [000000] .. [6ac965]

Switch to unified view

a b/src/plotting.py
1
import seaborn as sns
2
import matplotlib.pyplot as plt
3
import pickle
4
import pandas as pd 
5
import numpy as np
6
from pathlib import Path
7
import warnings
8
warnings.simplefilter(action='ignore', category=FutureWarning)
9
from PIL import Image
10
import sys 
11
12
# Hydra for configuration
13
import hydra
14
from omegaconf import DictConfig, OmegaConf
15
from matplotlib.ticker import ScalarFormatter
16
from matplotlib.ticker import MaxNLocator
17
from matplotlib.ticker import FuncFormatter
18
19
# Custom formatter function
20
def custom_formatter(x, pos):
21
    if x.is_integer():
22
        return f'{int(x)}'
23
    elif x==0.5:
24
        return r'$1/2$'
25
    elif x==0.25:
26
        return r'$1/4$'
27
    # Do a diagonal fraction instead
28
29
    else:
30
        return f'{x:.2f}'
31
32
cblind_palete = sns.color_palette("colorblind", as_cmap=True)
33
learner_colors = {
34
    "Torch_SLearner": cblind_palete[0],
35
    "Torch_TLearner": cblind_palete[1],
36
    "Torch_XLearner": cblind_palete[2],
37
    "Torch_TARNet": cblind_palete[3],
38
    'Torch_CFRNet_0.01': cblind_palete[4],
39
    "Torch_CFRNet_0.001": cblind_palete[6],
40
    'Torch_CFRNet_0.0001': cblind_palete[9],
41
    'Torch_ActionNet': cblind_palete[7],
42
    "Torch_DRLearner": cblind_palete[8],
43
    "Torch_RALearner": cblind_palete[9],
44
    "Torch_DragonNet": cblind_palete[5],
45
    "Torch_DragonNet_2": cblind_palete[5],
46
    "Torch_DragonNet_4": cblind_palete[3],
47
    "Torch_ULearner": cblind_palete[6],
48
    "Torch_PWLearner": cblind_palete[7],
49
    "Torch_RLearner": cblind_palete[8],
50
    "Torch_FlexTENet": cblind_palete[9],
51
    "EconML_CausalForestDML": cblind_palete[2],
52
    "EconML_DML": cblind_palete[0],
53
    "EconML_DMLOrthoForest": cblind_palete[1],
54
    "EconML_DRLearner": cblind_palete[6],
55
    "EconML_DROrthoForest": cblind_palete[9],
56
    "EconML_ForestDRLearner": cblind_palete[7],
57
    "EconML_LinearDML": cblind_palete[8],
58
    "EconML_LinearDRLearner": cblind_palete[5],
59
    "EconML_SparseLinearDML": cblind_palete[3],
60
    "EconML_SparseLinearDRLearner": cblind_palete[4],
61
    "EconML_XLearner_Lasso": cblind_palete[7],
62
    "EconML_TLearner_Lasso": cblind_palete[8],
63
    "EconML_SLearner_Lasso": cblind_palete[9],
64
    "DiffPOLearner": cblind_palete[0],
65
    "Truth": cblind_palete[9],
66
}
67
68
learner_linestyles = {
69
    "Torch_SLearner": "-",
70
    "Torch_TLearner": "--",
71
    "Torch_XLearner": ":",
72
    "Torch_TARNet": "-.",
73
    "Torch_DragonNet": "--",
74
    "Torch_DragonNet_2": "-",
75
    "Torch_DragonNet_4": "-.",
76
    "Torch_XLearner": "--",
77
    "Torch_CFRNet_0.01": "-",
78
    "Torch_CFRNet_0.001": ":",
79
    "Torch_CFRNet_0.0001": "--",
80
    "Torch_DRLearner": "-",
81
    "Torch_RALearner": "--",
82
    "Torch_ULearner": "-",
83
    "Torch_PWLearner": "-",
84
    "Torch_RLearner": "-",
85
    "Torch_FlexTENet": "-",
86
    'Torch_ActionNet': "-",
87
    "EconML_CausalForestDML": "-",
88
    "EconML_DML": "--",
89
    "EconML_DMLOrthoForest": ":",
90
    "EconML_DRLearner": "-.",
91
    "EconML_DROrthoForest": "--",
92
    "EconML_ForestDRLearner": "-.",
93
    "EconML_LinearDML": ":",
94
    "EconML_LinearDRLearner": "-",
95
    "EconML_SparseLinearDML": "--",
96
    "EconML_SparseLinearDRLearner": ":",
97
    "EconML_SLearner_Lasso": "-.",
98
    "EconML_TLearner_Lasso": "--",
99
    "EconML_XLearner_Lasso": "-",
100
    "DiffPOLearner": "-.",
101
    "Truth": ":",
102
}
103
104
105
learner_markers = {
106
    "Torch_SLearner": "d",
107
    "Torch_TLearner": "o",
108
    "Torch_XLearner": "^",
109
    "Torch_TARNet": "*",
110
    "Torch_DragonNet": "x",
111
    "Torch_DragonNet_2": "o",
112
    "Torch_DragonNet_4": "*",
113
    "Torch_XLearner": "D",
114
    "Torch_CFRNet_0.01": "8",
115
    "Torch_CFRNet_0.001": "s",
116
    "Torch_CFRNet_0.0001": "x",
117
    "Torch_DRLearner": "x",
118
    "Torch_RALearner": "H",
119
    "Torch_ULearner": "x",
120
    "Torch_PWLearner": "*",
121
    "Torch_RLearner": "*",
122
    "Torch_FlexTENet": "*",
123
    'Torch_ActionNet': "*",
124
    "EconML_CausalForestDML": "d",
125
    "EconML_DML": "o",
126
    "EconML_DMLOrthoForest": "^",
127
    "EconML_DRLearner": "*",
128
    "EconML_DROrthoForest": "D",
129
    "EconML_ForestDRLearner": "8",
130
    "EconML_LinearDML": "s",
131
    "EconML_LinearDRLearner": "x",
132
    "EconML_SparseLinearDML": "x",
133
    "EconML_SparseLinearDRLearner": "H",
134
    "EconML_TLearner_Lasso": "o",
135
    "EconML_SLearner_Lasso": "^",
136
    "EconML_XLearner_Lasso": "d",
137
    "DiffPOLearner": "H",
138
    "Truth": "<",
139
}
140
141
datasets_names_map = {
142
    "tcga_100": "TCGA", 
143
    "twins": "Twins", 
144
    "news_100": "News", 
145
    "all_notupro_technologies": "AllNoTuproTechnologies",
146
    "all_notupro_technologies_small": "AllNoTuproTechnologiesSmall",
147
    "dummy_data": "DummyData",
148
    "selected_technologies_pategan_1000": "selected_technologies_pategan_1000",
149
    "selected_technologies_with_fastdrug": "selected_technologies_with_fastdrug",
150
    "cytof_normalized":"cytof_normalized",
151
    "cytof_normalized_with_fastdrug":"cytof_normalized_with_fastdrug",
152
    "cytof_pategan_1000_normalized": "cytof_pategan_1000_normalized",
153
    "all_notupro_technologies_with_fastdrug": "all_notupro_technologies_with_fastdrug",
154
    "acic": "ACIC2016", 
155
    "depmap_drug_screen_2_drugs": "depmap_drug_screen_2_drugs",
156
    "depmap_drug_screen_2_drugs_norm": "depmap_drug_screen_2_drugs_norm",
157
    "depmap_drug_screen_2_drugs_all_features": "depmap_drug_screen_2_drugs_all_features",
158
    "depmap_drug_screen_2_drugs_all_features_norm": "depmap_drug_screen_2_drugs_all_features_norm",
159
    "depmap_crispr_screen_2_kos": "depmap_crispr_screen_2_kos",
160
    "depmap_crispr_screen_2_kos_norm": "depmap_crispr_screen_2_kos_norm",
161
    "depmap_crispr_screen_2_kos_all_features": "depmap_crispr_screen_2_kos_all_features",
162
    "depmap_crispr_screen_2_kos_all_features_norm": "depmap_crispr_screen_2_kos_all_features_norm",
163
    "depmap_drug_screen_2_drugs_100pcs_norm":"depmap_drug_screen_2_drugs_100pcs_norm",
164
    "depmap_drug_screen_2_drugs_5000hv_norm":"depmap_drug_screen_2_drugs_5000hv_norm",
165
    "depmap_crispr_screen_2_kos_100pcs_norm":"depmap_crispr_screen_2_kos_100pcs_norm",
166
    "depmap_crispr_screen_2_kos_5000hv_norm":"depmap_crispr_screen_2_kos_5000hv_norm",
167
    "independent_normally_dist": "independent_normally_dist",
168
    "ovarian_semi_synthetic_l1":"ovarian_semi_synthetic_l1",
169
    "ovarian_semi_synthetic_rf":"ovarian_semi_synthetic_rf",
170
    "melanoma_semi_synthetic_l1": "melanoma_semi_synthetic_l1",
171
    "pred": "Predictive confounding", 
172
    "prog": "Prognostic confounding", 
173
    "irrelevant_var": "Non-confounded propensity",
174
    "selective": "General confounding"}
175
176
metric_names_map = {
177
    'Pred: Pred features ACC': r'Predictive $\mathrm{Attr}$', #^{\mathrm{pred}}_{\mathrm{pred}}$',
178
    'Pred: Prog features ACC': r'$\mathrm{Attr}^{\mathrm{pred}}_{\mathrm{prog}}$',
179
    'Pred: Select features ACC': r'$\mathrm{Attr}^{\mathrm{pred}}_{\mathrm{select}}$',
180
    'Prog: Pred features ACC': r'$\mathrm{Attr}^{\mathrm{prog}}_{\mathrm{pred}}$',
181
    'Prog: Prog features ACC': r'Prognostic $\mathrm{Attr}$', #^{\mathrm{prog}}$', #_{\mathrm{prog}}$',
182
    'Prog: Select features ACC': r'$\mathrm{Attr}^{\mathrm{prog}}_{\mathrm{select}}$',
183
    'Select: Pred features ACC': r'$\mathrm{Attr}^{\mathrm{select}}_{\mathrm{pred}}$',
184
    'Select: Prog features ACC': r'$\mathrm{Attr}^{\mathrm{select}}_{\mathrm{prog}}$',
185
    'Select: Select features ACC': r'$\mathrm{Attr}^{\mathrm{select}}_{\mathrm{select}}$',
186
    'CI Coverage': 'CI Coverage',
187
    'Normalized PEHE': 'N. PEHE',
188
    'PEHE': 'PEHE',
189
    'CF RMSE': 'CF-RMSE',
190
    'AUROC': 'AUROC',
191
    'Factual AUROC': 'Factual AUROC',
192
    'CF AUROC': "CF AUROC",
193
    'Factual RMSE': 'F-RMSE',
194
    "Factual RMSE Y0": "F-RMSE Y0", 
195
    "Factual RMSE Y1": "F-RMSE Y1",
196
    "CF RMSE Y0": "CF-RMSE Y0",
197
    "CF RMSE Y1": "CF-RMSE Y1",
198
    'Normalized F-RMSE': 'N. F-RMSE',
199
    'Normalized CF-RMSE': 'N. CF-RMSE',
200
    "F-Outcome true mean":"F-Outcome true mean",
201
    "CF-Outcome true mean":"CF-Outcome true mean",
202
    "F-Outcome true std":"F-Outcome true std",
203
    "CF-Outcome true std":"CF-Outcome true std",
204
    "F-CF Outcome Diff":"F-CF Outcome Diff",
205
    'Swap AUROC@1': 'AUROC@1',
206
    'Swap AUPRC@1': 'AUPRC@1',
207
    'Swap AUROC@5': 'AUROC@5',
208
    'Swap AUPRC@5': 'AUPRC@5',
209
    'Swap AUROC@tre': 'AUROC@tre',
210
    'Swap AUPRC@tre': 'AUPRC@tre',
211
    'Swap AUROC@all': 'AUROC',
212
    'Swap AUPRC@all': 'AUPRC',
213
    "GT Pred Expertise": r'$\mathrm{B}^{\pi}_{Y_1-Y_0}$',
214
    "GT Prog Expertise": r'$\mathrm{B}^{\pi}_{Y_0}$',
215
    "GT Tre Expertise": r'$\mathrm{B}^{\pi}_{Y_1}$',
216
    "Upd. GT Pred Expertise": r'$\mathrm{B}^{\hat{\pi}}_{Y_1-Y_0}$',
217
    "Upd. GT Prog Expertise": r'$\mathrm{B}^{\hat{\pi}}_{Y_0}$',
218
    "Upd. GT Tre Expertise": r'$\mathrm{B}^{\hat{\pi}}_{Y_1}$',
219
    "GT Expertise Ratio": r'$\mathrm{E}^{\pi}_{\mathrm{ratio}}$',
220
    "GT Total Expertise": r'$\mathrm{B}^{\pi}_{Y_0,Y_1}$',
221
    "ES Pred Expertise": "ES Pred Bias",
222
    "ES Prog Expertise": "ES Prog Bias",
223
    "ES Total Expertise": "ES Outcome Bias",
224
    "Pred Precision": r'$\mathrm{Prec}^{\hat{\pi}}_{\mathrm{Ass.}}$',
225
    "Policy Precision": r'$\mathrm{Prec}^{\pi}_{\mathrm{Ass.}}$',
226
    'T Distribution: Train': 'T Distribution: Train',
227
    'T Distribution: Test': 'T Distribution: Test',
228
    'True Swap Perc': 'True Swap Perc',
229
    "Normalized F-CF Diff": "Normalized F-CF Diff",
230
    'Training Duration': 'Training Duration',
231
    "FC PEHE":"PEHE(Model) - PEHE(TARNet)",
232
    "FC F-RMSE":"Rel. N. F-RMSE",
233
    "FC CF-RMSE":"Rel. N. CF-RMSE",
234
    "FC Swap AUROC":"Rel. AUROC",
235
    "FC Swap AUPRC":"Rel. AUPRC",
236
    "GT In-context Var":r'$\mathrm{B}^{\pi}_{X}$', 
237
    "ES In-context Var":"ES Total Bias",
238
    "GT-ES Pred Expertise Diff":r'$\mathrm{E}^{\pi}_{\mathrm{pred}}$ Error',
239
    "GT-ES Prog Expertise Diff":r'$\mathrm{E}^{\pi}_{\mathrm{prog}}$ Error',
240
    "GT-ES Total Expertise Diff":r'$\mathrm{E}^{\pi}$ Error',
241
    "RMSE Y0":"RMSE Y0",
242
    "RMSE Y1":"RMSE Y1",
243
}
244
245
learners_names_map = {
246
    "Torch_TLearner":"T-Learner-MLP", 
247
    "Torch_SLearner": "S-Learner-MLP", 
248
    "Torch_TARNet": "Baseline-TAR",  
249
    "Torch_DragonNet": "DragonNet-1 (Act. Pred.)",
250
    "Torch_DragonNet_2": "DragonNet-2 (Act. Pred.)",
251
    "Torch_DragonNet_4": "DragonNet-4 (Act. Pred.)",
252
    "Torch_DRLearner": "Direct-DR", 
253
    "Torch_XLearner": "XLearner-MLP (Direct)", 
254
    "Torch_CFRNet_0.001": 'CFRNet-0.001 (Balancing)', #-\gamma=0.001)$',  
255
    "Torch_CFRNet_0.01": 'CFRNet-0.01 (Balancing)', 
256
    "Torch_CFRNet_0.0001": 'CFRNet-0.0001 (Balancing)', 
257
    'Torch_ActionNet': "ActionNet (Act. Pred.)",
258
    "Torch_RALearner": "RA-Learner",
259
    "Torch_ULearner": "U-Learner",
260
    "Torch_PWLearner":"Torch_PWLearner",
261
    "Torch_RLearner":"Torch_RLearner",
262
    "Torch_FlexTENet":"Torch_FlexTENet",
263
    "EconML_CausalForestDML": "CausalForestDML",
264
    "EconML_DML": "APred-Prop-Lasso",
265
    "EconML_DMLOrthoForest": "DMLOrthoForest",
266
    "EconML_DRLearner": "DRLearner",
267
    "EconML_DROrthoForest": "DROrthoForest",
268
    "EconML_ForestDRLearner": "ForestDRLearner",
269
    "EconML_LinearDML": "LinearDML",
270
    "EconML_LinearDRLearner": "LinearDRLearner",
271
    "EconML_SparseLinearDML": "SparseLinearDML",
272
    "EconML_SparseLinearDRLearner": "SparseLinearDRLearner",
273
    "EconML_TLearner_Lasso": "T-Learner-Lasso",
274
    "EconML_SLearner_Lasso": "S-Learner-Lasso",
275
    "EconML_XLearner_Lasso": "XLearnerLasso",
276
    "DiffPOLearner": "DiffPOLearner",
277
    
278
    "Truth": "Truth"
279
}
280
281
compare_values_map = {
282
    # Propensity
283
    "none_prog": r'$\pi_{\mathrm{RCT}} \rightarrow \pi_{\mathrm{Y_0}}^\beta$',
284
    "none_tre": r'$\pi_{\mathrm{RCT}} \rightarrow \pi_{\mathrm{Y_1}}^\beta$',
285
    "none_pred": r'$\pi_{\mathrm{RCT}} \rightarrow \pi_{\mathrm{Y_1-Y_0}}^\beta$',
286
    "rct_none": r'$\pi_{\mathrm{RCT}} \rightarrow \pi_{\mathrm{X_{irr}}}^\beta$',
287
    "none_pred_overlap": r'$\pi_{\mathrm{RCT}} \rightarrow \pi_{\mathrm{X_{pred}}}^\beta$',
288
289
    # Toy
290
    "toy7": r'$\pi_{\mathrm{RCT}} \rightarrow \pi_{\mathrm{T_7}}^\beta$',
291
    "toy8_nonlinear": r'$\pi_{\mathrm{RCT}} \rightarrow \pi_{\mathrm{T_8}}^\beta$',
292
    "toy1_linear": r'$\pi_{\mathrm{RCT}} \rightarrow \pi_{\mathrm{T_1^{lin}}}^\beta$',
293
    "toy1_nonlinear": r'$\mathrm{Toy 1:} \pi_{\mathrm{T_1}}^\beta$', #r'$\pi_{\mathrm{RCT}} \rightarrow \pi_{\mathrm{T_1}}^\beta$',
294
    "toy2_linear": r'$\pi_{\mathrm{RCT}} \rightarrow \pi_{\mathrm{T_3^{lin}}}^\beta$',
295
    "toy2_nonlinear": r'$\mathrm{Toy 3:} \pi_{\mathrm{T_3}}^\beta$',
296
    "toy3_nonlinear": r'$\mathrm{Toy 2:} \pi_{\mathrm{T_2}}^\beta$',
297
    "toy4_nonlinear": r'$\pi_{\mathrm{RCT}} \rightarrow \pi_{\mathrm{T_4}}^\beta$',
298
    "toy5": r'$\pi_{\mathrm{RCT}} \rightarrow \pi_{\mathrm{T_5}}^\beta$',
299
    "toy6_nonlinear": r'$\mathrm{Toy 4:} \pi_{\mathrm{T_4}}^\beta$', #r'$\pi_{\mathrm{RCT}} \rightarrow \pi_{\mathrm{T_6}}^\beta$',
300
301
    # Expertise
302
    # "prog_tre": r'$\pi_{\mathrm{Y_0}}^{\beta=4} \rightarrow \pi_{\mathrm{Y_1-Y_0}}^{\beta=4} \rightarrow \pi_{\mathrm{Y_1}}^{\beta=4}$',
303
    # "none_prog": r'$\pi_{\mathrm{X_{rand}}}^{\beta=4} \rightarrow \pi_{\mathrm{Y_0}}^{\beta=4}$',
304
    # "none_tre": r'$\pi_{\mathrm{X_{rand}}}^{\beta=4} \rightarrow \pi_{\mathrm{Y_1}}^{\beta=4}$',
305
    # "none_pred": r'$\pi_{\mathrm{X_{rand}}}^{\beta=4} \rightarrow \pi_{\mathrm{Y_1-Y_0}}^{\beta=4}$',
306
    #["prog_tre", "none_prog", "none_tre", "none_pred"]
307
308
    0: r'$\pi_{\mathrm{RCT}}$',
309
    2: r'$\pi_{\mathrm{Y_1-Y_0}}^{\beta=2}$',
310
    100: r'$\pi_{\mathrm{Y_1-Y_0}}^{\beta=100}$',
311
312
    "0": r'$\pi_{\mathrm{RCT}}$',
313
    "2": r'$\pi_{\mathrm{Y_1-Y_0}}^{\beta=2}$',
314
    "100": r'$\pi_{\mathrm{Y_1-Y_0}}^{\beta=100}$',
315
}
316
317
318
def plot_results_datasets_compare(results_df: pd.DataFrame, 
319
                          model_names: list,
320
                          dataset: str,
321
                          compare_axis: str,
322
                          compare_axis_values,
323
                          x_axis, 
324
                          x_label_name, 
325
                          x_values_to_plot, 
326
                          metrics_list, 
327
                          learners_list, 
328
                          figsize, 
329
                          legend_position, 
330
                          seeds_list, 
331
                          n_splits,
332
                          sharey=False, 
333
                          legend_rows=1,
334
                          dim_X=1,
335
                          log_x_axis = False): 
336
    """
337
    Plot the results for a given dataset.
338
    """
339
    # Get the unique values of the compare axis
340
    if compare_axis_values is None:
341
        compare_axis_values = results_df[compare_axis].unique()
342
343
    metrics_list = ["Pred Precision"]
344
    # Initialize the plot
345
    nrows = len(metrics_list)
346
    columns = len(compare_axis_values)
347
    figsize = (3*columns+2, 3*nrows)
348
    #figsize = (3*columns, 3)
349
350
    font_size=10
351
    fig, axs = plt.subplots(len(metrics_list), len(compare_axis_values), figsize=figsize, squeeze=False, sharey=sharey, dpi=500)
352
    plt.gcf().subplots_adjust(bottom=0.15)
353
    
354
    # Aggregate results across seeds for each metric
355
    for i in range(len(compare_axis_values)):
356
        cmp_value = compare_axis_values[i]
357
        for metric_id, metric in enumerate(metrics_list):
358
            for model_name in model_names:
359
                # Extract results for individual cate models
360
                sub_df = results_df.loc[(results_df["Learner"] == model_name)]
361
                sub_df = sub_df.loc[(sub_df[compare_axis] == cmp_value)][[x_axis, metric]]
362
                sub_df = sub_df[sub_df[x_axis].isin(x_values_to_plot)]
363
                sub_df_mean = sub_df.groupby(x_axis).agg('median').reset_index()
364
                sub_df_std = sub_df.groupby(x_axis).agg('std').reset_index()
365
                sub_df_min = sub_df.groupby(x_axis).agg('min').reset_index()
366
                sub_df_max = sub_df.groupby(x_axis).agg('max').reset_index()
367
368
                # Plot the results
369
                x_values = sub_df_mean.loc[:, x_axis].values
370
371
                try:
372
                    y_values = sub_df_mean.loc[:, metric].values
373
                except:
374
                    continue
375
376
                y_err = sub_df_std.loc[:, metric].values / (np.sqrt(n_splits*len(seeds_list)))
377
                y_min = sub_df_min.loc[:, metric].values
378
                y_max = sub_df_max.loc[:, metric].values
379
                
380
                # axs[metric_id][i].plot(x_values, y_values, label=learners_names_map[model_name], 
381
                #                                             color=learner_colors[model_name], linestyle=learner_linestyles[model_name], marker=learner_markers[model_name], markersize=5)
382
                axs[metric_id][i].plot(x_values, y_values, label=learners_names_map[model_name], 
383
                                                            color=learner_colors[model_name], linestyle=learner_linestyles[model_name], marker=learner_markers[model_name], markersize=3, alpha=0.5)
384
                axs[metric_id][i].fill_between(x_values, y_values-y_err, y_values+y_err, alpha=0.1, color=learner_colors[model_name])
385
386
            
387
            
388
            # if log_x_axis:
389
            #     axs[metric_id][i].set_xscale('symlog', linthresh=0.5, base=2)
390
            #     #axs[metric_id][i].fill_between(x_values, y_min, y_max, alpha=0.1, color=learner_colors[model_name])
391
            
392
            axs[metric_id][i].tick_params(axis='x', labelsize=font_size-2)
393
            axs[metric_id][i].tick_params(axis='y', labelsize=font_size-1)
394
395
            
396
            axs[metric_id][i].set_title(compare_values_map[cmp_value], fontsize=font_size+11, y=1.04)
397
398
            axs[metric_id][i].set_xlabel(x_label_name, fontsize=font_size-1)
399
            if i == 0:
400
                axs[metric_id][i].set_ylabel(metric_names_map[metric], fontsize=font_size-1)
401
402
            if log_x_axis:
403
                axs[metric_id][i].set_xscale('symlog', linthresh=0.5, base=2)
404
                # Display as fractions if not integers and as integers if integers
405
                # axs[0][i].xaxis.set_major_formatter(ScalarFormatter())
406
                # Get the current ticks
407
                current_ticks = axs[metric_id][i].get_xticks()
408
                
409
                # Calculate the midpoint between the first and second tick
410
                if len(current_ticks) > 1:
411
                    midpoint = (current_ticks[0] + current_ticks[1]) / 2
412
                    # Add the midpoint to the list of ticks
413
                    new_ticks = [current_ticks[0], midpoint] + list(current_ticks[1:])
414
                    axs[metric_id][i].set_xticks(new_ticks)
415
416
                # Add a tick at 0.25
417
                axs[metric_id][i].set_xticks(sorted(set(axs[metric_id][i].get_xticks()).union({0.25})))
418
                axs[metric_id][i].xaxis.set_major_formatter(FuncFormatter(custom_formatter))
419
420
            if metric in ["True Swap Perc", "T Distribution: Train", "T Distribution: Test", "GT Total Expertise", "ES Total Expertise", "GT Expertise Ratio", "GT Pred Expertise", "GT Prog Expertise", "ES Pred Expertise", "ES Prog Expertise","GT In-context Var","ES In-context Var","GT-ES Pred Expertise Diff","GT-ES Prog Expertise Diff","GT-ES Total Expertise Diff", "Policy Precision", "GT In-context Var", "GT Total Expertise", "GT Prog Expertise", "GT Tre Expertise", "GT Pred Expertise", "Upd. GT Prog Expertise", "Upd. GT Tre Expertise", "Upd. GT Pred Expertise"]:
421
                axs[metric_id][i].set_ylim(0, 1)
422
423
            if metric == "PEHE":
424
                axs[metric_id][i].set_ylim(top = 1.75)
425
            #axs[metric_id][i].set_ylim(bottom=0.475)
426
            #axs[metric_id][i].set_aspect(0.7/axs[metric_id][i].get_data_ratio(), adjustable='box')
427
            #axs[metric_id][i].tick_params(axis='y', labelsize=font_size-1)
428
429
            # axs[metric_id][i].tick_params(
430
            #     axis='x',          # changes apply to the x-axis
431
            #     which='both',      # both major and minor ticks are affected
432
            #     bottom=False,      # ticks along the bottom edge are off
433
            #     top=False,         # ticks along the top edge are off
434
            #     labelbottom=False) # labels along the bottom edge are off
435
436
    # Add the legend
437
    lines_labels = [ax.get_legend_handles_labels() for ax in fig.axes]
438
    lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
439
    legend_rows = 6
440
441
    # Iterate over each row of subplots
442
    for row in range(len(axs)):
443
        # Create a legend for each row
444
        handles, labels = axs[row, -1].get_legend_handles_labels()
445
        axs[row, -1].legend(
446
            lines[:len(learners_list)],
447
            labels[:len(learners_list)],
448
            ncol=1, #len(learners_list) if legend_rows == 1 else int((len(learners_list) + 1) / legend_rows),
449
            loc='center right',
450
            bbox_to_anchor=(1.8, 0.5),
451
            prop={'size': font_size+2}
452
        )
453
454
455
    #fig.tight_layout()
456
    plt.subplots_adjust( wspace=0.07)
457
    return fig   
458
459
def plot_performance_metrics(results_df: pd.DataFrame, 
460
                          model_names: list,
461
                          dataset: str,
462
                          compare_axis: str,
463
                          compare_axis_values,
464
                          x_axis, 
465
                          x_label_name, 
466
                          x_values_to_plot, 
467
                          metrics_list, 
468
                          learners_list, 
469
                          figsize, 
470
                          legend_position, 
471
                          seeds_list, 
472
                          n_splits,
473
                          sharey=False, 
474
                          legend_rows=1,
475
                          dim_X=1,
476
                          log_x_axis = False): 
477
    
478
    # Get the unique values of the compare axis
479
    if compare_axis_values is None:
480
        compare_axis_values = results_df[compare_axis].unique()
481
482
    metrics_list = ['PEHE', 'FC PEHE', "Pred Precision", 'Pred: Pred features ACC', 'Prog: Prog features ACC'] #]
483
    #log_x_axis=False
484
    # Initialize the plot
485
    nrows = len(metrics_list)
486
    columns = len(compare_axis_values)
487
488
    #figsize = (3*columns+2, 3*nrows) #PREV
489
    figsize = (3*columns+2, 3.4*nrows)
490
    #figsize = (3*columns, 3)
491
492
    font_size=10
493
    fig, axs = plt.subplots(len(metrics_list), len(compare_axis_values), figsize=figsize, squeeze=False, sharey=sharey, dpi=500)
494
    plt.gcf().subplots_adjust(bottom=0.15)
495
    
496
    model_names_cpy = model_names.copy()
497
    # Aggregate results across seeds for each metric
498
    for i in range(len(compare_axis_values)):
499
        cmp_value = compare_axis_values[i]
500
        for metric_id, metric in enumerate(metrics_list):
501
            # if metric in ["FC PEHE", 'Prog: Prog features ACC', 'Prog: Pred features ACC']:
502
            #     model_names = ["Torch_TARNet","Torch_DragonNet","Torch_CFRNet_0.001","EconML_TLearner_Lasso"]
503
            # else:
504
            model_names = model_names_cpy #["Torch_TARNet","Torch_DragonNet","Torch_ActionNet", "Torch_CFRNet_0.001","EconML_TLearner_Lasso"]
505
506
            for model_name in model_names:
507
                # Extract results for individual cate models
508
                sub_df = results_df.loc[(results_df["Learner"] == model_name)]
509
                sub_df = sub_df.loc[(sub_df[compare_axis] == cmp_value)][[x_axis, metric]]
510
                sub_df = sub_df[sub_df[x_axis].isin(x_values_to_plot)]
511
                sub_df_mean = sub_df.groupby(x_axis).agg('median').reset_index()
512
                sub_df_std = sub_df.groupby(x_axis).agg('std').reset_index()
513
                sub_df_min = sub_df.groupby(x_axis).agg('min').reset_index()
514
                sub_df_max = sub_df.groupby(x_axis).agg('max').reset_index()
515
516
                # Plot the results
517
                x_values = sub_df_mean.loc[:, x_axis].values
518
519
                try:
520
                    y_values = sub_df_mean.loc[:, metric].values
521
                except:
522
                    continue
523
524
                y_err = sub_df_std.loc[:, metric].values / (np.sqrt(n_splits*len(seeds_list)))
525
                y_min = sub_df_min.loc[:, metric].values
526
                y_max = sub_df_max.loc[:, metric].values
527
                
528
                # axs[metric_id][i].plot(x_values, y_values, label=learners_names_map[model_name], 
529
                #                                             color=learner_colors[model_name], linestyle=learner_linestyles[model_name], marker=learner_markers[model_name], markersize=5)
530
                axs[metric_id][i].plot(x_values, y_values, label=learners_names_map[model_name], 
531
                                                            color=learner_colors[model_name], linestyle=learner_linestyles[model_name], marker=learner_markers[model_name], markersize=3, alpha=0.5)
532
                axs[metric_id][i].fill_between(x_values, y_values-y_err, y_values+y_err, alpha=0.1, color=learner_colors[model_name])
533
534
            
535
            
536
            # if log_x_axis:
537
            #     axs[metric_id][i].set_xscale('symlog', linthresh=0.5, base=2)
538
            #     #axs[metric_id][i].fill_between(x_values, y_min, y_max, alpha=0.1, color=learner_colors[model_name])
539
            
540
            axs[metric_id][i].tick_params(axis='x', labelsize=font_size-2)
541
            axs[metric_id][i].tick_params(axis='y', labelsize=font_size-1)
542
543
            
544
            if metric_id == 0:
545
                axs[metric_id][i].set_title(compare_values_map[cmp_value], fontsize=font_size+1, y=1.0)
546
547
548
            axs[metric_id][i].set_xlabel(x_label_name, fontsize=font_size-1)
549
            if i == 0:
550
                axs[metric_id][i].set_ylabel(metric_names_map[metric], fontsize=font_size-1)
551
552
            if log_x_axis:
553
                axs[metric_id][i].set_xscale('symlog', linthresh=0.5, base=2)
554
                # Display as fractions if not integers and as integers if integers
555
                # axs[0][i].xaxis.set_major_formatter(ScalarFormatter())
556
                # Get the current ticks
557
                current_ticks = axs[metric_id][i].get_xticks()
558
                
559
                # Calculate the midpoint between the first and second tick
560
                if len(current_ticks) > 1:
561
                    midpoint = (current_ticks[0] + current_ticks[1]) / 2
562
                    # Add the midpoint to the list of ticks
563
                    new_ticks = [current_ticks[0], midpoint] + list(current_ticks[1:])
564
                    axs[metric_id][i].set_xticks(new_ticks)
565
566
                # Add a tick at 0.25
567
                axs[metric_id][i].set_xticks(sorted(set(axs[metric_id][i].get_xticks()).union({0.25})))
568
                axs[metric_id][i].xaxis.set_major_formatter(FuncFormatter(custom_formatter))
569
570
            if metric in ["True Swap Perc", "T Distribution: Train", "T Distribution: Test", "GT Total Expertise", "ES Total Expertise", "GT Expertise Ratio", "GT Pred Expertise", "GT Prog Expertise", "ES Pred Expertise", "ES Prog Expertise","GT In-context Var","ES In-context Var","GT-ES Pred Expertise Diff","GT-ES Prog Expertise Diff","GT-ES Total Expertise Diff", "Policy Precision", "GT In-context Var", "GT Total Expertise", "GT Prog Expertise", "GT Tre Expertise", "GT Pred Expertise", "Upd. GT Prog Expertise", "Upd. GT Tre Expertise", "Upd. GT Pred Expertise"]:
571
                axs[metric_id][i].set_ylim(0, 1)
572
573
            # if metric == "PEHE":
574
            #     axs[metric_id][i].set_ylim(top = 1.75)
575
            #axs[metric_id][i].set_ylim(bottom=0.475)
576
            #axs[metric_id][i].set_aspect(0.7/axs[metric_id][i].get_data_ratio(), adjustable='box')
577
            #axs[metric_id][i].tick_params(axis='y', labelsize=font_size-1)
578
579
            axs[metric_id][i].tick_params(
580
                axis='x',          # changes apply to the x-axis
581
                which='both',      # both major and minor ticks are affected
582
                bottom=False,      # ticks along the bottom edge are off
583
                top=False,         # ticks along the top edge are off
584
                labelbottom=False) # labels along the bottom edge are off
585
586
    # Add the legend
587
    lines_labels = [ax.get_legend_handles_labels() for ax in fig.axes]
588
    lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
589
    legend_rows = 6
590
591
    # Iterate over each row of subplots
592
    for row in range(len(axs)):
593
        # Create a legend for each row
594
        handles, labels = axs[row, -1].get_legend_handles_labels()
595
        axs[row, -1].legend(
596
            lines[:len(learners_list)],
597
            labels[:len(learners_list)],
598
            ncol=1, #len(learners_list) if legend_rows == 1 else int((len(learners_list) + 1) / legend_rows),
599
            loc='center right',
600
            bbox_to_anchor=(1.9, 0.5),
601
            prop={'size': font_size+2}
602
        )
603
604
605
    plt.subplots_adjust( wspace=0.07)
606
    #fig.tight_layout()
607
    return fig  
608
609
610
611
def plot_performance_metrics_f_cf(results_df: pd.DataFrame, 
612
                          model_names: list,
613
                          dataset: str,
614
                          compare_axis: str,
615
                          compare_axis_values,
616
                          x_axis, 
617
                          x_label_name, 
618
                          x_values_to_plot, 
619
                          metrics_list, 
620
                          learners_list, 
621
                          figsize, 
622
                          legend_position, 
623
                          seeds_list, 
624
                          n_splits,
625
                          sharey=False, 
626
                          legend_rows=1,
627
                          dim_X=1,
628
                          log_x_axis = False): 
629
    # Get the unique values of the compare axis
630
    if compare_axis_values is None:
631
        compare_axis_values = results_df[compare_axis].unique()
632
633
    # Initialize the plot
634
    #model_names = model_names[0] #["Torch_TARNet"] #[EconML_TLearner_Lasso"]
635
    columns = len(compare_axis_values)
636
    rows = len(model_names)
637
    figsize = (3*columns+2, 3.3*rows)
638
    #figsize = (3*columns, 3)
639
    font_size=10
640
    fig, axs = plt.subplots(len(model_names), len(compare_axis_values), figsize=figsize, squeeze=False, sharey=sharey, dpi=500)
641
    #plt.gcf().subplots_adjust(bottom=0.15)
642
    
643
    # Filter results_df for first model and first seed and first split
644
    
645
    # results_df = results_df.loc[(results_df["Seed"] == seeds_list[0])]
646
    # results_df = results_df.loc[(results_df["Split ID"] == 0)]
647
648
    # Only consider expertise metrics
649
    #colors = ['black', 'orange', 'darkorange', 'orchid', 'darkorchid']
650
    colors = ['blue', 'lightcoral', 'lightgreen', 'red', 'green']
651
652
    markers = ['o', 'D', 'D', 'x',  'x']
653
    metrics_list = ["PEHE", "Factual RMSE Y0", "Factual RMSE Y1", "CF RMSE Y0", "CF RMSE Y1"]
654
    filtered_df = results_df[[x_axis, compare_axis] + metrics_list]
655
656
    # Aggregate results across seeds for each metric
657
    for model_id, model_name in enumerate(model_names):
658
        filtered_df_model = filtered_df.loc[(results_df["Learner"] == model_name)]
659
        for i in range(len(compare_axis_values)):
660
            cmp_value = compare_axis_values[i]
661
662
            # Plot all metric outcomes as lines in a single plot for the given cmp_value and use x_axis as x-axis
663
            x_values = filtered_df_model[x_axis].values
664
665
            for metric_id, metric in enumerate(metrics_list):
666
                # Extract results for individual cate models
667
668
                sub_df = filtered_df_model.loc[(filtered_df_model[compare_axis] == cmp_value)][[x_axis, metric]]
669
                sub_df = sub_df[sub_df[x_axis].isin(x_values_to_plot)]
670
                sub_df_mean = sub_df.groupby(x_axis).agg('median').reset_index()
671
                sub_df_std = sub_df.groupby(x_axis).agg('std').reset_index()
672
                sub_df_min = sub_df.groupby(x_axis).agg('min').reset_index()
673
                sub_df_max = sub_df.groupby(x_axis).agg('max').reset_index()
674
675
                # Plot the results
676
                x_values = sub_df_mean.loc[:, x_axis].values
677
678
                try:
679
                    y_values = sub_df_mean.loc[:, metric].values
680
                except:
681
                    continue
682
683
                y_err = sub_df_std.loc[:, metric].values / (np.sqrt(n_splits*len(seeds_list)))
684
                y_min = sub_df_min.loc[:, metric].values
685
                y_max = sub_df_max.loc[:, metric].values
686
687
                # use a different linestyle for each metric
688
689
                axs[model_id][i].plot(x_values, y_values, label=metric_names_map[metric], 
690
                                                                color=colors[metric_id], linestyle='-', marker=markers[metric_id], alpha=0.5, markersize=3)
691
                axs[model_id][i].fill_between(x_values, y_values-y_err, y_values+y_err, alpha=0.1, color=colors[metric_id])
692
693
694
                axs[model_id][i].tick_params(axis='x', labelsize=font_size-2)
695
                axs[model_id][i].tick_params(axis='y', labelsize=font_size-1)
696
                
697
                axs[model_id][i].set_title(compare_values_map[cmp_value], fontsize=font_size+2, y=1.01)
698
699
                axs[model_id][i].set_xlabel(x_label_name, fontsize=font_size-1)
700
701
            # if i == 0:
702
            #     axs[model_id][i].set_ylabel(metric_names_map[metric], fontsize=font_size-1)
703
704
            if log_x_axis:
705
                axs[model_id][i].set_xscale('symlog', linthresh=0.5, base=2)
706
                # Display as fractions if not integers and as integers if integers
707
                # axs[0][i].xaxis.set_major_formatter(ScalarFormatter())
708
                # Get the current ticks
709
                current_ticks = axs[model_id][i].get_xticks()
710
                
711
                # Calculate the midpoint between the first and second tick
712
                if len(current_ticks) > 1:
713
                    midpoint = (current_ticks[0] + current_ticks[1]) / 2
714
                    # Add the midpoint to the list of ticks
715
                    new_ticks = [current_ticks[0], midpoint] + list(current_ticks[1:])
716
                    axs[model_id][i].set_xticks(new_ticks)
717
718
                # Add a tick at 0.25
719
                axs[model_id][i].set_xticks(sorted(set(axs[model_id][i].get_xticks()).union({0.25})))
720
                axs[model_id][i].xaxis.set_major_formatter(FuncFormatter(custom_formatter))
721
            axs[model_id][i].tick_params(
722
                axis='x',          # changes apply to the x-axis
723
                which='both',      # both major and minor ticks are affected
724
                bottom=False,      # ticks along the bottom edge are off
725
                top=False,         # ticks along the top edge are off
726
                labelbottom=False) # labels along the bottom edge are off
727
            
728
            axs[model_id][i].tick_params(axis='y', labelsize=font_size-1)
729
            #axs[model_id][i].set_aspect(0.7/axs[model_id][i].get_data_ratio(), adjustable='box')
730
731
732
    # Add the legend
733
    lines_labels = [ax.get_legend_handles_labels() for ax in fig.axes]
734
    lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
735
736
    # Add legends to the right of each row
737
    for i, row in enumerate(axs):
738
        lines_labels = [ax.get_legend_handles_labels() for ax in row]
739
        lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
740
        row[-1].legend(
741
            lines[:len(metrics_list)],
742
            labels[:len(metrics_list)],
743
            loc='center right',
744
            bbox_to_anchor=(1.8, 0.5),
745
            ncol=1,
746
            prop={'size': font_size+2},
747
            title_fontsize=font_size+4,
748
            title=learners_names_map[model_names[i]]
749
        )
750
751
    #fig.tight_layout()
752
    plt.subplots_adjust( wspace=0.07)
753
754
    return fig
755
756
757
def plot_expertise_metrics(results_df: pd.DataFrame, 
758
                          model_names: list,
759
                          dataset: str,
760
                          compare_axis: str,
761
                          compare_axis_values,
762
                          x_axis, 
763
                          x_label_name, 
764
                          x_values_to_plot, 
765
                          metrics_list, 
766
                          learners_list, 
767
                          figsize, 
768
                          legend_position, 
769
                          seeds_list, 
770
                          n_splits,
771
                          sharey=False, 
772
                          legend_rows=1,
773
                          dim_X=1,
774
                          log_x_axis = False): 
775
    
776
    if compare_axis_values is None:
777
        compare_axis_values = results_df[compare_axis].unique()
778
779
    # Initialize the plot
780
    columns = len(compare_axis_values)
781
    figsize = (3*columns+2, 3)
782
    font_size=10
783
    fig, axs = plt.subplots(1, len(compare_axis_values), figsize=figsize, squeeze=False, sharey=sharey, dpi=500)
784
    plt.gcf().subplots_adjust(bottom=0.15)
785
    
786
    # Filter results_df for first model and first seed and first split
787
    results_df = results_df.loc[(results_df["Learner"] == model_names[0])]
788
    results_df = results_df.loc[(results_df["Seed"] == seeds_list[0])]
789
    results_df = results_df.loc[(results_df["Split ID"] == 0)]
790
791
792
    # Only consider expertise metrics
793
    colors = ['black', 'grey', 'red', 'green', 'blue']
794
    markers = ['o', 'x', 'x',  'x', 'x']
795
    metrics_list = ["Policy Precision", "GT In-context Var", "GT Prog Expertise", "GT Tre Expertise", "GT Pred Expertise"]
796
    sub_df = results_df[[x_axis, compare_axis, "Seed", "Split ID"] + metrics_list]
797
798
    # Aggregate results across seeds for each metric
799
    for i in range(len(compare_axis_values)):
800
        cmp_value = compare_axis_values[i]
801
802
        # Plot all metric outcomes as lines in a single plot for the given cmp_value and use x_axis as x-axis
803
        filtered_df = sub_df[(sub_df[compare_axis] == cmp_value)]
804
        x_values = filtered_df[x_axis].values
805
806
        for metric_id, metric in enumerate(metrics_list):
807
            y_values = filtered_df[metric].values
808
            # use a different linestyle for each metric
809
810
            axs[0][i].plot(x_values, y_values, label=metric_names_map[metric], color=colors[metric_id], linestyle='-', marker=markers[metric_id], alpha=0.5,  markersize=5)
811
            
812
813
            # if i == 0:
814
            #     axs[0][i].set_ylabel("Selection Bias", fontsize=font_size)
815
816
            axs[0][i].tick_params(axis='x', labelsize=font_size-2)
817
            axs[0][i].tick_params(axis='y', labelsize=font_size-1)
818
            axs[0][i].set_title(compare_values_map[cmp_value], fontsize=font_size+11, y=1.04)
819
            axs[0][i].set_xlabel(x_label_name, fontsize=font_size-1)
820
821
        if log_x_axis:
822
            axs[0][i].set_xscale('symlog', linthresh=0.5, base=2)
823
            # Display as fractions if not integers and as integers if integers
824
            # axs[0][i].xaxis.set_major_formatter(ScalarFormatter())
825
            # Get the current ticks
826
            current_ticks = axs[0][i].get_xticks()
827
            
828
            # Calculate the midpoint between the first and second tick
829
            if len(current_ticks) > 1:
830
                midpoint = (current_ticks[0] + current_ticks[1]) / 2
831
                # Add the midpoint to the list of ticks
832
                new_ticks = [current_ticks[0], midpoint] + list(current_ticks[1:])
833
                axs[0][i].set_xticks(new_ticks)
834
835
            # Add a tick at 0.25
836
            axs[0][i].set_xticks(sorted(set(axs[0][i].get_xticks()).union({0.25})))
837
            axs[0][i].xaxis.set_major_formatter(FuncFormatter(custom_formatter))
838
839
        # if i == 0:
840
        #         axs[0][i].set_ylabel(metric_names_map[metric], fontsize=font_size-1)
841
842
        axs[0][i].tick_params(
843
            axis='x',          # changes apply to the x-axis
844
            which='both',      # both major and minor ticks are affected
845
            bottom=False,      # ticks along the bottom edge are off
846
            top=False,         # ticks along the top edge are off
847
            labelbottom=False) # labels along the bottom edge are off
848
        #axs[0][i].set_aspect(0.7/axs[0][i].get_data_ratio(), adjustable='box')
849
        axs[0][i].tick_params(axis='y', labelsize=font_size-1)
850
851
    # Add the legend
852
    lines_labels = [ax.get_legend_handles_labels() for ax in fig.axes]
853
    lines, labels = [sum(lol, []) for lol in zip(*lines_labels)]
854
855
    fig.legend(
856
        lines[:len(metrics_list)],
857
        labels[:len(metrics_list)],
858
        loc='center right',  # Position the legend to the right
859
        #bbox_to_anchor=(1, 0.5),  # Adjust the anchor point to the right center
860
        ncol=1,  # Set the number of columns to 1 for a vertical legend
861
        prop={'size': font_size+2}
862
    )
863
864
    #fig.tight_layout()
865
    plt.subplots_adjust( wspace=0.07)
866
867
    return fig
868
    
869
                           
870
871
def merge_pngs(images, axis="horizontal"):
872
    """
873
    Merge a list of png images into a single image.
874
    """
875
    widths, heights = zip(*(i.size for i in images))
876
877
    if axis == "vertical":
878
        total_height = sum(heights)
879
        max_width = max(widths)
880
881
        new_im = Image.new('RGB', (max_width, total_height))
882
883
        y_offset = 0
884
        for im in images:
885
            new_im.paste(im, (0,y_offset))
886
            y_offset += im.size[1]
887
        
888
        return new_im
889
    
890
    elif axis == "horizontal":
891
        total_width = sum(widths)
892
        max_height = max(heights)
893
894
        new_im = Image.new('RGB', (total_width, max_height))
895
896
        x_offset = 0
897
        for im in images:
898
            new_im.paste(im, (x_offset,0))
899
            x_offset += im.size[0]
900
        
901
        return new_im
902