a b/src/iterpretability/experiments/experiments_base.py
1
from pathlib import Path
2
import os
3
import catenets.models as cate_models
4
from catenets.models.diffpo import DiffPOLearner
5
import numpy as np
6
import pandas as pd
7
import wandb 
8
import shap
9
from PIL import Image
10
import src.iterpretability.logger as log
11
from sklearn.metrics import roc_auc_score
12
from src.plotting import (
13
    plot_results_datasets_compare, 
14
    plot_expertise_metrics,
15
    plot_performance_metrics,
16
    plot_performance_metrics_f_cf,
17
)
18
from src.iterpretability.explain import Explainer
19
from src.iterpretability.datasets.data_loader import load
20
from src.iterpretability.simulators import (
21
    SimulatorBase,
22
    TYSimulator,
23
    TSimulator
24
)
25
from src.iterpretability.utils import (
26
    attribution_accuracy,
27
)
28
import time
29
# For contour plotting
30
from sklearn.metrics import mean_squared_error, roc_auc_score, f1_score
31
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
32
from sklearn.multioutput import MultiOutputRegressor, MultiOutputClassifier
33
from sklearn.linear_model import LinearRegression, Lasso, LogisticRegression, LogisticRegressionCV, LassoCV
34
from sklearn.model_selection import LeaveOneOut
35
from sklearn.metrics import auc
36
37
# Hydra for configuration
38
from omegaconf import DictConfig, OmegaConf
39
40
class ExperimentBase():
41
    """
42
    Base class for all experiments.
43
    """
44
    def __init__(self, cfg: DictConfig) -> None:
45
        # Store configuration
46
        self.cfg = cfg
47
48
        # General experiment settings
49
        self.seeds = cfg.seeds
50
        self.n_splits = cfg.n_splits
51
        self.experiment_name = cfg.experiment_name
52
        self.simulation_type = cfg.simulator.simulation_type
53
        self.evaluate_inference = cfg.evaluate_inference
54
55
        # Load data
56
        if self.simulation_type == "TY":
57
            self.X,self.feature_names = load(cfg.dataset, 
58
                                            train_ratio=1, 
59
                                            debug=cfg.debug, 
60
                                            sim_type=self.simulation_type,
61
                                            directory_path_=cfg.directory_path+'/',
62
                                            repo_path = cfg.repo_path,
63
                                            n_samples=cfg.n_samples)
64
            self.outcomes = None
65
            self.discrete_outcome = cfg.simulator.num_binary_outcome == cfg.simulator.dim_Y # Currently the only discrete outcome is binary
66
67
        elif self.simulation_type == "T":
68
            self.X, self.outcomes, self.feature_names = load(cfg.dataset, 
69
                                                            train_ratio=1, 
70
                                                            debug=cfg.debug, 
71
                                                            directory_path_=cfg.directory_path+'/', 
72
                                                            repo_path = cfg.repo_path,
73
                                                            sim_type=self.simulation_type)
74
            
75
            # Check whether all entries in outcomes are integers
76
            self.discrete_outcome = cfg.simulator.num_binary_outcome
77
            
78
            cfg.simulator.num_T = self.outcomes.shape[1]
79
80
        else:
81
            raise ValueError(f"Simulation type {self.cfg.simulator.simulation_type} not supported.")
82
83
        # Initialize learners
84
        self.discrete_treatment = True # Currently simulation only supports discrete treatment
85
86
        # Results path and directories
87
        self.file_name = f"{self.cfg.results_dictionary_prefix}_{self.cfg.dataset}_T{self.cfg.simulator.num_T}_Y{self.cfg.simulator.dim_Y}_{self.cfg.simulator.simulation_type}sim_numbin{self.cfg.simulator.num_binary_outcome}"
88
        self.results_path = Path(self.cfg.results_path) / Path(self.cfg.experiment_name) / Path(self.file_name)
89
90
        if not self.results_path.exists():
91
            self.results_path.mkdir(parents=True, exist_ok=True)
92
93
        # Variables set by some or all experiments
94
        self.learners = None
95
        self.baseline_learners = None
96
        self.pred_learners = None
97
        self.prog_learner = None
98
        self.select_learner = None
99
        self.explanations = None
100
        self.all_important_features = None
101
        self.pred_features = None
102
        self.prog_features = None
103
        self.select_features = None
104
        self.true_num_swaps = None
105
        self.swap_counter = None
106
        self.training_times = {}
107
108
    def get_learners(self, 
109
                     num_features: int, 
110
                     seed: int = 123) -> dict:
111
        """
112
        Get learners for the experiment, based on the configuration settings.
113
        """
114
        if self.cfg.simulator.dim_Y > 1 and "Torch_XLearner" in self.cfg.model_names:
115
            raise ValueError("Torch_XLearner only supports one outcome dimension.")
116
        elif self.cfg.simulator.dim_Y > 1 and "EconML_SparseLinearDRLearner" in self.cfg.model_names:
117
            raise ValueError("EconML_SparseLinearDRLearner only supports one outcome dimension.")
118
        # Let user know that torch and diffpo do not support multiple treatments
119
        if self.cfg.simulator.num_T > 2 and ("Torch" in self.cfg.model_names or "DiffPOLearner" in self.cfg.model_names):
120
            raise ValueError("Torch and DiffPO models do not support multiple treatments. Only the first treatment will be used.")
121
122
123
        binary_y = self.discrete_outcome and self.cfg.simulator.num_binary_outcome == 1
124
        
125
        learners = {
126
127
                        "EconML_CausalForestDML": cate_models.econml.EconMlEstimator2(self.cfg,
128
                                                                                model_name="EconML_CausalForestDML",
129
                                                                                discrete_treatment=self.discrete_treatment,
130
                                                                                discrete_outcome=self.discrete_outcome,
131
                                                                                seed=seed),
132
133
                        "EconML_DML": cate_models.econml.EconMlEstimator2(self.cfg,
134
                                                                        model_name="EconML_DML",
135
                                                                        discrete_treatment=self.discrete_treatment,
136
                                                                        discrete_outcome=self.discrete_outcome,
137
                                                                        seed=seed),
138
139
                        "EconML_DRLearner": cate_models.econml.EconMlEstimator2(self.cfg,
140
                                                                            model_name="EconML_DRLearner",
141
                                                                            discrete_treatment=self.discrete_treatment,
142
                                                                            discrete_outcome=self.discrete_outcome,
143
                                                                            seed=seed),
144
145
                        "EconML_DMLOrthoForest": cate_models.econml.EconMlEstimator2(self.cfg,
146
                                                                                model_name="EconML_DMLOrthoForest",
147
                                                                                discrete_treatment=self.discrete_treatment,
148
                                                                                discrete_outcome=self.discrete_outcome,
149
                                                                                seed=seed),
150
151
                        "EconML_DROrthoForest": cate_models.econml.EconMlEstimator2(self.cfg,
152
                                                                                model_name="EconML_DROrthoForest",
153
                                                                                discrete_treatment=self.discrete_treatment,
154
                                                                                discrete_outcome=self.discrete_outcome,
155
                                                                                seed=seed),
156
157
                        "EconML_ForestDRLearner": cate_models.econml.EconMlEstimator2(self.cfg,
158
                                                                                    model_name="EconML_ForestDRLearner",
159
                                                                                    discrete_treatment=self.discrete_treatment,
160
                                                                                    discrete_outcome=self.discrete_outcome,
161
                                                                                    seed=seed),
162
163
                        "EconML_LinearDML": cate_models.econml.EconMlEstimator2(self.cfg,
164
                                                                            model_name="EconML_LinearDML",
165
                                                                            discrete_treatment=self.discrete_treatment,
166
                                                                            discrete_outcome=self.discrete_outcome,
167
                                                                            seed=seed),
168
169
                        "EconML_LinearDRLearner": cate_models.econml.EconMlEstimator2(self.cfg,
170
                                                                                model_name="EconML_LinearDRLearner",
171
                                                                                discrete_treatment=self.discrete_treatment,
172
                                                                                discrete_outcome=self.discrete_outcome,
173
                                                                                seed=seed),
174
175
                        "EconML_SparseLinearDML": cate_models.econml.EconMlEstimator2(self.cfg,
176
                                                                                    model_name="EconML_SparseLinearDML",
177
                                                                                    discrete_treatment=self.discrete_treatment,
178
                                                                                    discrete_outcome=self.discrete_outcome,
179
                                                                                    seed=seed),
180
181
182
                        "EconML_SparseLinearDRLearner": cate_models.econml.EconMlEstimator2(self.cfg, 
183
                                                                                            model_name="EconML_SparseLinearDRLearner",
184
                                                                                            discrete_treatment=self.discrete_treatment,
185
                                                                                            discrete_outcome=self.discrete_outcome,
186
                                                                                            seed=seed), 
187
                        
188
189
                        "EconML_SLearner_Lasso": cate_models.econml.EconMlEstimator2(self.cfg,
190
                                                                                    model_name="EconML_SLearner_Lasso",
191
                                                                                    discrete_treatment=self.discrete_treatment,
192
                                                                                    discrete_outcome=self.discrete_outcome,
193
                                                                                    seed=seed),
194
195
                        "EconML_TLearner_Lasso": cate_models.econml.EconMlEstimator2(self.cfg,
196
                                                                                    model_name="EconML_TLearner_Lasso",
197
                                                                                    discrete_treatment=self.discrete_treatment,
198
                                                                                    discrete_outcome=self.discrete_outcome,
199
                                                                                    seed=seed),
200
201
                        "EconML_XLearner_Lasso": cate_models.econml.EconMlEstimator2(self.cfg,
202
                                                                                    model_name="EconML_XLearner_Lasso",
203
                                                                                    discrete_treatment=self.discrete_treatment,
204
                                                                                    discrete_outcome=self.discrete_outcome,
205
                                                                                    seed=seed),
206
207
                        "Torch_SLearner": cate_models.torch.SLearner(num_features,
208
                                                                        binary_y=binary_y,
209
                                                                        **self.cfg.Torch_SLearner),
210
211
                        "Torch_TLearner": cate_models.torch.TLearner(num_features,
212
                                                                        binary_y=binary_y,
213
                                                                        **self.cfg.Torch_TLearner),
214
215
                        "Torch_XLearner": cate_models.torch.XLearner(num_features, 
216
                                                                     binary_y=binary_y, 
217
                                                                     **self.cfg.Torch_XLearner),
218
219
                        "Torch_DRLearner": cate_models.torch.DRLearner(num_features,
220
                                                                        binary_y=binary_y,
221
                                                                        **self.cfg.Torch_DRLearner),
222
223
                        "Torch_DragonNet": cate_models.torch.DragonNet(num_features,
224
                                                                        binary_y=binary_y,
225
                                                                        **self.cfg.Torch_DragonNet),
226
227
                        "Torch_DragonNet_2": cate_models.torch.DragonNet(num_features,
228
                                                                        binary_y=binary_y,
229
                                                                        **self.cfg.Torch_DragonNet_2),
230
231
                        "Torch_DragonNet_4": cate_models.torch.DragonNet(num_features,
232
                                                                        binary_y=binary_y,
233
                                                                        **self.cfg.Torch_DragonNet_4),                                                                                            
234
235
                        "Torch_ActionNet": cate_models.torch.ActionNet(num_features,
236
                                                                        binary_y=binary_y,
237
                                                                        **self.cfg.Torch_ActionNet),
238
                                                                        
239
                        # "Torch_FlexTENet": cate_models.torch.FlexTENet(num_features,
240
                        #                                                 binary_y=binary_y,
241
                        #                                                 **self.cfg.Torch_FlexTENet),
242
243
                        "Torch_PWLearner": cate_models.torch.PWLearner(num_features,
244
                                                                        binary_y=binary_y,
245
                                                                        **self.cfg.Torch_PWLearner),
246
247
                        "Torch_RALearner": cate_models.torch.RALearner(num_features,
248
                                                                        binary_y=binary_y,
249
                                                                        **self.cfg.Torch_RALearner),
250
251
                        "Torch_RLearner": cate_models.torch.RLearner(num_features,
252
                                                                        binary_y=binary_y,
253
                                                                        **self.cfg.Torch_RLearner),
254
255
                        "Torch_TARNet": cate_models.torch.TARNet(num_features,
256
                                                                        binary_y=binary_y,
257
                                                                        **self.cfg.Torch_TARNet),
258
259
                        "Torch_ULearner": cate_models.torch.ULearner(num_features,
260
                                                                        binary_y=binary_y,
261
                                                                        **self.cfg.Torch_ULearner),
262
263
                        "Torch_CFRNet_0.01": cate_models.torch.TARNet(num_features,
264
                                                                binary_y=binary_y,
265
                                                                **self.cfg.Torch_CRFNet_0_01),
266
267
                        "Torch_CFRNet_0.001": cate_models.torch.TARNet(num_features,
268
                                                                binary_y=binary_y,
269
                                                                **self.cfg.Torch_CRFNet_0_001),
270
271
                        "Torch_CFRNet_0.0001": cate_models.torch.TARNet(num_features,
272
                                                                binary_y=binary_y,
273
                                                                **self.cfg.Torch_CRFNet_0_0001),
274
                        "DiffPOLearner": DiffPOLearner(self.cfg,
275
                                                        num_features,
276
                                                        binary_y=binary_y,
277
                                                        ),
278
279
                    }
280
281
        # Deal with the case where the model is not available
282
        for name in self.cfg.model_names:
283
            if name not in learners:
284
                raise Exception(f"Unknown model name {name}.")
285
            
286
        # Only return learners from cfg.model_names
287
        return {k: v for k, v in learners.items() if k in self.cfg.model_names}
288
        
289
    def get_baseline_learners(self, 
290
                             seed: int = 123) -> dict:
291
        """
292
        Get baseline learners for the experiment, based on the configuration settings.
293
        These models will be used as a baseline for retrieving absolute outcomes from cate predictions. 
294
        """
295
        # Instantiate baseline learner for all possible treatment options
296
        baseline_learners = []
297
        
298
        if self.cfg.debug or self.cfg.dataset == "cytof_normalized_with_fastdrug" or self.cfg.dataset.startswith("melanoma"):
299
            cv = LeaveOneOut()
300
        else:
301
            cv = 5
302
303
        for t in range(self.cfg.simulator.num_T):
304
            if self.discrete_outcome == 1:
305
                base_model = LogisticRegressionCV(penalty='l1', solver='liblinear', cv=cv, random_state=seed)
306
                #base_model = RandomForestClassifier(n_estimators=100, random_state=seed)
307
                model = MultiOutputClassifier(base_model)
308
                baseline_learners.append(model)
309
            else:
310
                base_model = LassoCV(cv=cv, n_alphas=5, max_iter=50, random_state=seed)
311
                #base_model = RandomForestRegressor(n_estimators=100, random_state=seed)
312
                model = MultiOutputRegressor(base_model)
313
                baseline_learners.append(model)
314
315
        return baseline_learners
316
317
    def get_select_learner(self,
318
                            seed: int = 123) -> dict:
319
        """
320
        Get learners for feature selection.
321
        """
322
        # Instantiate baseline learner for all possible treatment options
323
        if self.cfg.debug:
324
            cv = LeaveOneOut()
325
        else:
326
            cv = 5
327
        select_learner = LogisticRegressionCV(penalty='l1', solver='liblinear', cv=cv, random_state=seed)
328
        #select_learner = RandomForestRegressor(n_estimators=100, max_depth=6)
329
        
330
        return select_learner
331
    
332
    def get_prog_learner(self,
333
                            seed: int = 123) -> dict:
334
        """
335
        Get learners for the prognostic part of the outcomes.
336
        """
337
        if self.cfg.debug:
338
            cv = LeaveOneOut()
339
        else:
340
            cv = 5
341
        base_model = LassoCV(cv=cv, n_alphas=5, max_iter=50, random_state=seed)
342
343
        #base_model = RandomForestRegressor(n_estimators=100, random_state=seed)
344
        prog_learner = MultiOutputRegressor(base_model)
345
        return prog_learner
346
    
347
    def get_pred_learners(self,
348
                            seed: int = 123) -> dict:
349
        """
350
        Get learners for the treatment-specific CATEs.
351
        """
352
        # Instantiate baseline learner for all possible treatment options
353
        pred_learners = []
354
355
        if self.cfg.debug or self.cfg.dataset.startswith("melanoma"):
356
            cv = LeaveOneOut()
357
        else:
358
            cv = 5
359
        #base_model = RandomForestRegressor(n_estimators=100, random_state=seed)
360
        for t in range(self.cfg.simulator.num_T):
361
            base_model = LassoCV(cv=cv, n_alphas=5, max_iter=50, random_state=seed)
362
            model = MultiOutputRegressor(base_model)
363
            pred_learners.append(model)
364
365
        return pred_learners
366
    
367
    def train_learners(self,
368
                        X_train: np.ndarray,
369
                        Y_train: np.ndarray,
370
                        T_train: np.ndarray,
371
                        outcomes_train: np.ndarray) -> None:
372
        """
373
        Train all learners.
374
        """
375
        self.training_times = {}
376
        for name in self.learners:
377
            # measure training time
378
            start_time = time.time()
379
380
            log.info(f"Fitting {name}.")
381
            if self.evaluate_inference:
382
                if not (name.startswith("EconML") or name.startswith("DiffPOLearner")):
383
                    raise ValueError("Only EconML models support inference.")
384
                
385
            if name == "DiffPOLearner":
386
                self.learners[name].train(X=X_train, y=Y_train, w=T_train, outcomes=outcomes_train) # fit before!!
387
            else:
388
                self.learners[name].train(X=X_train, y=Y_train, w=T_train) # fit before!!
389
390
            # measure training time
391
            end_time = time.time()
392
            self.training_times[name] = end_time - start_time
393
394
395
    def train_baseline_learners(self,
396
                                X_train: np.ndarray,
397
                                outcomes_train: np.ndarray,
398
                                T_train: np.ndarray) -> None:
399
        """
400
        Train all baseline learners.
401
        """
402
        for t in range(self.cfg.simulator.num_T):
403
            # Get all data points where treatment is t
404
            mask = T_train == t
405
            Y_train_t = outcomes_train[mask,t,:]
406
            X_train_t = X_train[mask,:]
407
408
            log.debug(
409
            f'Check baseline data for treatment {t}:'
410
            f'============================================'
411
            f'X_train: {X_train.shape}'
412
            f'\n{X_train}'
413
            f'\noutcomes_train: {outcomes_train.shape}'
414
            f'\n{outcomes_train}'
415
            f'\nT_train: {T_train.shape}'
416
            f'\n{T_train}'
417
            f'\nX_train_t: {X_train_t.shape}'
418
            f'\n{X_train_t}'
419
            f'\nY_train_t: {Y_train_t.shape}'
420
            f'\n{Y_train_t}'
421
            f'\n============================================\n\n'
422
            )
423
            # Check that there are data points for this treatment
424
            assert Y_train_t.shape[0] > 0, f"No data points for treatment {t}."
425
426
            log.info(f"Fitting baseline learner for treatment {t}.")
427
428
            if self.discrete_outcome:
429
                Y_train_t = Y_train_t.astype(bool)
430
431
            self.baseline_learners[t].fit(X_train_t, Y_train_t)
432
433
    def train_select_learner(self, 
434
                                X_train: np.ndarray,
435
                                T_train: np.ndarray) -> None:
436
        """
437
        Train the feature selection learner.
438
        """
439
        log.info(f"Fitting feature selection learner.")
440
        self.select_learner.fit(X_train, T_train)
441
442
    def train_prog_learner(self,
443
                            X_train: np.ndarray,
444
                            pred_outcomes_train: np.ndarray) -> None:
445
        """
446
        Train the model learning the prognostic part of the outcomes.
447
        Here the prognostic part is the average of all possible outcomes.
448
        """
449
        # pred_outcomes shape: n, num_T, dim_Y
450
        # X_train shape: n, dim_X
451
452
        # Compute the average of all possible outcomes
453
        prog_Y_train = np.mean(pred_outcomes_train, axis=1)
454
455
        # Train the model
456
        self.prog_learner.fit(X_train, prog_Y_train)
457
458
        log.debug(
459
            f'Check prog learner data:'
460
            f'============================================'
461
            f'X_train: {X_train.shape}'
462
            f'\n{X_train}'
463
            f'\npred_outcomes_train: {pred_outcomes_train.shape}'
464
            f'\n{pred_outcomes_train}'
465
            f'\nprog_Y_train: {prog_Y_train.shape}'
466
            f'\n{prog_Y_train}'
467
            f'\n============================================\n\n'
468
        )
469
470
    def train_pred_learner(self,
471
                            X_train: np.ndarray,
472
                            pred_cates_train: np.ndarray,
473
                            T_train: np.ndarray) -> None:
474
        """
475
        Train the model learning the CATEs.
476
        """
477
        for t in range(self.cfg.simulator.num_T):
478
            # Get treatment mask
479
            mask = T_train == t
480
481
            # For every patient with treatment t, we can simply compute the mean CATE
482
            # For others we need to add the cate for t, to get the desired cates (see biomarker attribution good notes)
483
            pred_Y_train = np.zeros((X_train.shape[0], self.cfg.simulator.dim_Y))
484
            pred_Y_train = pred_cates_train.mean(axis=1)
485
            pred_Y_train[~mask] -= pred_cates_train[~mask, t, :]
486
       
487
            # Train the model
488
            self.pred_learners[t].fit(X_train, pred_Y_train)
489
490
            log.debug(
491
                f'Check pred learner data for treatment {t}:'
492
                f'============================================'
493
                f'X_train: {X_train.shape}'
494
                f'\n{X_train[:10]}'
495
                f'\npred_cates_train: {pred_cates_train.shape}'
496
                f'\n{pred_cates_train[:10]}'
497
                f'\nT_train: {T_train.shape}'
498
                f'\n{T_train[:10]}'
499
                f'\npred_Y_train: {pred_Y_train.shape}'
500
                f'\n{pred_Y_train[:10]}'
501
                f'\n============================================\n\n'
502
            )
503
504
    def get_learner_explanations(self, 
505
                                 X: np.ndarray,
506
                                 return_explainer_names: bool = True,
507
                                 ignore_explainer_limit: bool = False,
508
                                 type: str = "pred") -> dict:
509
        """
510
        Get explanations for all learners.
511
        """
512
        learner_explainers = {}
513
        learner_explanations = {}
514
        explainer_names = {}
515
516
        for name in self.learners:
517
            log.info(f"Explaining {name}.")
518
519
            if ignore_explainer_limit:
520
                explainer_limit = X.shape[0]
521
            else:
522
                explainer_limit = self.cfg.explainer_limit
523
524
            if "EconML" in name:
525
                # EconML 
526
                if self.cfg.explainer_econml != "shap":
527
                    raise ValueError("Only shap is supported for EconML models.")
528
529
                if type == "pred":
530
                    cate_est = lambda X: self.learners[name].predict(X)
531
                    shap_values_avg = shap.Explainer(cate_est, X[:explainer_limit]).shap_values(X[:explainer_limit])
532
                    explainer_names[name] = "shap"
533
                    # shap_values = self.learners[name].explain(X[:explainer_limit], background_samples=None)
534
535
                    # treatment_names = self.learners[name].est.cate_treatment_names()
536
                    # output_names = self.learners[name].est.cate_output_names()
537
538
                    # # average absolute shaps over all treatment names
539
                    # shap_values_avg = np.zeros_like(shap_values[output_names[0]][treatment_names[0]].values)
540
                    # for output_name in output_names:
541
                    #     for treatment_name in treatment_names:
542
                    #         shap_values_avg += np.abs(shap_values[output_name][treatment_name].values)
543
544
                # Does not work yet!
545
                elif type == "prog":
546
                    if name == "EconML_SLearner_Lasso":
547
                        y0_est = lambda X: self.learners[name].est.overall_model.predict(np.hstack([X, np.ones((X.shape[0], 1)),np.zeros((X.shape[0], 1))]))
548
                        # pred outcomes shape should be: n, num_T, dim_Y
549
                    elif name == "EconML_TLearner_Lasso":
550
                        y0_est = lambda X: self.learners[name].est.models[0].predict(X)
551
                    elif name == "EconML_DML":
552
                        y0_est = lambda X: self.learners[name].predict_outcomes(X, outcomes=None)[:,0]
553
                    else:
554
                        raise ValueError(f"Model {name} not supported for prog learner explanations.")
555
                    
556
                    shap_values = shap.Explainer(y0_est, X[:explainer_limit]).shap_values(X[:explainer_limit])
557
                    shap_values_avg = shap_values
558
                    # average absolute shaps over all treatment names
559
                    # shap_values_avg = np.zeros_like(shap_values[0])
560
                    # for i in range(len(shap_values)):
561
                    #     shap_values_avg += np.abs(shap_values[i])
562
                
563
                
564
                learner_explanations[name] = shap_values_avg
565
                explainer_names[name] = "shap"
566
                
567
            else:
568
                # Also doesn't work yet!
569
                if type == "prog" and name in ["Torch_SLearner", "Torch_TLearner", "Torch_TARNet", "Torch_CFRNet_0.001", "Torch_CFRNet_0.01","Torch_CFRNet_0.0001","Torch_DragonNet","Torch_DragonNet_2","Torch_DragonNet_4"]:
570
                    # model_to_explain = self.learners[name]._po_estimators[0]
571
                    # X_to_explain = self._repr_estimator(X).squeeze()
572
                    y0_est = lambda X: self.learners[name].predict(X, return_po=True)[1]
573
                    learner_explanations[name] = shap.Explainer(y0_est, X[:explainer_limit]).shap_values(X[:explainer_limit])
574
                    explainer_names[name] = "shap"  
575
576
                elif type == "prog" and not name in ["Torch_SLearner", "Torch_TLearner", "Torch_TARNet", "Torch_CFRNet_0.001", "Torch_CFRNet_0.01","Torch_CFRNet_0.0001","Torch_DragonNet","Torch_DragonNet_2","Torch_DragonNet_4"]:
577
                    learner_explanations[name] = None
578
                    explainer_names[name] = None
579
                    #raise ValueError(f"Model {name} not supported for prog learner explanations.")
580
                
581
                else:
582
                    cate_est = lambda X: self.learners[name].predict(X)
583
                    learner_explanations[name] = shap.Explainer(cate_est, X[:explainer_limit]).shap_values(X[:explainer_limit])
584
                    explainer_names[name] = "shap"
585
                    # model_to_explain = self.learners[name]
586
                    # X_to_explain = X
587
588
                    # learner_explainers[name] = Explainer(
589
                    #     model_to_explain,
590
                    #     feature_names=list(range(X.shape[1])),
591
                    #     explainer_list=[self.cfg.explainer_torch],
592
                    # )
593
                    # learner_explanations[name] = learner_explainers[name].explain(
594
                    #     X[: explainer_limit]
595
                    # )
596
                    # learner_explanations[name] = learner_explanations[name][self.cfg.explainer_torch]
597
                    # explainer_names[name] = self.cfg.explainer_torch
598
                   
599
600
601
        # Check dimensions of explanations by looking at the shape of all explanation np arrays
602
        # for name in learner_explanations:
603
        #     log.debug(
604
        #         f"\nExplanations for {name} have shape {learner_explanations[name].shape}."
605
        #     )
606
            
607
        if return_explainer_names:
608
            return learner_explanations, explainer_names
609
        else:
610
            return learner_explanations
611
    
612
    def get_select_learner_explanations(self,
613
                                        X_reference: np.ndarray,
614
                                        X_to_explain: np.ndarray) -> np.ndarray:
615
        """
616
        Get explanations for the feature selection learner.
617
        Returns shape: n, dim_X, num_T.
618
        """
619
        explainer = shap.Explainer(self.select_learner, X_reference)
620
        shap_values = explainer(X_to_explain).values # Shape: n, dim_X, num_T
621
622
        return shap_values
623
624
    def get_prog_learner_explanations(self,
625
                                        X_reference: np.ndarray,
626
                                        X_to_explain: np.ndarray) -> np.ndarray:
627
        """
628
        Get explanations for the prognostic learner.
629
        Returns shape: n, dim_X, dim_Y.
630
        """
631
        # Get explanations for every outcome
632
        shap_values = np.zeros((X_to_explain.shape[0], X_to_explain.shape[1], self.cfg.simulator.dim_Y))
633
        for i in range(self.cfg.simulator.dim_Y):
634
            explainer = shap.Explainer(self.prog_learner.estimators_[i], X_reference, check_additivity=False)
635
            shap_values[:,:,i] = explainer(X_to_explain).values #, check_additivity=False for forest
636
       
637
        return shap_values
638
    
639
    def get_pred_learner_explanations(self,
640
                                        X_reference: np.ndarray,
641
                                        X_to_explain: np.ndarray) -> np.ndarray:
642
        """
643
        Get explanations for the treatment-specific CATEs (one vs. all).
644
        Returns shape: num_T, n, dim_X, dim_Y.
645
        """
646
        # Get explanations for every outcome and every reference treatment
647
        shap_values = np.zeros((self.cfg.simulator.num_T, X_to_explain.shape[0], X_to_explain.shape[1], self.cfg.simulator.dim_Y))
648
        shap_base_values = np.zeros((self.cfg.simulator.num_T, self.cfg.simulator.dim_Y))
649
650
        for t in range(self.cfg.simulator.num_T):
651
            for i in range(self.cfg.simulator.dim_Y):
652
                explainer = shap.Explainer(self.pred_learners[t].estimators_[i], X_reference)
653
                shap_explanation = explainer(X_to_explain)
654
                pred = self.pred_learners[t].predict(X_to_explain)
655
                shap_values[t,:,:,i] = explainer(X_to_explain).values #, check_additivity=False for forest
656
                shap_base_values[t,i] = shap_explanation.base_values[0]
657
658
        log.debug(
659
            f"\nExplanations for pred_learner have shape {shap_values.shape}."
660
            f"\n{shap_values[:,:10,:10,:]}"
661
        )
662
663
        return shap_values, shap_base_values
664
    
665
666
    def save_results(self, metrics_df: pd.DataFrame, compare_axis_values = None, plot_only: bool = False, save_df_only: bool = False, compare_axis=None, fig_type = "all") -> None:
667
        """
668
        Save results of the experiment.
669
        """
670
        file_name = self.file_name
671
        results_path = self.results_path
672
673
        if save_df_only:
674
            log.info(f"Saving intermediate results in {results_path}...")
675
            if not results_path.exists():
676
                results_path.mkdir(parents=True, exist_ok=True)
677
        else:
678
            log.info(f"Saving final results in {results_path}...")
679
            if not results_path.exists():
680
                results_path.mkdir(parents=True, exist_ok=True)
681
682
        if not plot_only:
683
            # Save metrics csv and metadata csv with the configuration
684
            if save_df_only:
685
                metrics_df.to_csv(
686
                    results_path / Path(file_name+"_checkpoint.csv")
687
                )
688
            else:
689
                metrics_df.to_csv(
690
                    results_path / Path(file_name+".csv")
691
                )
692
            OmegaConf.save(self.cfg, results_path / Path(file_name+".yaml"))
693
694
        # Choose figure size and legend position
695
        n_rows = len(self.cfg.metrics_to_plot)
696
        figsize = None #(10, 100)
697
        legend_position = 0.95
698
        log_x_axis = False
699
700
        # Save plots
701
        if self.cfg.plot_results and not save_df_only:
702
            if self.cfg.experiment_name == 'sample_size_sensitivity':
703
                compare_axis = "Propensity Scale"
704
                x_axis = "Sample Portion"
705
                x_label_name = r'$\omega_{\mathrm{ss}}$'
706
                x_values_to_plot = self.cfg.sample_sizes
707
            
708
            elif self.cfg.experiment_name == 'important_feature_num_sensitivity':
709
                compare_axis = "Feature Overlap"
710
                x_axis = "Num Important Features"
711
                x_label_name = r'$\omega_{\mathrm{IFs}}$'
712
                x_values_to_plot = self.cfg.important_feature_nums
713
714
            elif self.cfg.experiment_name == 'propensity_scale_sensitivity':
715
                # compare_axis = "Unbalancedness Exp"
716
                compare_axis = "Propensity Type"
717
                x_axis = "Propensity Scale"
718
                x_label_name = r'${\mathrm{\beta}}$'
719
                x_values_to_plot = self.cfg.propensity_scales
720
                log_x_axis = True
721
722
            elif self.cfg.experiment_name == 'predictive_scale_sensitivity':
723
                compare_axis = "Nonlinearity Scale"
724
                x_axis = "Predictive Scale"
725
                x_label_name = r'$\omega_{\mathrm{pds}}$'
726
                x_values_to_plot = self.cfg.predictive_scales
727
728
            elif self.cfg.experiment_name == 'treatment_space_sensitivity':
729
                compare_axis = "Binary Outcome"
730
                x_axis = "Treatment Options"
731
                x_label_name = r'$\omega_{\mathrm{Ts}}$'
732
                x_values_to_plot = self.cfg.num_Ts
733
734
            elif self.cfg.experiment_name == 'data_dimensionality_sensitivity':
735
                compare_axis = compare_axis
736
                x_axis = "Data Dimension"
737
                x_label_name = r'$Num Features$'
738
                x_values_to_plot = self.cfg.data_dims
739
740
            elif self.cfg.experiment_name == 'feature_overlap_sensitivity':
741
                compare_axis = "Overlap Type"
742
                x_axis = "Propensity Scale"
743
                x_label_name = r'$\omega_{\mathrm{pps}}$'
744
                x_values_to_plot = self.cfg.propensity_scales
745
746
            elif self.cfg.experiment_name == 'expertise_sensitivity':
747
                compare_axis = "Propensity Type"
748
                x_axis = "Alpha"
749
                x_label_name = r'$\alpha$'
750
                x_values_to_plot = self.cfg.alphas
751
752
            else:
753
                raise ValueError(f"Experiment {self.cfg.experiment_name} not supported for plotting.")
754
755
            if fig_type == "all":
756
                fig = plot_results_datasets_compare(
757
                            results_df=metrics_df,
758
                            model_names=self.cfg.model_names,
759
                            dataset=self.cfg.dataset,
760
                            compare_axis=compare_axis,
761
                            compare_axis_values=compare_axis_values,
762
                            x_axis=x_axis,
763
                            x_label_name=x_label_name,
764
                            x_values_to_plot=x_values_to_plot,
765
                            metrics_list=self.cfg.metrics_to_plot,
766
                            learners_list=self.cfg.model_names,
767
                            figsize=figsize,
768
                            legend_position=legend_position,
769
                            seeds_list=self.cfg.seeds,
770
                            n_splits=self.cfg.n_splits,
771
                            sharey="row",
772
                            legend_rows=1,
773
                            dim_X=self.X.shape[0],
774
                            log_x_axis=log_x_axis
775
                        )
776
                
777
            elif fig_type == "expertise":
778
                fig = plot_expertise_metrics(
779
                            results_df=metrics_df,
780
                            model_names=self.cfg.model_names,
781
                            dataset=self.cfg.dataset,
782
                            compare_axis=compare_axis,
783
                            compare_axis_values=compare_axis_values,
784
                            x_axis=x_axis,
785
                            x_label_name=x_label_name,
786
                            x_values_to_plot=x_values_to_plot,
787
                            metrics_list=self.cfg.metrics_to_plot,
788
                            learners_list=self.cfg.model_names,
789
                            figsize=figsize,
790
                            legend_position=legend_position,
791
                            seeds_list=self.cfg.seeds,
792
                            n_splits=self.cfg.n_splits,
793
                            sharey="row",
794
                            legend_rows=1,
795
                            dim_X=self.X.shape[0],
796
                            log_x_axis=log_x_axis
797
                        )
798
                
799
            elif fig_type == "performance":
800
                fig = plot_performance_metrics(
801
                            results_df=metrics_df,
802
                            model_names=self.cfg.model_names,
803
                            dataset=self.cfg.dataset,
804
                            compare_axis=compare_axis,
805
                            compare_axis_values=compare_axis_values,
806
                            x_axis=x_axis,
807
                            x_label_name=x_label_name,
808
                            x_values_to_plot=x_values_to_plot,
809
                            metrics_list=self.cfg.metrics_to_plot,
810
                            learners_list=self.cfg.model_names,
811
                            figsize=figsize,
812
                            legend_position=legend_position,
813
                            seeds_list=self.cfg.seeds,
814
                            n_splits=self.cfg.n_splits,
815
                            sharey="row",
816
                            legend_rows=1,
817
                            dim_X=self.X.shape[0],
818
                            log_x_axis=log_x_axis
819
                        )
820
                
821
            elif fig_type == "performance_f_cf":
822
                fig = plot_performance_metrics_f_cf(
823
                            results_df=metrics_df,
824
                            model_names=self.cfg.model_names,
825
                            dataset=self.cfg.dataset,
826
                            compare_axis=compare_axis,
827
                            compare_axis_values=compare_axis_values,
828
                            x_axis=x_axis,
829
                            x_label_name=x_label_name,
830
                            x_values_to_plot=x_values_to_plot,
831
                            metrics_list=self.cfg.metrics_to_plot,
832
                            learners_list=self.cfg.model_names,
833
                            figsize=figsize,
834
                            legend_position=legend_position,
835
                            seeds_list=self.cfg.seeds,
836
                            n_splits=self.cfg.n_splits,
837
                            sharey="row",
838
                            legend_rows=1,
839
                            dim_X=self.X.shape[0],
840
                            log_x_axis=log_x_axis
841
                        )
842
               
843
844
            # Save figure
845
            try:
846
                fig.savefig(results_path / f"{self.cfg.plot_name_prefix}.png", bbox_inches='tight')
847
            except:
848
                fig.savefig(results_path / f"{file_name}.png", bbox_inches='tight')
849
850
    def load_and_plot_results(self, compare_axis_values, fig_type: str = "all") -> None:
851
        """
852
        Load and plot results of the experiment.
853
        """
854
        file_name = self.file_name
855
        results_path = self.results_path
856
857
        log.info(f"Loading results from {results_path}...")
858
        if not results_path.exists():
859
            raise FileNotFoundError(f"Results path {results_path} does not exist.")
860
861
        # Load metrics
862
        metrics_df = pd.read_csv(results_path / Path(file_name+"_checkpoint.csv"))
863
864
        if self.cfg.experiment_name == "data_dimensionality_sensitivity":
865
            if self.cfg.compare_axis == "propensity":
866
                compare_axis = "Propensity Scale"
867
            elif self.cfg.compare_axis == "num_features":
868
                compare_axis = "Num Important Features"
869
            self.save_results(metrics_df, plot_only=True, compare_axis=compare_axis, compare_axis_values=compare_axis_values, fig_type=fig_type)
870
        else:
871
            self.save_results(metrics_df, plot_only=True, compare_axis_values=compare_axis_values, fig_type=fig_type)
872
873
    
874
    def compute_outcome_mse(self,
875
                            pred_outcomes: np.ndarray,
876
                            true_outcomes: np.ndarray,
877
                            T: np.ndarray = None) -> float:
878
        """
879
        Compute MSE for all outcomes.
880
        """
881
       
882
        # Only keep factual cates
883
        pred_outcomes_factual = np.zeros((pred_outcomes.shape[0], pred_outcomes.shape[2])) # n, dim_Y
884
        true_outcomes_factual = np.zeros((true_outcomes.shape[0], true_outcomes.shape[2]))
885
        pred_outcomes_factual_Y0 = np.zeros(((T==0).sum(), pred_outcomes.shape[2])) # n, dim_Y
886
        true_outcomes_factual_Y0 = np.zeros(((T==0).sum(), pred_outcomes.shape[2]))
887
        pred_outcomes_factual_Y1 = np.zeros(((T==1).sum(), pred_outcomes.shape[2])) # n, dim_Y
888
        true_outcomes_factual_Y1 = np.zeros(((T==1).sum(), pred_outcomes.shape[2]))
889
890
        pred_outcomes_cf = np.zeros((pred_outcomes.shape[0], pred_outcomes.shape[1]-1, pred_outcomes.shape[2])) # n, num_T-1, dim_Y
891
        true_outcomes_cf = np.zeros((true_outcomes.shape[0], true_outcomes.shape[1]-1, true_outcomes.shape[2]))
892
        pred_outcomes_cf_Y0 = np.zeros(((T==1).sum(), pred_outcomes.shape[1]-1, pred_outcomes.shape[2])) # n, dim_Y
893
        true_outcomes_cf_Y0 = np.zeros(((T==1).sum(), pred_outcomes.shape[1]-1, pred_outcomes.shape[2]))
894
        pred_outcomes_cf_Y1 = np.zeros(((T==0).sum(), pred_outcomes.shape[1]-1, pred_outcomes.shape[2])) # n, dim_Y
895
        true_outcomes_cf_Y1 = np.zeros(((T==0).sum(), pred_outcomes.shape[1]-1, pred_outcomes.shape[2]))
896
897
        counter_Y0 = 0
898
        counter_Y1 = 0
899
        for i in range(pred_outcomes.shape[0]):
900
            mask_factual = np.zeros(pred_outcomes.shape[1], dtype=bool)
901
            mask_cf = np.ones(pred_outcomes.shape[1], dtype=bool)
902
            mask_factual[T[i]] = True
903
            mask_cf[T[i]] = False
904
905
            pred_outcomes_factual[i,:] = pred_outcomes[i, mask_factual,:]
906
            true_outcomes_factual[i,:] = true_outcomes[i, mask_factual,:]
907
            pred_outcomes_cf[i,:,:] = pred_outcomes[i, mask_cf,:]
908
            true_outcomes_cf[i,:,:] = true_outcomes[i, mask_cf,:]
909
910
            if T[i] == 0:
911
                pred_outcomes_factual_Y0[counter_Y0,:] = pred_outcomes[i, mask_factual,:]
912
                true_outcomes_factual_Y0[counter_Y0,:] = true_outcomes[i, mask_factual,:]
913
                pred_outcomes_cf_Y1[counter_Y0,:,:] = pred_outcomes[i,mask_cf,:]
914
                true_outcomes_cf_Y1[counter_Y0,:,:] = true_outcomes[i,mask_cf,:]
915
                counter_Y0 += 1
916
            else:
917
                pred_outcomes_factual_Y1[counter_Y1,:] = pred_outcomes[i, mask_factual,:]
918
                true_outcomes_factual_Y1[counter_Y1,:] = true_outcomes[i, mask_factual,:]
919
                pred_outcomes_cf_Y0[counter_Y1,:,:] = pred_outcomes[i,mask_cf,:]
920
                true_outcomes_cf_Y0[counter_Y1,:,:] = true_outcomes[i,mask_cf,:]
921
                counter_Y1 += 1
922
923
924
        rmse_factual = mean_squared_error(true_outcomes_factual.reshape(-1), pred_outcomes_factual.reshape(-1), squared=False)
925
        rmse_factual_Y0 = mean_squared_error(true_outcomes_factual_Y0.reshape(-1), pred_outcomes_factual_Y0.reshape(-1), squared=False)
926
        rmse_factual_Y1 = mean_squared_error(true_outcomes_factual_Y1.reshape(-1), pred_outcomes_factual_Y1.reshape(-1), squared=False)
927
        rmse_cf = mean_squared_error(true_outcomes_cf.reshape(-1), pred_outcomes_cf.reshape(-1), squared=False)
928
        rmse_cf_Y0 = mean_squared_error(true_outcomes_cf_Y0.reshape(-1), pred_outcomes_cf_Y0.reshape(-1), squared=False)
929
        rmse_cf_Y1 = mean_squared_error(true_outcomes_cf_Y1.reshape(-1), pred_outcomes_cf_Y1.reshape(-1), squared=False)
930
931
        rmse_Y0 = mean_squared_error(np.concatenate([true_outcomes_factual_Y0.reshape(-1), true_outcomes_cf_Y0.reshape(-1)]), np.concatenate([pred_outcomes_factual_Y0.reshape(-1), pred_outcomes_cf_Y0.reshape(-1)]), squared=False)
932
        rmse_Y1 = mean_squared_error(np.concatenate([true_outcomes_factual_Y1.reshape(-1), true_outcomes_cf_Y1.reshape(-1)]), np.concatenate([pred_outcomes_factual_Y1.reshape(-1), pred_outcomes_cf_Y1.reshape(-1)]), squared=False)
933
        # Get variance of true outcomes, factual and cf
934
        # factual_std = np.std(true_outcomes_factual)
935
        # cf_std = np.std(true_outcomes_cf)
936
        factual_std = np.var(true_outcomes_factual)
937
        cf_std = np.var(true_outcomes_cf)
938
939
        log.debug(
940
            f"\nPred outcomes: {pred_outcomes.shape}: \n{pred_outcomes}"
941
            f"\n\nT: \n{T}"
942
            f"\n\nPred outcomes factual: {pred_outcomes_factual.shape}: \n{pred_outcomes_factual}"
943
        )
944
        
945
        return rmse_Y0, rmse_Y1, rmse_factual_Y0, rmse_factual_Y1, rmse_cf_Y0, rmse_cf_Y1, rmse_factual, rmse_cf, rmse_factual/factual_std, rmse_cf/cf_std, np.mean(true_outcomes_factual), np.mean(true_outcomes_cf), np.std(true_outcomes_factual), np.std(true_outcomes_cf)
946
947
    
948
    def compute_outcome_auroc(self,
949
                              pred_outcomes: np.ndarray,
950
                              true_outcomes: np.ndarray,
951
                              T: np.ndarray = None) -> float:
952
        """
953
        Compute AUROC for all outcomes.
954
        """
955
        # Only keep factual cates
956
        pred_outcomes_factual = np.zeros((pred_outcomes.shape[0], pred_outcomes.shape[2])) # n, dim_Y
957
        true_outcomes_factual = np.zeros((true_outcomes.shape[0], true_outcomes.shape[2]))
958
        pred_outcomes_cf = np.zeros((pred_outcomes.shape[0], pred_outcomes.shape[1]-1, pred_outcomes.shape[2])) # n, num_T-1, dim_Y
959
        true_outcomes_cf = np.zeros((true_outcomes.shape[0], pred_outcomes.shape[1]-1, true_outcomes.shape[2]))
960
961
        for i in range(pred_outcomes.shape[0]):
962
            mask_factual = np.zeros(pred_outcomes.shape[1], dtype=bool)
963
            mask_cf = np.ones(pred_outcomes.shape[1], dtype=bool)
964
            mask_factual[T[i]] = True
965
            mask_cf[T[i]] = False
966
967
            pred_outcomes_factual[i,:] = pred_outcomes[i, mask_factual,:]
968
            true_outcomes_factual[i,:] = true_outcomes[i, mask_factual,:]
969
            pred_outcomes_cf[i,:,:] = pred_outcomes[i, mask_cf,:]
970
            true_outcomes_cf[i,:,:] = true_outcomes[i, mask_cf,:]
971
972
        auroc_factual = roc_auc_score(true_outcomes_factual.reshape(-1), pred_outcomes_factual.reshape(-1))
973
        auroc_cf = roc_auc_score(true_outcomes_cf.reshape(-1), pred_outcomes_cf.reshape(-1))
974
975
        log.debug(
976
            f"\nPred outcomes: {pred_outcomes.shape}: \n{pred_outcomes}"
977
            f"\n\nT: \n{T}"
978
            f"\n\nPred outcomes factual: {pred_outcomes_factual.shape}: \n{pred_outcomes_factual}"
979
        )
980
        
981
        return auroc_factual, auroc_cf
982
    
983
        # if factual:
984
        #     # Only keep factual cates
985
        #     pred_outcomes_factual = np.zeros((pred_outcomes.shape[0], pred_outcomes.shape[2])) # dim_X, num_T, dim_Y
986
        #     true_outcomes_factual = np.zeros((true_outcomes.shape[0], true_outcomes.shape[2]))
987
988
        #     for i in range(pred_outcomes.shape[0]):
989
        #         mask = np.zeros(pred_outcomes.shape[1], dtype=bool)
990
        #         mask[T[i]] = True
991
        #         pred_outcomes_factual[i,:] = pred_outcomes[i, mask,:]
992
        #         true_outcomes_factual[i,:] = true_outcomes[i, mask,:]
993
        #     # f1 = f1_score(true_outcomes_factual.reshape(-1), pred_outcomes_factual.reshape(-1))
994
        #     f1 = roc_auc_score(true_outcomes_factual.reshape(-1), pred_outcomes_factual.clip(0,1).reshape(-1))
995
            
996
        # else:
997
        #     auroc = roc_auc_score(true_outcomes.reshape(-1), pred_outcomes.clip(0,1).reshape(-1))
998
        #     f1 = auroc
999
        #     #f1 = f1_score(true_outcomes.reshape(-1), pred_outcomes.reshape(-1))    
1000
        
1001
        # auroc = f1 #TODO: Change name to f1
1002
        # return auroc
1003
1004
1005
    def compute_overall_pehe(self,
1006
                                pred_cates: np.ndarray,
1007
                                true_cates: np.ndarray,
1008
                                T: np.ndarray) -> float:
1009
        """
1010
        Compute average PEHE across all treatment options for all outcomes.
1011
        """
1012
1013
        # Remove cates where basline and assigned treatment are the same - they evaluate to zero and are not informative
1014
        if self.discrete_outcome:
1015
            pred_cates = pred_cates.clip(-1,1)
1016
1017
        pehe_sum_total = 0
1018
        pehe_normalized_sum_total = 0
1019
        mean_sum_total = 0
1020
        std_sum_total = 0
1021
1022
        for outcome_idx in range(pred_cates.shape[2]):
1023
            pred_cates_curr = pred_cates[:,:,outcome_idx]
1024
            true_cates_curr = true_cates[:,:,outcome_idx]
1025
1026
            counter = 0
1027
            pehe_sum = 0
1028
            pehe_normalized_sum = 0
1029
            mean_sum = 0
1030
            std_sum = 0
1031
1032
            for i in range(pred_cates.shape[1]):
1033
                for j in range(i):
1034
                    mask_j = T == j 
1035
                    mask_i = T == i
1036
1037
                    n = np.sum(mask_j) + np.sum(mask_i)
1038
                    counter += 1
1039
1040
                    pred_cates_curr_cate_j = pred_cates_curr[mask_j,i]
1041
                    true_cates_curr_cate_j = true_cates_curr[mask_j,i]
1042
                    # pred_cates_curr_cate_i = pred_cates_curr[mask_i,j]
1043
                    # true_cates_curr_cate_i = true_cates_curr[mask_i,j]
1044
                    pred_cates_curr_cate_i = -pred_cates_curr[mask_i,j] # Make sure to always record cate in same direction
1045
                    true_cates_curr_cate_i = -true_cates_curr[mask_i,j] # TODO: Make sure this makes sense for multiple treatments & outcomes
1046
1047
                    pred_cates_curr_cate = np.concatenate((pred_cates_curr_cate_j, pred_cates_curr_cate_i)).reshape(-1)
1048
                    true_cates_curr_cate = np.concatenate((true_cates_curr_cate_j, true_cates_curr_cate_i)).reshape(-1)
1049
1050
                    # Compute mean CATE
1051
                    true_cates_mean = np.mean(true_cates_curr_cate)
1052
                    mean_sum += true_cates_mean
1053
1054
                    # Compute std of CATE
1055
                    true_cates_std = np.std(true_cates_curr_cate)
1056
                    std_sum += true_cates_std
1057
1058
                    pehe_curr = mean_squared_error(true_cates_curr_cate, pred_cates_curr_cate)
1059
                    pehe_sum += pehe_curr
1060
                    pehe_normalized_sum += pehe_curr / np.var(true_cates_curr_cate)
1061
1062
                    log.debug(
1063
                        f'Check pehe computation for outcome: {outcome_idx}, cate: {i}-{j}'
1064
                        f'============================================'
1065
                        f'\npred_cates: {pred_cates.shape}'
1066
                        f'\n{pred_cates}'
1067
                        f'\n\ntrue_cates: {true_cates.shape}'
1068
                        f'\n{true_cates}'
1069
                        f'\n\nT: {T.shape}'
1070
                        f'\n{T}'
1071
                        f'\n\npred_cates_curr: {pred_cates_curr.shape}'
1072
                        f'\n{pred_cates_curr}'
1073
                        f'\n\ntrue_cates_curr: {true_cates_mean.shape}'
1074
                        f'\n{true_cates_curr}'
1075
                        f'\n\nmask_i: {mask_i.shape}'
1076
                        f'\n{mask_i}'
1077
                        f'\n\nmask_j: {mask_j.shape}'
1078
                        f'\n{mask_j}'
1079
                        f'\n\pred_cates_curr_cate: {pred_cates_curr_cate.shape}'
1080
                        f'\n{pred_cates_curr_cate}'
1081
                        f'\n\ntrue_cates_curr_cate: {true_cates_curr_cate.shape}'
1082
                        f'\n{true_cates_curr_cate}'
1083
                        f'\n\ncounter is currently: {counter}'
1084
                        f'\n\npehe_curr: {pehe_curr}'
1085
                        f'\n\nmean_sum: {mean_sum}'
1086
                        f'\n\nstd_sum: {std_sum}'
1087
                        f'\n============================================\n\n'
1088
                    )
1089
1090
            pehe = pehe_sum / counter
1091
            pehe_normalized = pehe_normalized_sum / counter
1092
            pehe_sum_total += pehe
1093
            pehe_normalized_sum_total += pehe_normalized
1094
            mean_sum_total += mean_sum / counter
1095
            std_sum_total += std_sum / counter
1096
1097
        pehe_total = pehe_sum_total / pred_cates.shape[2]
1098
        pehe_normalized_total = pehe_normalized_sum_total / pred_cates.shape[2]
1099
        true_cates_mean_total = mean_sum_total / pred_cates.shape[2]
1100
        true_cates_std_total = std_sum_total / pred_cates.shape[2]
1101
1102
        # pred_cates_filt = np.zeros((pred_cates.shape[0], pred_cates.shape[1]-1, pred_cates.shape[2])) # dim_X, num_T, dim_Y
1103
        # true_cates_filt = np.zeros((true_cates.shape[0], true_cates.shape[1]-1, true_cates.shape[2]))
1104
1105
        # for i in range(pred_cates.shape[0]):
1106
        #     mask = np.ones(pred_cates.shape[1], dtype=bool)
1107
        #     mask[T[i]] = False
1108
1109
        #     pred_cates_filt[i,:,:] = pred_cates[i, mask,:]
1110
        #     true_cates_filt[i,:,:] = true_cates[i, mask,:]
1111
1112
        # # Reshape to get all cates for each outcome in one dimension
1113
        # pred_cates_filt = pred_cates_filt.reshape(-1)
1114
        # true_cates_filt = true_cates_filt.reshape(-1)
1115
        
1116
        # # Compute PEHE 
1117
        # pehe = mean_squared_error(true_cates_filt, pred_cates_filt)
1118
1119
        # Compute mean CATE
1120
        # pred_cates_mean = np.mean(pred_cates_filt)
1121
        # true_cates_mean = np.mean(true_cates_filt)
1122
1123
        # # Compute std of CATE
1124
        # pred_cates_std = np.std(pred_cates_filt)
1125
        # true_cates_std = np.std(true_cates_filt)
1126
1127
        # # Compute normalized PEHE
1128
        # pehe_normalized = pehe / np.var(true_cates_filt)
1129
1130
        return pehe_total, pehe_normalized_total, true_cates_mean_total, true_cates_std_total
1131
1132
    def get_pred_cates(self, 
1133
                       model_name: str, 
1134
                       X: np.ndarray, 
1135
                       T: np.ndarray,
1136
                       outcomes_test: np.ndarray) -> np.ndarray:
1137
        """
1138
        Get the predicted CATEs for a model.
1139
        """
1140
        # Predict cate for every treatment option and use assigned treatment as baseline treatment.
1141
        T0 = T
1142
        T1 = np.zeros_like(T)
1143
        pred_cates = np.zeros((X.shape[0], self.cfg.simulator.num_T, self.cfg.simulator.dim_Y))
1144
        pred_cates_conf = np.zeros((X.shape[0], self.cfg.simulator.num_T, self.cfg.simulator.dim_Y, 2))
1145
1146
        for i in range(self.cfg.simulator.num_T):
1147
            # Set to current treatment
1148
            T1[:] = i
1149
1150
            # Deal with Torch models in case there are only two treatment options
1151
            if self.cfg.simulator.num_T > 2:
1152
                pred = self.learners[model_name].predict(X, T0=T0, T1=T1) # This predicts y[T1]-y[T0]
1153
1154
            elif model_name == "DiffPOLearner":
1155
                mask = T1 == T0
1156
                pred = self.learners[model_name].predict(X, T0=T0, T1=T1, outcomes=outcomes_test)
1157
                pred[mask] = 0
1158
                # if self.evaluate_inference:
1159
                #     cates_conf = self.learners[model_name].est.effect_interval(X)
1160
                if i == 0:
1161
                    pred = -pred
1162
              
1163
            else:
1164
1165
                mask = T1 == T0
1166
                
1167
                pred = self.learners[model_name].predict(X) # This predicts y[1]-y[0]
1168
                pred[mask] = 0
1169
                # if self.evaluate_inference:
1170
                #     cates_conf = self.learners[model_name].est.effect_interval(X)
1171
                if i == 0:
1172
                    pred = -pred
1173
                    # cates_conf_lbs = -cates_conf[0]
1174
                    # cates_conf_ups = -cates_conf[1]
1175
1176
            # print(cates_conf_lbs)
1177
            # print(cates_conf_ups)
1178
1179
            pred = pred.reshape(-1, self.cfg.simulator.dim_Y)
1180
            pred_cates[:, i, :] = pred.cpu().detach().numpy()
1181
1182
            # if self.evaluate_inference:
1183
            #     cates_conf_lbs = cates_conf_lbs.reshape(-1, self.cfg.simulator.dim_Y)
1184
            #     cates_conf_ups = cates_conf_ups.reshape(-1, self.cfg.simulator.dim_Y)
1185
            #     pred_cates_conf[:, i, :, 0] = cates_conf_lbs
1186
            #     pred_cates_conf[:, i, :, 1] = cates_conf_ups
1187
1188
            
1189
        log.debug(f"Predicted CATEs for {model_name} have shape {pred_cates.shape}.")
1190
1191
        # if self.discrete_outcome:
1192
        #     # Clip the values to the range [-1, 1]
1193
        #     pred_cates = np.clip(pred_cates, -1, 1)
1194
1195
        #     # Round to the nearest integer
1196
        #     pred_cates = np.rint(pred_cates)
1197
1198
        # Fill nan values with zeros
1199
        if np.isnan(pred_cates).any():
1200
            log.warning(f"There are nan values in the predicted CATEs for model: {model_name}. They were filled with 0s")
1201
        pred_cates = np.nan_to_num(pred_cates)
1202
        
1203
        return pred_cates
1204
    
1205
    def get_pred_outcomes(self,
1206
                            model_name: str,
1207
                            pred_cates: np.ndarray,
1208
                            X: np.ndarray,
1209
                            T: np.ndarray) -> np.ndarray:
1210
            """
1211
            Get the predicted outcomes for a model.
1212
            """
1213
            # Get predicted outcomes for all treatment options
1214
            pred_outcomes_baselines = np.zeros((X.shape[0], self.cfg.simulator.num_T, self.cfg.simulator.dim_Y))
1215
1216
            # Get directly predicted outcomes if available
1217
            if model_name in ["DiffPOLearner", "EconML_SLearner_Lasso", "EconML_TLearner_Lasso", "EconML_DML"]:
1218
                pred_outcomes = self.learners[model_name].predict_outcomes(X, T0=T, outcomes=None)
1219
                return pred_outcomes
1220
            
1221
            elif model_name in ["Torch_ActionNet", "Torch_TARNet", "Torch_TLearner", "Torch_SLearner", "Torch_DragonNet_2", "Torch_DragonNet_4", "Torch_DragonNet", "TorchSNet", "Torch_CFRNet_0.001", "Torch_CFRNet_0.01", "Torch_CFRNet_0.0001"]:
1222
                _, y0_pred, y1_pred = self.learners[model_name].predict(X, return_po=True)
1223
                y0_pred = y0_pred.cpu().detach().numpy()
1224
                y1_pred = y1_pred.cpu().detach().numpy()
1225
                pred_outcomes = np.zeros((X.shape[0], self.cfg.simulator.num_T, self.cfg.simulator.dim_Y))
1226
                pred_outcomes[:, 0, :] = y0_pred.reshape(-1, self.cfg.simulator.dim_Y)
1227
                pred_outcomes[:, 1, :] = y1_pred.reshape(-1, self.cfg.simulator.dim_Y)
1228
                return pred_outcomes
1229
1230
            else:
1231
                for i in range(self.cfg.simulator.num_T):
1232
                    # Get baseline outcomes
1233
                    mask = T == i
1234
1235
                    # Fill in baseline outcomes according to selected treatments for each patient
1236
                    if self.discrete_outcome:
1237
                        baseline_preds = np.zeros((X[mask, :].shape[0], self.cfg.simulator.dim_Y))
1238
                        baseline_preds_list = self.baseline_learners[i].predict_proba(X[mask, :])
1239
                        for out_dim in range(self.cfg.simulator.dim_Y):
1240
                            baseline_preds_curr = baseline_preds_list[out_dim][:,1]
1241
                            baseline_preds[:,out_dim] = baseline_preds_curr
1242
1243
                    else:
1244
                        baseline_preds = self.baseline_learners[i].predict(X[mask, :])
1245
                        
1246
                    baseline_preds = np.repeat(baseline_preds[:,np.newaxis,:], self.cfg.simulator.num_T, axis=1)
1247
1248
                    # Copy baseline predictions to all treatment options and make sure dimensions match
1249
                    pred_outcomes_baselines[mask, :, :] = baseline_preds
1250
1251
                # Add predicted CATEs to baseline outcomes
1252
                pred_outcomes = pred_outcomes_baselines + pred_cates
1253
1254
                log.debug(
1255
                    f'Check predicted outcomes for model:'
1256
                    f'============================================'
1257
                    f'X: {X.shape}'
1258
                    f'\n{X}'
1259
                    f'\nT: {T.shape}'
1260
                    f'\n{T}'
1261
                    f'\npred_cates: {pred_cates.shape}'
1262
                    f'\n{pred_cates}'
1263
                    f'\npred_outcomes_baselines: {pred_outcomes_baselines.shape}'
1264
                    f'\n{pred_outcomes_baselines}'
1265
                    f'\npred_outcomes: {pred_outcomes.shape}'
1266
                    f'\n{pred_outcomes}'
1267
                    f'\n============================================\n\n'
1268
                )
1269
1270
                # Clip to binary outcome
1271
                # if self.discrete_outcome:
1272
                #     pred_outcomes = np.clip(pred_outcomes, 0, 1)
1273
1274
                # Fill nan values with zeros and warn user
1275
                if np.isnan(pred_outcomes).any():
1276
                    log.warning("There are nan values in the predicted outcomes for. They were filled with zeros.")
1277
                pred_outcomes = np.nan_to_num(pred_outcomes)
1278
                return pred_outcomes
1279
    
1280
    def get_effect_cis(self,
1281
                        model_name: str,
1282
                        X: np.ndarray,
1283
                        T: np.ndarray) -> np.ndarray:
1284
        """
1285
        Get the confidence intervals for the predicted CATEs.
1286
        """
1287
        # Predict cate for every treatment option and use assigned treatment as baseline treatment.
1288
        T0 = T
1289
        T1 = np.zeros_like(T)
1290
        pred_cates_conf = np.zeros((2, X.shape[0], self.cfg.simulator.num_T, self.cfg.simulator.dim_Y))
1291
1292
        for i in range(self.cfg.simulator.num_T):
1293
            # Set to current treatment
1294
            T1[:] = i
1295
1296
            effect_cis = self.learners[model_name].infer_effect_ci(X=X, T0=T0) # dim: 2, n, dim_Y (2 for lower and upper bound of confidence interval)
1297
            pred_cates_conf[:, :, i, :] = effect_cis
1298
1299
        return pred_cates_conf
1300
1301
1302
    def get_swap_statistics_single_outcome(self,
1303
                                           T: np.ndarray, 
1304
                                           outcomes: np.ndarray, # n, num_T
1305
                                           pred_cates: np.ndarray, # n, num_T
1306
                                           shap_values_pred: np.ndarray, # num_T, n, dim_X
1307
                                           shap_base_values_pred: np.ndarray, # num_T
1308
                                           k: int = 1, 
1309
                                           threshold: int = 0):
1310
        """
1311
        Evaluates for a given decisions threshold and number of decision variables k,  
1312
        how well the personalized shap values predict whether the treatment should be swapped.
1313
        """
1314
        tp, fp, tn, fn = 0,0,0,0 # slightly biased, but avoids division by 0
1315
1316
        for i in range(T.shape[0]):
1317
            # Check whether the true outcomes would speak for swapping the treatment
1318
            outcome = outcomes[i,T[i]]
1319
            outcome_mean = np.mean(outcomes[i])
1320
            swap_true = outcome < outcome_mean # swap if below mean
1321
            
1322
1323
            # print(f"Outcome: {outcome}, Outcome mean: {outcome_mean}, Swap true: {swap_true}")
1324
            # print(f"Pred cates: {pred_cates[i,:]}")
1325
            # Check whether the SHAP values would speak for swapping the treatment
1326
            # get average cate
1327
            swap_pred = np.mean(pred_cates[i,:]) > threshold
1328
            
1329
1330
            ## OLD: for top k evaluation
1331
            # shap_values = shap_values_pred[T[i],i,:]
1332
            # Get the top k features in terms of absolute SHAP values
1333
            # top_k = np.argsort(np.abs(shap_values))[-k:] #[-k] to check for one feature
1334
            # swap_pred = np.sum(shap_values[top_k]) + shap_base_values_pred[T[i]] > threshold
1335
1336
            if swap_true and swap_pred:
1337
                tp += 1
1338
            elif swap_true and not swap_pred:
1339
                fn += 1
1340
            elif not swap_true and swap_pred:
1341
                fp += 1
1342
            else:
1343
                tn += 1
1344
1345
        # Compute the FPR and TPR and precision and recall
1346
        fpr = fp / (fp + tn) if fp + tn > 0 else np.nan
1347
        tpr = tp / (tp + fn) if tp + fn > 0 else np.nan
1348
        precision = tp / (tp + fp) if tp + fp > 0 else np.nan
1349
        recall = tp / (tp + fn) if tp + fn > 0 else np.nan
1350
1351
        return fpr, tpr, precision, recall
1352
    
1353
    
1354
    def compute_swap_metrics_single_outcome(self,
1355
                                            T: np.ndarray,
1356
                                            outcomes: np.ndarray, # n, num_T
1357
                                            pred_cates: np.ndarray, # n, num_T
1358
                                            shap_values_pred: np.ndarray, # num_T, n, dim_X
1359
                                            shap_base_values_pred: np.ndarray, # num_T
1360
                                            k: int = 1,) -> dict:
1361
        """
1362
        Computes the swap metrics for all outcomes.
1363
        """
1364
        # Get decision thresholds for auroc and auprc computation
1365
        thresholds = []
1366
        thresholds = list(pred_cates.reshape(-1))
1367
        
1368
        ## For top k auroc
1369
        # for i in range(T.shape[0]):
1370
        #     shap_values = shap_values_pred[T[i],i,:]
1371
1372
        #     # Get the top k features in terms of absolute SHAP values
1373
        #     top_k = np.argsort(np.abs(shap_values))[-k:]
1374
        #     thresholds.append(np.sum(shap_values[top_k]) + shap_base_values_pred[T[i]])
1375
        
1376
        # Only use unique thresholds
1377
        thresholds = np.unique(thresholds)
1378
1379
        if thresholds.shape[0] == 1:
1380
            thresholds = np.array([-1, thresholds[0],1])
1381
1382
        # Only leave 200 thresholds to accelerate computation of scores
1383
        if len(thresholds) >= 200:
1384
            step = len(thresholds)//200
1385
        else:
1386
            step = 1
1387
        thresholds = thresholds[::step]
1388
1389
        # Iterate through all thresholds and compute statistics
1390
        fprs, tprs, precisions, recalls = [], [], [], []
1391
        for threshold in thresholds:
1392
            fpr, tpr, precision, recall = self.get_swap_statistics_single_outcome(T, 
1393
                                                                                  outcomes, 
1394
                                                                                  pred_cates,
1395
                                                                                  shap_values_pred, 
1396
                                                                                  shap_base_values_pred,
1397
                                                                                  k, threshold)
1398
1399
            fprs.append(fpr)
1400
            tprs.append(tpr)
1401
1402
            if precision is not np.nan:
1403
                precisions.append(precision)
1404
            # else:
1405
            #     log.info("Precision is nan")
1406
                
1407
            if recall is not np.nan:
1408
                recalls.append(recall)
1409
            # else:
1410
            #     log.info("Recall is nan")
1411
1412
        # Compute auroc and auprc
1413
        # Sort fprs and tprs by fpr
1414
        fprs_sort, tprs_sort = zip(*sorted(zip(fprs, tprs)))
1415
        fprs_sort, tprs_sort = np.array(fprs_sort), np.array(tprs_sort)
1416
1417
        # Compute the AUC
1418
        roc_auc = auc(fprs_sort, tprs_sort)
1419
1420
        # Compute the AUC for the precision-recall curve
1421
        if len(recalls) > 1 and len(precisions) > 1:
1422
            # Sort precisions and recalls by recall
1423
            recalls_sort, precisions_sort = zip(*sorted(zip(recalls, precisions)))
1424
            recalls_sort, precisions_sort = np.array(recalls_sort), np.array(precisions_sort)
1425
        
1426
            pr_auc = auc(recalls_sort, precisions_sort)
1427
        else:
1428
            pr_auc = -1
1429
1430
        return roc_auc, pr_auc
1431
    
1432
    def compute_swap_metrics(self,
1433
                            T: np.ndarray,
1434
                            true_outcomes: np.ndarray,
1435
                            pred_cates: np.ndarray,
1436
                            shap_values_pred: np.ndarray,
1437
                            shap_base_values_pred: np.ndarray,
1438
                            k: int = 1) -> dict:
1439
        """
1440
        Compute swap metrics for all outcomes.
1441
        """
1442
        roc_aucs, pr_aucs = [], []
1443
1444
        for i in range(self.cfg.simulator.dim_Y):
1445
            roc_auc, pr_auc = self.compute_swap_metrics_single_outcome(T, 
1446
                                                                      true_outcomes[:,:,i], 
1447
                                                                      pred_cates[:,:,i],
1448
                                                                      shap_values_pred[:,:,:,i], 
1449
                                                                      shap_base_values_pred[:,i],
1450
                                                                      k)
1451
            roc_aucs.append(roc_auc)
1452
            pr_aucs.append(pr_auc)
1453
           
1454
        roc_auc_total = np.mean(roc_aucs)
1455
        pr_auc_total = np.mean(pr_aucs)
1456
1457
        # Compute swap percentage with threshold 0
1458
        counter = 0
1459
        swap_counter = 0
1460
        policy_precision = 0
1461
        predicted_precision = 0
1462
        for j in range(self.cfg.simulator.dim_Y):
1463
            for i in range(T.shape[0]):
1464
                # Check whether the true outcomes would speak for swapping the treatment
1465
                true_outcome = true_outcomes[i,T[i],j]
1466
                true_outcome_mean = np.mean(true_outcomes[i,:,j])
1467
                swap_true = true_outcome < true_outcome_mean # swap if below mean
1468
                counter += 1.0
1469
                swap_counter += swap_true
1470
  
1471
        return roc_auc_total, pr_auc_total, swap_counter/counter
1472
    
1473
    def compute_ci_coverage(self,
1474
                            pred_effect_cis: np.ndarray, # dim: 2, n, num_T-1, dim_Y
1475
                            true_cates: np.ndarray, # n, num_T, dim_Y
1476
                            T: np.ndarray) -> float:
1477
        """
1478
        Compute the coverage of the confidence intervals.
1479
        """
1480
1481
        ci_coverage = 0
1482
        counter = 0
1483
        for i in range(pred_effect_cis.shape[2]):
1484
            for j in range(pred_effect_cis.shape[3]):
1485
                lb = pred_effect_cis[0,:,i,j]
1486
                ub = pred_effect_cis[1,:,i,j]
1487
                true_cate = true_cates[:,i,j]
1488
1489
                # Only consider the cates where the baseline and assigned treatment are different
1490
                mask = T != i
1491
                lb = lb[mask]
1492
                ub = ub[mask]
1493
                true_cate = true_cate[mask]
1494
1495
                ci_coverage += np.sum((lb <= true_cate) & (true_cate <= ub))
1496
                counter += len(true_cate)
1497
1498
                log.debug(
1499
                    f'Check ci coverage computation for outcome: {j}, cate: {i}'
1500
                    f'============================================'
1501
                    f'\ntrue_cates: {true_cate.shape}'
1502
                    f'\n{true_cate}'
1503
                    f'\n\nT: {T.shape}'
1504
                    f'\n{T}'
1505
                    f'\n\nlb: {lb.shape}'
1506
                    f'\n{lb}'
1507
                    f'\n\nub: {ub.shape}'
1508
                    f'\n{ub}'
1509
                    f'\n\ncoverage is currently: {ci_coverage}'
1510
                    f'\n============================================\n\n'
1511
                )
1512
1513
        ci_coverage /= counter
1514
        
1515
        return ci_coverage
1516
1517
    def compute_expertise_metrics(self,
1518
                                    T: np.ndarray,
1519
                                    outcomes: np.ndarray,
1520
                                    type: str = "prognostic") -> tuple:
1521
        """
1522
        Compute the prognostic or treatment expertise.
1523
        """
1524
        if self.cfg.simulator.num_T != 2:
1525
            raise ValueError("Expertise metrics can only be computed for binary treatments.")
1526
1527
        # Get potential outcomes
1528
        y0 = outcomes[:, 0, 0]
1529
        y1 = outcomes[:, 1, 0]
1530
1531
        # prognostic expertise calculation
1532
        if type == "predictive" or type == "prognostic" or type == "treatment":
1533
            if type == "predictive":
1534
                _, gt_bins = np.histogram(y1 - y0, bins="auto")
1535
                gt_uncond_result = y1 - y0
1536
1537
            elif type == "prognostic":
1538
                _, gt_bins = np.histogram(y0, bins="auto")
1539
                gt_uncond_result = y0
1540
1541
            elif type == "treatment":
1542
                _, gt_bins = np.histogram(y1, bins="auto")
1543
                gt_uncond_result = y1
1544
1545
            else:
1546
                raise ValueError("Invalid expertise type. Choose between 'predictive', 'prognostic' or 'treatment'.")
1547
1548
            actions, _ = np.histogram(T, bins='auto')
1549
            actions = actions / T.shape[0]
1550
            actentropy = -np.sum(actions * np.ma.log(actions))
1551
1552
            num_ones = np.sum(T) / T.shape[0]
1553
1554
            # cond = -np.mean(propensity_test * np.log(propensity_test) + (1 - propensity_test) * np.log(1 - propensity_test))
1555
            gt_uncond_hist, gt_uncond_bins = np.histogram(gt_uncond_result, bins=gt_bins)
1556
            gt_uncond_hist = gt_uncond_hist / gt_uncond_result.shape[0]
1557
            gt_uncond_hist = gt_uncond_hist[gt_uncond_hist != 0]
1558
            gt_uncond_entropy = -np.sum(gt_uncond_hist * np.log(gt_uncond_hist))
1559
1560
            gt_cond_one = gt_uncond_result * T
1561
            gt_cond_one = gt_cond_one[gt_cond_one != 0]
1562
            gt_one_hist, _ = np.histogram(gt_cond_one, bins=gt_bins)
1563
            gt_one_hist = gt_one_hist / gt_cond_one.shape[0]
1564
            gt_one_hist = gt_one_hist[gt_one_hist != 0]
1565
            gt_one_entropy = -np.sum(gt_one_hist * np.log(gt_one_hist))
1566
1567
            gt_cond_zero = gt_uncond_result * (1 - T)
1568
            gt_cond_zero = gt_cond_zero[gt_cond_zero != 0]
1569
            gt_zero_hist, _ = np.histogram(gt_cond_zero, bins=gt_bins)
1570
            gt_zero_hist = gt_zero_hist / gt_cond_zero.shape[0]
1571
            gt_zero_hist = gt_zero_hist[gt_zero_hist != 0]
1572
            gt_zero_entropy = -np.sum(gt_zero_hist * np.log(gt_zero_hist))
1573
1574
            gt_expertise = gt_uncond_entropy - num_ones * gt_one_entropy - (1 - num_ones) * gt_zero_entropy
1575
            gt_expertise = gt_expertise / actentropy
1576
1577
        elif type == "total":
1578
            _, gt1_bins, gt0_bins = np.histogram2d(y1, y0)
1579
1580
            actions, _ = np.histogram(T, bins='auto')
1581
            actions = actions / T.shape[0]
1582
            actentropy = -np.sum(actions * np.ma.log(actions))
1583
1584
            num_ones = np.sum(T) / T.shape[0]
1585
1586
            # cond = -np.mean(propensity_test * np.log(propensity_test) + (1 - propensity_test) * np.log(1 - propensity_test))
1587
1588
            gt_uncond_hist, _, _ = np.histogram2d(y1, y0, bins=[gt1_bins, gt0_bins])
1589
            gt_uncond_hist = gt_uncond_hist / y1.shape[0] #/ y0.shape[0]
1590
            gt_uncond_hist = gt_uncond_hist[gt_uncond_hist != 0]
1591
            gt_uncond_entropy = -np.sum(gt_uncond_hist * np.log(gt_uncond_hist))
1592
1593
            gt_cond_one1 = y1 * T
1594
            gt_cond_one1 = gt_cond_one1[gt_cond_one1 != 0]
1595
            gt_cond_one0 = y0 * T
1596
            gt_cond_one0 = gt_cond_one0[gt_cond_one0 != 0]
1597
            gt_one_hist, _, _ = np.histogram2d(gt_cond_one1, gt_cond_one0, bins=[gt1_bins, gt0_bins])
1598
            gt_one_hist = gt_one_hist / gt_cond_one1.shape[0] #/ gt_cond_one0.shape[0]
1599
            gt_one_hist = gt_one_hist[gt_one_hist != 0]
1600
            gt_one_entropy = -np.sum(gt_one_hist * np.log(gt_one_hist))
1601
1602
            gt_cond_zero1 = y1 * (1 - T)
1603
            gt_cond_zero1 = gt_cond_zero1[gt_cond_zero1 != 0]
1604
            gt_cond_zero0 = y0 * (1 - T)
1605
            gt_cond_zero0 = gt_cond_zero0[gt_cond_zero0 != 0]
1606
            gt_zero_hist, _, _ = np.histogram2d(gt_cond_zero1, gt_cond_zero0, bins=[gt1_bins, gt0_bins])
1607
            gt_zero_hist = gt_zero_hist / gt_cond_zero1.shape[0] #/ gt_cond_zero0.shape[0]
1608
            gt_zero_hist = gt_zero_hist[gt_zero_hist != 0]
1609
            gt_zero_entropy = -np.sum(gt_zero_hist * np.log(gt_zero_hist))
1610
1611
            gt_expertise = gt_uncond_entropy - num_ones * gt_one_entropy - (1 - num_ones) * gt_zero_entropy
1612
            gt_expertise = gt_expertise / actentropy
1613
        
1614
        else:
1615
            raise ValueError("Invalid expertise type. Choose between 'predictive', 'prognostic' or 'treatment'.")
1616
        
1617
        return gt_expertise
1618
    
1619
    def compute_incontext_variability(self,
1620
                                      T: np.ndarray,
1621
                                        propensities: np.ndarray) -> float:
1622
        """
1623
        Compute the in-context variability.
1624
        """
1625
        # Sum over contribution of all patients
1626
        props = propensities.reshape(-1)
1627
        props = props[props != 0]
1628
        cond_entropy = 2*-np.mean(props * np.log(props))
1629
1630
        # Get action entropy
1631
        actions, _ = np.histogram(T, bins='auto')
1632
        actions = actions / T.shape[0]
1633
        actentropy = -np.sum(actions * np.ma.log(actions))
1634
1635
        return cond_entropy / actentropy
1636
1637
    def get_pred_assignment_precision(self, pred_outcomes, true_outcomes):
1638
        # Compute swap percentage with threshold 0
1639
        counter = 0
1640
        correct_counter = 0
1641
        for j in range(self.cfg.simulator.dim_Y):
1642
            for i in range(pred_outcomes.shape[0]):
1643
                # Check whether the true outcomes would speak for swapping the treatment
1644
                true_y0 = true_outcomes[i,0,j]
1645
                true_y1 = true_outcomes[i,1,j]
1646
                pred_y0 = pred_outcomes[i,0,j]
1647
                pred_y1 = pred_outcomes[i,1,j]
1648
1649
                true_ranking = true_y0 < true_y1 
1650
                pred_ranking = pred_y0 < pred_y1 
1651
1652
                if (true_ranking == 0 and pred_ranking == 0) or ((true_ranking == 1 and pred_ranking == 1)):
1653
                    correct_counter += 1
1654
                
1655
                counter += 1.0
1656
        return correct_counter / counter
1657
1658
    def compute_metrics(self,
1659
                        results_data: dict,
1660
                        sim: SimulatorBase,
1661
                        X_train: np.ndarray,
1662
                        Y_train: np.ndarray,
1663
                        T_train: np.ndarray,
1664
                        X_test: np.ndarray,
1665
                        Y_test: np.ndarray,
1666
                        T_test: np.ndarray,
1667
                        outcomes_train: np.ndarray,
1668
                        outcomes_test: np.ndarray,
1669
                        propensities_train: np.ndarray,
1670
                        propensities_test: np.ndarray,
1671
                        x_axis_value: float,
1672
                        x_axis_name: str,
1673
                        compare_value: float,
1674
                        compare_name: str,
1675
                        seed: int,
1676
                        split_id: int) -> dict:
1677
        """
1678
        Compute metrics for a given experiment.
1679
        """
1680
        # Get learners
1681
        self.baseline_learners = self.get_baseline_learners(seed=seed)
1682
        self.learners = self.get_learners(num_features=X_train.shape[1], seed=seed)
1683
1684
        # Train learners
1685
        if self.cfg.train_baseline_learner:
1686
            self.train_baseline_learners(X_train, outcomes_train, T_train)
1687
        self.train_learners(X_train, Y_train, T_train, outcomes_train)
1688
1689
        # Get treatment distribution for train and test
1690
        train_treatment_distribution = np.bincount(T_train) / len(T_train) 
1691
        test_treatment_distribution = np.bincount(T_test) / len(T_test)
1692
1693
        if self.cfg.simulator.num_T == 2: # Such that it can be plotted
1694
            train_treatment_distribution = train_treatment_distribution[0]
1695
            test_treatment_distribution = test_treatment_distribution[0]
1696
1697
        # Get learner explanations
1698
        (learner_explanations, _) = self.get_learner_explanations(X=X_test, type="pred") if self.cfg.evaluate_explanations else (None, None)
1699
        (learner_prog_explanations, _) = self.get_learner_explanations(X=X_test, type="prog") if self.cfg.evaluate_prog_explanations else (None, None)
1700
1701
        # Get and train select learner - treatment selection model
1702
        if self.cfg.evaluate_in_context_variability:
1703
            self.select_learner = self.get_select_learner(seed=seed)
1704
            self.train_select_learner(X_train, T_train)
1705
            propensities_pred_test = self.select_learner.predict_proba(X_test)
1706
            propensities_pred_train = self.select_learner.predict_proba(X_train)
1707
1708
        # # get auroc for propensity score
1709
        # propensity_auc = roc_auc_score(T_test, prop_test[:,1])
1710
        # # propensity_auc = auc(fpr, tpr)
1711
        # print(propensity_auc)
1712
        # quit()
1713
        # select_learner_explanations = self.get_select_learner_explanations(X_reference=X_train, 
1714
        #                                                                     X_to_explain=X_test)
1715
1716
        # Compute metrics
1717
        baseline_metrics = {} # First model in list will be used as baseline
1718
1719
        for i, model_name in enumerate(self.cfg.model_names):
1720
            # try:
1721
            # Compute attribution accuracies
1722
            # print(f"Computing attribution accuracies for model: {model_name}...")
1723
            # print(learner_explanations)
1724
            # print(f"Attribution explanations: {learner_explanations[model_name].shape}")
1725
            if self.cfg.evaluate_explanations and learner_explanations[model_name] is not None:
1726
                attribution_est = np.abs(learner_explanations[model_name]) 
1727
                pred_acc_scores_all_features = attribution_accuracy(self.all_important_features, attribution_est) 
1728
                pred_acc_scores_predictive_features = attribution_accuracy(self.pred_features, attribution_est) 
1729
                pred_acc_scores_prog_features = attribution_accuracy(self.prog_features, attribution_est) 
1730
                pred_acc_scores_selective_features = attribution_accuracy(self.select_features, attribution_est) 
1731
1732
            else:
1733
                pred_acc_scores_all_features = -1
1734
                pred_acc_scores_predictive_features = -1
1735
                pred_acc_scores_prog_features = -1
1736
                pred_acc_scores_selective_features = -1
1737
1738
            if self.cfg.evaluate_prog_explanations and learner_prog_explanations[model_name] is not None:
1739
                prog_attribution_est = np.abs(learner_prog_explanations[model_name]) 
1740
                prog_acc_scores_all_features = attribution_accuracy(self.all_important_features, prog_attribution_est) 
1741
                prog_acc_scores_predictive_features = attribution_accuracy(self.pred_features, prog_attribution_est) 
1742
                prog_acc_scores_prog_features = attribution_accuracy(self.prog_features, prog_attribution_est) 
1743
                prog_acc_scores_selective_features = attribution_accuracy(self.select_features, prog_attribution_est) 
1744
1745
            else:
1746
                prog_acc_scores_all_features = -1
1747
                prog_acc_scores_predictive_features = -1
1748
                prog_acc_scores_prog_features = -1
1749
                prog_acc_scores_selective_features = -1
1750
1751
            # Compute predicted cates and outcomes
1752
            pred_cates = self.get_pred_cates(model_name, X_test, T_test, outcomes_test)
1753
            pred_cates_train = self.get_pred_cates(model_name, X_train, T_train, outcomes_train)
1754
1755
            pred_outcomes = self.get_pred_outcomes(model_name, pred_cates, X_test, T_test)
1756
            pred_outcomes_train = self.get_pred_outcomes(model_name, pred_cates_train, X_train, T_train)
1757
1758
            pred_effect_cis = self.get_effect_cis(model_name, X_test, T_test) if self.evaluate_inference else None
1759
1760
     
1761
            # quit()
1762
            # pred_cates_train = self.get_pred_cates(model_name, X_train, T_train, outcomes_train)
1763
            # pred_outcomes_train = self.get_pred_outcomes(pred_cates_train, X_train, T_train)
1764
1765
1766
            # Train models for treat-specific explanations
1767
            temp = self.discrete_outcome
1768
            self.discrete_outcome = 0
1769
            self.pred_learners = self.get_pred_learners(seed=seed) # One for each treatment as reference treatment
1770
            self.prog_learner = self.get_prog_learner(seed=seed) # One for the average over all treatments
1771
            self.discrete_outcome = temp
1772
1773
            # self.train_prog_learner(X_train, pred_outcomes_train)
1774
            # self.train_pred_learner(X_train, pred_cates_train, T_train)
1775
            # Get treat-specific explanations
1776
            # prog_learner_explanations = self.get_prog_learner_explanations(X_reference=X_train,
1777
            #                                                                 X_to_explain=X_test)
1778
            # if self.cfg.evaluate_explanations:
1779
            #     pred_learner_explanations, pred_learner_base_values = self.get_pred_learner_explanations(X_reference=X_train,
1780
            #                                                                                         X_to_explain=X_test)
1781
            # else: 
1782
            #     pred_learner_explanations = np.zeros((self.cfg.simulator.num_T, X_test.shape[0], X_test.shape[1], self.cfg.simulator.dim_Y))
1783
            #     pred_learner_base_values = np.zeros((self.cfg.simulator.num_T, self.cfg.simulator.dim_Y))
1784
1785
            ## COMPUTE METRICS ##
1786
            # Compute PEHE
1787
            true_cates = sim.get_true_cates(X_test, T_test, outcomes_test)
1788
1789
            # print(f"Pred cates: {pred_cates[:5]}")
1790
            # print(f"Pred outcomes: {pred_outcomes[:5]}")
1791
            # print(f"True cates: {true_cates[:5]}")
1792
            # print(f"True outcomes: {outcomes_test[:5]}")
1793
            # quit()
1794
            (
1795
                pehe, 
1796
                pehe_normalized, 
1797
                true_cates_mean, 
1798
                true_cates_std
1799
            ) = self.compute_overall_pehe(pred_cates, true_cates, T_test)
1800
1801
            (
1802
                rmse_Y0, rmse_Y1, rmse_factual_Y0, rmse_factual_Y1, rmse_cf_Y0, rmse_cf_Y1, 
1803
                factual_outcomes_rmse, cf_outcomes_rmse,
1804
                factual_outcomes_rmse_normalized, cf_outcomes_rmse_normalized,
1805
                factual_outcomes_mean, cf_outcomes_mean,
1806
                factual_outcomes_std, cf_outcomes_std
1807
            )= self.compute_outcome_mse(pred_outcomes, outcomes_test, T_test)
1808
            f_cf_diff = np.abs(factual_outcomes_rmse - cf_outcomes_rmse)
1809
            f_cf_diff_norm = np.abs(factual_outcomes_rmse_normalized - cf_outcomes_rmse_normalized)
1810
1811
1812
            # Get aurocs in case of discrete outcomes
1813
            cf_outcomes_auroc = -1
1814
            factual_outcomes_auroc = -1
1815
            if self.discrete_outcome:
1816
                factual_outcomes_auroc, cf_outcomes_auroc = self.compute_outcome_auroc(pred_outcomes, outcomes_test, T_test)
1817
1818
            # Compute personalized explanations metrics
1819
            #k_tre = self.cfg.simulator.num_pred_features * self.cfg.simulator.num_T 
1820
            k_all = X_test.shape[1]
1821
1822
            # Compute confidence intervall coverage
1823
            ci_coverage = self.compute_ci_coverage(pred_effect_cis, true_cates, T_test) if self.evaluate_inference else -1
1824
1825
            # Compute expertise metrics
1826
            # Get learned policy
1827
            T_pred = pred_outcomes[:, :, 0].argmax(axis=1)
1828
1829
            try:
1830
                gt_pred_expertise = self.compute_expertise_metrics(T_train, outcomes_train, type="predictive")
1831
                gt_prog_expertise = self.compute_expertise_metrics(T_train, outcomes_train, type="prognostic")  
1832
                gt_tre_expertise = self.compute_expertise_metrics(T_train, outcomes_train, type="treatment") 
1833
            
1834
                updated_gt_pred_expertise = self.compute_expertise_metrics(T_pred, outcomes_test, type="predictive")
1835
                updated_gt_prog_expertise = self.compute_expertise_metrics(T_pred, outcomes_test, type="prognostic")  
1836
                updated_gt_tre_expertise = self.compute_expertise_metrics(T_pred, outcomes_test, type="treatment") 
1837
1838
                gt_total_expertise = self.compute_expertise_metrics(T_train, outcomes_train, type="total") if self.discrete_outcome == 0 else -1
1839
                es_pred_expertise = self.compute_expertise_metrics(T_train, pred_outcomes_train, type="predictive")
1840
                es_prog_expertise = self.compute_expertise_metrics(T_train, pred_outcomes_train, type="prognostic")
1841
                es_tre_expertise = self.compute_expertise_metrics(T_train, pred_outcomes_train, type="treatment")
1842
                es_total_expertise = self.compute_expertise_metrics(T_train, pred_outcomes_train, type="total") if self.discrete_outcome == 0 else -1
1843
1844
            except:
1845
                print("Error while computing expertise")
1846
                gt_pred_expertise = -1
1847
                gt_prog_expertise = -1  
1848
                gt_tre_expertise = -1
1849
                gt_total_expertise = -1
1850
                updated_gt_pred_expertise = -1
1851
                updated_gt_prog_expertise = -1
1852
                updated_gt_tre_expertise = -1
1853
                es_pred_expertise = -1
1854
                es_prog_expertise = -1
1855
                es_total_expertise = -1 
1856
                es_tre_expertise = -1
1857
1858
            # Compute incontext variability
1859
            try:
1860
                gt_incontext_variability = self.compute_incontext_variability(T_train, propensities_train)
1861
                es_incontext_variability = self.compute_incontext_variability(T_train, propensities_pred_train) if self.cfg.evaluate_in_context_variability else -1
1862
            except:
1863
                gt_incontext_variability = -1
1864
                es_incontext_variability = -1
1865
            # swap_auroc_1, swap_auprc_1 = self.compute_swap_metrics(T=T_test,
1866
            #                                                     true_outcomes=outcomes_test,
1867
            #                                                     pred_cates=pred_cates,
1868
            #                                                     shap_values_pred=pred_learner_explanations,
1869
            #                                                     shap_base_values_pred=pred_learner_base_values,
1870
            #                                                     k=1)
1871
            # swap_auroc_5, swap_auprc_5 = self.compute_swap_metrics(T=T_test,
1872
            #                                                     true_outcomes=outcomes_test,
1873
            #                                                     pred_cates=pred_cates,
1874
            #                                                     shap_values_pred=pred_learner_explanations,
1875
            #                                                     shap_base_values_pred=pred_learner_base_values,
1876
            #                                                     k=5)
1877
            # swap_auroc_tre, swap_auprc_tre = self.compute_swap_metrics(T=T_test,
1878
            #                                                         true_outcomes=outcomes_test,
1879
            #                                                         pred_cates=pred_cates,
1880
            #                                                         shap_values_pred=pred_learner_explanations,
1881
            #                                                         shap_base_values_pred=pred_learner_base_values,
1882
            #                                                         k=k_tre)
1883
1884
            pred_learner_explanations = np.zeros((self.cfg.simulator.num_T, X_test.shape[0], X_test.shape[1], self.cfg.simulator.dim_Y))
1885
            pred_learner_base_values = np.zeros((self.cfg.simulator.num_T, self.cfg.simulator.dim_Y))
1886
            (
1887
                swap_auroc_all, 
1888
                swap_auprc_all, 
1889
                swap_perc 
1890
            ) = self.compute_swap_metrics(T=T_test,
1891
                                        true_outcomes=outcomes_test,
1892
                                        pred_cates=pred_cates,
1893
                                        shap_values_pred=pred_learner_explanations,
1894
                                        shap_base_values_pred=pred_learner_base_values,
1895
                                        k=k_all)
1896
            
1897
            pred_correct_assignment_precision = self.get_pred_assignment_precision(pred_outcomes, outcomes_test)
1898
            policy_correct_assignment_precision = 1-swap_perc
1899
            # Get scores relative to baseline
1900
            if i == 0:
1901
                baseline_metrics["pehe"] = pehe
1902
                baseline_metrics["pehe_normalized"] = pehe_normalized
1903
                baseline_metrics["factual_outcomes_rmse"] = factual_outcomes_rmse
1904
                baseline_metrics["factual_outcomes_rmse_normalized"] = factual_outcomes_rmse_normalized
1905
                baseline_metrics["cf_outcomes_rmse"] = cf_outcomes_rmse
1906
                baseline_metrics["cf_outcomes_rmse_normalized"] = cf_outcomes_rmse_normalized
1907
                baseline_metrics["swap_auroc_all"] = swap_auroc_all
1908
                baseline_metrics["swap_auprc_all"] = swap_auprc_all
1909
1910
            fc_pehe = pehe_normalized - baseline_metrics["pehe_normalized"]
1911
            fc_factual_outcomes = factual_outcomes_rmse_normalized - baseline_metrics["factual_outcomes_rmse_normalized"] 
1912
            fc_cf_outcomes = cf_outcomes_rmse_normalized - baseline_metrics["cf_outcomes_rmse_normalized"] 
1913
            fc_swap_auroc_all = swap_auroc_all - baseline_metrics["swap_auroc_all"]
1914
            fc_swap_auprc_all = swap_auprc_all - baseline_metrics["swap_auprc_all"]
1915
1916
            # Compute attribution accuracies
1917
            # pred_attribution_est = np.abs(pred_learner_explanations).mean(axis=(0,3)) # average over treatments and outcomes
1918
            # prog_attribution_est = np.abs(prog_learner_explanations).mean(axis=2) # average over outcomes
1919
            # select_attribution_est = np.abs(select_learner_explanations) # average over outcomes
1920
1921
            # pred_acc_scores_all_features = attribution_accuracy(self.all_important_features, pred_attribution_est)
1922
            # pred_acc_scores_predictive_features = attribution_accuracy(self.pred_features, pred_attribution_est)
1923
            # pred_acc_scores_prog_features = attribution_accuracy(self.prog_features, pred_attribution_est)
1924
            # pred_acc_scores_selective_features = attribution_accuracy(self.select_features, pred_attribution_est)
1925
1926
            # prog_acc_scores_all_features = attribution_accuracy(self.all_important_features, prog_attribution_est)
1927
            # prog_acc_scores_predictive_features = attribution_accuracy(self.pred_features, prog_attribution_est)
1928
            # prog_acc_scores_prog_features = attribution_accuracy(self.prog_features, prog_attribution_est)
1929
            # prog_acc_scores_selective_features = attribution_accuracy(self.select_features, prog_attribution_est)
1930
1931
            # select_acc_scores_all_features = attribution_accuracy(self.all_important_features, select_attribution_est)
1932
            # select_acc_scores_predictive_features = attribution_accuracy(self.pred_features, select_attribution_est)
1933
            # select_acc_scores_prog_features = attribution_accuracy(self.prog_features, select_attribution_est)
1934
            # select_acc_scores_selective_features = attribution_accuracy(self.select_features, select_attribution_est)
1935
            
1936
            results_data.append(
1937
                [
1938
                    seed,
1939
                    split_id,
1940
                    x_axis_value,
1941
                    compare_value,
1942
                    model_name,
1943
                    # explainer_names[model_name],
1944
                    true_cates_mean,
1945
                    true_cates_std,
1946
                    pehe,
1947
                    pehe_normalized,
1948
                    factual_outcomes_auroc,
1949
                    cf_outcomes_auroc,
1950
                    rmse_Y0,
1951
                    rmse_Y1,
1952
                    rmse_factual_Y0, 
1953
                    rmse_factual_Y1, 
1954
                    rmse_cf_Y0, 
1955
                    rmse_cf_Y1, 
1956
                    factual_outcomes_rmse,
1957
                    cf_outcomes_rmse,
1958
                    factual_outcomes_rmse_normalized,
1959
                    cf_outcomes_rmse_normalized,
1960
                    factual_outcomes_mean, 
1961
                    cf_outcomes_mean,
1962
                    factual_outcomes_std, 
1963
                    cf_outcomes_std,
1964
                    f_cf_diff,
1965
                    f_cf_diff_norm,
1966
                    fc_pehe,
1967
                    fc_factual_outcomes,
1968
                    fc_cf_outcomes,
1969
                    fc_swap_auroc_all,
1970
                    fc_swap_auprc_all,
1971
                    # swap_auroc_1,
1972
                    # swap_auprc_1,
1973
                    # swap_auroc_5,
1974
                    # swap_auprc_5,
1975
                    # swap_auroc_tre,
1976
                    # swap_auprc_tre,
1977
                    swap_auroc_all,
1978
                    swap_auprc_all,
1979
                    swap_perc,
1980
                    pred_correct_assignment_precision,
1981
                    policy_correct_assignment_precision,
1982
                    ci_coverage,
1983
                    pred_acc_scores_all_features,
1984
                    pred_acc_scores_predictive_features,
1985
                    pred_acc_scores_prog_features,
1986
                    pred_acc_scores_selective_features,
1987
                    prog_acc_scores_all_features,
1988
                    prog_acc_scores_predictive_features,
1989
                    prog_acc_scores_prog_features,
1990
                    prog_acc_scores_selective_features,
1991
                    train_treatment_distribution,
1992
                    test_treatment_distribution,
1993
                    gt_pred_expertise,
1994
                    gt_prog_expertise,
1995
                    gt_tre_expertise,
1996
                    gt_pred_expertise/(gt_pred_expertise + gt_prog_expertise),
1997
                    gt_total_expertise,
1998
                    updated_gt_pred_expertise,
1999
                    updated_gt_prog_expertise,
2000
                    updated_gt_tre_expertise,
2001
                    es_pred_expertise,
2002
                    es_prog_expertise,
2003
                    es_tre_expertise,
2004
                    es_total_expertise,
2005
                    np.abs(gt_pred_expertise - es_pred_expertise),
2006
                    np.abs(gt_prog_expertise - es_prog_expertise),
2007
                    np.abs(gt_total_expertise - es_total_expertise),
2008
                    1-gt_incontext_variability,
2009
                    1-es_incontext_variability,
2010
                    self.training_times[model_name],
2011
                    # select_acc_scores_all_features,
2012
                    # select_acc_scores_predictive_features,
2013
                    # select_acc_scores_prog_features,
2014
                    # select_acc_scores_selective_features
2015
                ]
2016
            )
2017
2018
            metrics_df = pd.DataFrame(
2019
                results_data,
2020
                columns=[
2021
                    "Seed",
2022
                    "Split ID",
2023
                    x_axis_name,
2024
                    compare_name,
2025
                    "Learner",
2026
                    # "Explainer",
2027
                    "CATE true mean",
2028
                    "CATE true std",
2029
                    "PEHE",
2030
                    "Normalized PEHE",
2031
                    "Factual AUROC",
2032
                    "CF AUROC",
2033
                    "RMSE Y0",
2034
                    "RMSE Y1",
2035
                    "Factual RMSE Y0",
2036
                    "Factual RMSE Y1",
2037
                    "CF RMSE Y0",
2038
                    "CF RMSE Y1",
2039
                    "Factual RMSE",
2040
                    "CF RMSE",
2041
                    "Normalized F-RMSE",
2042
                    "Normalized CF-RMSE",
2043
                    "F-Outcome true mean",
2044
                    "CF-Outcome true mean",
2045
                    "F-Outcome true std",
2046
                    "CF-Outcome true std",
2047
                    "F-CF Outcome Diff",
2048
                    "Normalized F-CF Diff",
2049
                    "FC PEHE",
2050
                    "FC F-RMSE",
2051
                    "FC CF-RMSE",
2052
                    "FC Swap AUROC",
2053
                    "FC Swap AUPRC",
2054
                    # "Swap AUROC@1",
2055
                    # "Swap AUPRC@1",
2056
                    # "Swap AUROC@5",
2057
                    # "Swap AUPRC@5",
2058
                    # "Swap AUROC@tre",
2059
                    # "Swap AUPRC@tre",
2060
                    "Swap AUROC@all",
2061
                    "Swap AUPRC@all",
2062
                    "True Swap Perc",
2063
                    "Pred Precision",
2064
                    "Policy Precision",
2065
                    "CI Coverage",
2066
                    "Pred: All features ACC",
2067
                    "Pred: Pred features ACC",
2068
                    "Pred: Prog features ACC",
2069
                    "Pred: Select features ACC",
2070
                    "Prog: All features ACC",
2071
                    "Prog: Pred features ACC",
2072
                    "Prog: Prog features ACC",
2073
                    "Prog: Select features ACC",
2074
                    "T Distribution: Train",
2075
                    "T Distribution: Test",
2076
                    "GT Pred Expertise",
2077
                    "GT Prog Expertise",
2078
                    "GT Tre Expertise",
2079
                    "GT Expertise Ratio",
2080
                    "GT Total Expertise",
2081
                    "Upd. GT Pred Expertise",
2082
                    "Upd. GT Prog Expertise",
2083
                    "Upd. GT Tre Expertise",
2084
                    "ES Pred Expertise",
2085
                    "ES Prog Expertise",
2086
                    "ES Tre Expertise",
2087
                    "ES Total Expertise",
2088
                    "GT-ES Pred Expertise Diff",
2089
                    "GT-ES Prog Expertise Diff",
2090
                    "GT-ES Total Expertise Diff",
2091
                    "GT In-context Var",
2092
                    "ES In-context Var",
2093
                    "Training Duration",
2094
                    # "Select: All features ACC",
2095
                    # "Select: Pred features ACC",
2096
                    # "Select: Prog features ACC",
2097
                    # "Select: Select features ACC",
2098
                ],
2099
            )
2100
2101
            # Save intermediate results
2102
            self.save_results(metrics_df, save_df_only=True, compare_axis_values=None)
2103
2104
        return metrics_df
2105
    
2106
    
2107
2108
       
2109