Switch to unified view

a b/src/iterpretability/experiments.py
1
from pathlib import Path
2
import os
3
import catenets.models as cate_models
4
import numpy as np
5
import pandas as pd
6
7
import src.iterpretability.logger as log
8
from src.iterpretability.explain import Explainer
9
from src.iterpretability.datasets.data_loader import load
10
from src.iterpretability.synthetic_simulate import (
11
    SyntheticSimulatorLinear,
12
    SyntheticSimulatorModulatedNonLinear,
13
)
14
from src.iterpretability.utils import (
15
    attribution_accuracy,
16
    compute_pehe,
17
)
18
19
# For contour plotting
20
import umap 
21
from sklearn.decomposition import PCA
22
from sklearn.manifold import TSNE
23
from sklearn.linear_model import LogisticRegression
24
import matplotlib.pyplot as plt
25
import matplotlib.tri as tri
26
import imageio
27
import torch
28
import shap
29
30
def get_learners(model_list, X_train, Y_train, n_iter, batch_size, batch_norm, discrete_treatment=True, discrete_outcome=False):
31
    learners = {
32
                    "TLearner": cate_models.torch.TLearner(
33
                        X_train.shape[1],
34
                        binary_y=(len(np.unique(Y_train)) == 2),
35
                        n_layers_out=2,
36
                        n_units_out=100,
37
                        batch_size=batch_size,
38
                        n_iter=n_iter,
39
                        batch_norm=batch_norm,
40
                        nonlin="relu",
41
                    ),
42
                    "SLearner": cate_models.torch.SLearner(
43
                        X_train.shape[1],
44
                        binary_y=(len(np.unique(Y_train)) == 2),
45
                        n_layers_out=2,
46
                        n_units_out=100,
47
                        n_iter=n_iter,
48
                        batch_size=batch_size,
49
                        batch_norm=batch_norm,
50
                        nonlin="relu",
51
                    ),
52
                    "TARNet": cate_models.torch.TARNet(
53
                        X_train.shape[1],
54
                        binary_y=(len(np.unique(Y_train)) == 2),
55
                        n_layers_r=1,
56
                        n_layers_out=1,
57
                        n_units_out=100,
58
                        n_units_r=100,
59
                        batch_size=batch_size,
60
                        n_iter=n_iter,
61
                        batch_norm=batch_norm,
62
                        nonlin="relu",
63
                    ),
64
                    "DRLearner": cate_models.torch.DRLearner(
65
                        X_train.shape[1],
66
                        binary_y=(len(np.unique(Y_train)) == 2),
67
                        n_layers_out=2,
68
                        n_units_out=100,
69
                        n_iter=n_iter,
70
                        batch_size=batch_size,
71
                        batch_norm=batch_norm,
72
                        nonlin="relu",
73
                    ),
74
                    "XLearner": cate_models.torch.XLearner(
75
                        X_train.shape[1],
76
                        binary_y=(len(np.unique(Y_train)) == 2),
77
                        n_layers_out=2,
78
                        n_units_out=100,
79
                        n_iter=n_iter,
80
                        batch_size=batch_size,
81
                        batch_norm=batch_norm,
82
                        nonlin="relu",
83
                    ),
84
                    "CFRNet_0.01": cate_models.torch.TARNet(
85
                        X_train.shape[1],
86
                        binary_y=(len(np.unique(Y_train)) == 2),
87
                        n_layers_r=1,
88
                        n_layers_out=1,
89
                        n_units_out=100,
90
                        n_units_r=100,
91
                        batch_size=batch_size,
92
                        n_iter=n_iter,
93
                        batch_norm=batch_norm,
94
                        nonlin="relu",
95
                        penalty_disc=0.01,
96
                    ),
97
                    "CFRNet_0.001": cate_models.torch.TARNet(
98
                        X_train.shape[1],
99
                        binary_y=(len(np.unique(Y_train)) == 2),
100
                        n_layers_r=1,
101
                        n_layers_out=1,
102
                        n_units_out=100,
103
                        n_units_r=100,
104
                        batch_size=batch_size,
105
                        n_iter=n_iter,
106
                        batch_norm=batch_norm,
107
                        nonlin="relu",
108
                        penalty_disc=0.001,
109
                    ),
110
                    "CFRNet_0.0001": cate_models.torch.TARNet(
111
                        X_train.shape[1],
112
                        binary_y=(len(np.unique(Y_train)) == 2),
113
                        n_layers_r=1,
114
                        n_layers_out=1,
115
                        n_units_out=100,
116
                        n_units_r=100,
117
                        batch_size=batch_size,
118
                        n_iter=n_iter,
119
                        batch_norm=batch_norm,
120
                        nonlin="relu",
121
                        penalty_disc=0.0001,
122
                    ),
123
                    "EconML_CausalForestDML": cate_models.econml.EconMlEstimator(
124
                        model_name="EconML_CausalForestDML",
125
                        # cv_model_selection=3,
126
                        discrete_treatment=discrete_treatment,
127
                        discrete_outcome=discrete_outcome,
128
                    ),
129
                    "EconML_DMLOrthoForest": cate_models.econml.EconMlEstimator(
130
                        model_name="EconML_DMLOrthoForest",
131
                        # cv_model_selection=3,
132
                        discrete_treatment=discrete_treatment,
133
                        discrete_outcome=discrete_outcome,
134
                    ),
135
136
                    "EconML_SparseLinearDML": cate_models.econml.EconMlEstimator(
137
                        model_name="EconML_SparseLinearDML",
138
                        # cv_model_selection=3,
139
                        discrete_treatment=discrete_treatment,
140
                        discrete_outcome=discrete_outcome,
141
                    ),
142
                    "EconML_SparseLinearDRLearner": cate_models.econml.EconMlEstimator(
143
                        model_name="EconML_SparseLinearDRLearner",
144
                        # cv_model_selection=3,
145
                        discrete_treatment=discrete_treatment,
146
                        discrete_outcome=discrete_outcome,
147
                    ),
148
                    "EconML_LinearDRLearner": cate_models.econml.EconMlEstimator(
149
                        model_name="EconML_LinearDRLearner",
150
                        # cv_model_selection=3,
151
                        discrete_treatment=discrete_treatment,
152
                        discrete_outcome=discrete_outcome,
153
                    ),
154
                    "EconML_DRLearner": cate_models.econml.EconMlEstimator(
155
                        model_name="EconML_DRLearner",
156
                        # cv_model_selection=3,
157
                        discrete_treatment=discrete_treatment,
158
                        discrete_outcome=discrete_outcome,
159
                    ),
160
                    "EconML_XLearner": cate_models.econml.EconMlEstimator(
161
                        model_name="EconML_XLearner",
162
                        # cv_model_selection=3,
163
                        discrete_treatment=discrete_treatment,
164
                        discrete_outcome=discrete_outcome,
165
                    ),
166
                    "EconML_SLearner": cate_models.econml.EconMlEstimator(
167
                        model_name="EconML_SLearner",
168
                        # cv_model_selection=3,
169
                        discrete_treatment=discrete_treatment,
170
                        discrete_outcome=discrete_outcome,
171
                    ),
172
                    "EconML_TLearner": cate_models.econml.EconMlEstimator(
173
                        model_name="EconML_TLearner",
174
                        # cv_model_selection=3,
175
                        discrete_treatment=discrete_treatment,
176
                        discrete_outcome=discrete_outcome,
177
                    ),
178
                    "EconML_SparseLinearDRIV": cate_models.econml.EconMlEstimator(
179
                        model_name="EconML_SparseLinearDRIV",
180
                        # cv_model_selection=3,
181
                        discrete_treatment=discrete_treatment,
182
                        discrete_outcome=discrete_outcome,
183
                    ),
184
185
                }
186
    
187
    for name in model_list:
188
        if name not in learners:
189
            raise Exception(f"Unknown model name {name}.")
190
        
191
    # Only return the learners that are in the model_list
192
    learners = {name: learners[name] for name in model_list}
193
    
194
    return learners
195
196
def get_learner_explanations(learners, X_test, X_train, Y_train, W_train, explainer_limit, explainer_list, return_learners=False, already_trained=False):
197
    learner_explainers = {}
198
    learner_explanations = {}
199
200
    for name in learners:
201
        log.info(f"Fitting {name}.")
202
203
        if not already_trained:
204
            learners[name].fit(X=X_train, y=Y_train, w=W_train)
205
        
206
        log.info(f"Explaining {name}.")
207
208
        if "EconML" in name:
209
            shap_values = learners[name].est.shap_values(X_test[:explainer_limit], background_samples=None)
210
            treatment_names = learners[name].est.cate_treatment_names()
211
            output_names = learners[name].est.cate_output_names()
212
            output_name = output_names[0]
213
            treatment_name = treatment_names[0]
214
            learner_explanations[name] = {"kernel_shap" : shap_values[output_name][treatment_name].values} 
215
            
216
        else:
217
            learner_explainers[name] = Explainer(
218
                learners[name],
219
                feature_names=list(range(X_train.shape[1])),
220
                explainer_list=explainer_list,
221
            )
222
            learner_explanations[name] = learner_explainers[name].explain(
223
                X_test[: explainer_limit]
224
            )
225
226
    if return_learners:
227
        return learner_explanations, learners
228
    else:
229
        return learner_explanations
230
231
class PredictiveSensitivity:
232
    """
233
    Sensitivity analysis for predictive scale.
234
    """
235
236
    def __init__(
237
        self,
238
        n_units_hidden: int = 50,
239
        n_layers: int = 1,
240
        penalty_orthogonal: float = 0.01,
241
        batch_size: int = 1024,
242
        batch_norm: bool = False,
243
        n_iter: int = 1000,
244
        seed: int = 42,
245
        explainer_limit: int = 1000,
246
        save_path: Path = Path.cwd(),
247
        propensity_type: str = "pred",
248
        predictive_scales: list = [1e-3, 1e-2, 1e-1, 0.5, 1, 2],
249
        num_interactions: int = 1,
250
        synthetic_simulator_type: str = "linear",
251
        selection_type: str = "random",
252
        non_linearity_scale: float = 0,
253
        model_list: list = ["TLearner"]
254
    ) -> None:
255
256
        self.n_units_hidden = n_units_hidden
257
        self.n_layers = n_layers
258
        self.penalty_orthogonal = penalty_orthogonal
259
        self.batch_size = batch_size
260
        self.batch_norm = batch_norm
261
        self.n_iter = n_iter
262
        self.seed = seed
263
        self.explainer_limit = explainer_limit
264
        self.save_path = save_path
265
        self.predictive_scales = predictive_scales
266
        self.propensity_type = propensity_type
267
        self.num_interactions = num_interactions
268
        self.synthetic_simulator_type = synthetic_simulator_type
269
        self.selection_type = selection_type
270
        self.non_linearity_scale = non_linearity_scale
271
        self.model_list = model_list
272
273
    def run(
274
        self,
275
        dataset: str = "tcga_10",
276
        train_ratio: float = 0.8,
277
        num_important_features: int = 2,
278
        binary_outcome: bool = False,
279
        random_feature_selection: bool = True,
280
        explainer_list: list = [
281
            "feature_ablation",
282
            "feature_permutation",
283
            "integrated_gradients",
284
            "shapley_value_sampling",
285
        ],
286
        debug: bool = False,
287
        directory_path_: str = None,
288
    ) -> None:
289
        log.info(
290
            f"Using dataset {dataset} with num_important features = {num_important_features}."
291
        )
292
293
        X_raw_train, X_raw_test = load(dataset, train_ratio=train_ratio, debug=debug, directory_path_=directory_path_)
294
295
        if self.synthetic_simulator_type == "linear":
296
            sim = SyntheticSimulatorLinear(
297
                X_raw_train,
298
                num_important_features=num_important_features,
299
                random_feature_selection=random_feature_selection,
300
                seed=self.seed,
301
            )
302
        elif self.synthetic_simulator_type == "nonlinear":
303
            sim = SyntheticSimulatorModulatedNonLinear(
304
                X_raw_train,
305
                num_important_features=num_important_features,
306
                non_linearity_scale=self.non_linearity_scale,
307
                seed=self.seed,
308
                selection_type=self.selection_type,
309
            )
310
        else:
311
            raise Exception("Unknown simulator type.")
312
313
        explainability_data = []
314
315
        for predictive_scale in self.predictive_scales:
316
            log.info(f"Now working with predictive_scale = {predictive_scale}...")
317
            (
318
                X_train,
319
                W_train,
320
                Y_train,
321
                po0_train,
322
                po1_train,
323
                propensity_train,
324
            ) = sim.simulate_dataset(
325
                X_raw_train,
326
                predictive_scale=predictive_scale,
327
                binary_outcome=binary_outcome,
328
                treatment_assign=self.propensity_type,
329
            )
330
331
            X_test, W_test, Y_test, po0_test, po1_test, _ = sim.simulate_dataset(
332
                X_raw_test,
333
                predictive_scale=predictive_scale,
334
                binary_outcome=binary_outcome,
335
                treatment_assign=self.propensity_type,
336
            )
337
338
            log.info("Fitting and explaining learners...")
339
340
            learners = get_learners(
341
                model_list=self.model_list,
342
                X_train=X_train,
343
                Y_train=Y_train,
344
                n_iter=self.n_iter,
345
                batch_size=self.batch_size,
346
                batch_norm=self.batch_norm,
347
                discrete_outcome=binary_outcome
348
            )
349
350
            learner_explanations = get_learner_explanations(learners, 
351
                                                            X_test, X_train, Y_train, W_train, 
352
                                                            self.explainer_limit, explainer_list)
353
354
            all_important_features = sim.get_all_important_features(with_selective=True)
355
            pred_features = sim.get_predictive_features()
356
            prog_features = sim.get_prognostic_features()
357
            select_features = sim.get_selective_features()
358
359
            cate_test = sim.te(X_test)
360
361
            for explainer_name in explainer_list:
362
                for learner_name in learners:
363
                    attribution_est = np.abs(
364
                        learner_explanations[learner_name][explainer_name]
365
                    )
366
                    acc_scores_all_features = attribution_accuracy(
367
                        all_important_features, attribution_est
368
                    )
369
                    acc_scores_predictive_features = attribution_accuracy(
370
                        pred_features, attribution_est
371
                    )
372
                    acc_scores_prog_features = attribution_accuracy(
373
                        prog_features, attribution_est
374
                    )
375
                    acc_scores_selective_features = attribution_accuracy(
376
                        select_features, attribution_est
377
                    )
378
                    cate_pred = learners[learner_name].predict(X=X_test)
379
380
                    pehe_test = compute_pehe(cate_true=cate_test, cate_pred=cate_pred)
381
382
                    explainability_data.append(
383
                        [
384
                            predictive_scale,
385
                            learner_name,
386
                            explainer_name,
387
                            acc_scores_all_features,
388
                            acc_scores_predictive_features,
389
                            acc_scores_prog_features,
390
                            acc_scores_selective_features,
391
                            pehe_test,
392
                            np.mean(cate_test),
393
                            np.var(cate_test),
394
                            pehe_test / np.sqrt(np.var(cate_test)),
395
                        ]
396
                    )
397
398
        metrics_df = pd.DataFrame(
399
            explainability_data,
400
            columns=[
401
                "Predictive Scale",
402
                "Learner",
403
                "Explainer",
404
                "All features ACC",
405
                "Pred features ACC",
406
                "Prog features ACC",
407
                "Select features ACC",
408
                "PEHE",
409
                "CATE true mean",
410
                "CATE true var",
411
                "Normalized PEHE",
412
            ],
413
        )
414
415
        results_path = self.save_path / "results/predictive_sensitivity"
416
        log.info(f"Saving results in {results_path}...")
417
        if not results_path.exists():
418
            results_path.mkdir(parents=True, exist_ok=True)
419
420
        metrics_df.to_csv(
421
            results_path / f"predictive_scale_{dataset}_{num_important_features}_"
422
            f"{self.synthetic_simulator_type}_random_{random_feature_selection}_"
423
            f"binary_{binary_outcome}-seed{self.seed}.csv"
424
        )
425
426
427
class NonLinearitySensitivity:
428
    """
429
    Sensitivity analysis for nonlinearity in prognostic and predictive functions.
430
    """
431
432
    def __init__(
433
        self,
434
        n_units_hidden: int = 50,
435
        n_layers: int = 1,
436
        penalty_orthogonal: float = 0.01,
437
        batch_size: int = 1024,
438
        batch_norm: bool = False,
439
        n_iter: int = 1000,
440
        seed: int = 42,
441
        explainer_limit: int = 1000,
442
        save_path: Path = Path.cwd(),
443
        propensity_type: str = "pred",
444
        nonlinearity_scales: list = [0.0, 0.2, 0.5, 0.7, 1.0],
445
        selection_type: str = "random",
446
        predictive_scale: float = 1,
447
        synthetic_simulator_type: str = "random",
448
        model_list: list = ["TLearner"]
449
    ) -> None:
450
451
        self.n_units_hidden = n_units_hidden
452
        self.n_layers = n_layers
453
        self.penalty_orthogonal = penalty_orthogonal
454
        self.batch_size = batch_size
455
        self.batch_norm = batch_norm
456
        self.n_iter = n_iter
457
        self.seed = seed
458
        self.explainer_limit = explainer_limit
459
        self.save_path = save_path
460
        self.propensity_type = propensity_type
461
        self.nonlinearity_scales = nonlinearity_scales
462
        self.selection_type = selection_type
463
        self.predictive_scale = predictive_scale
464
        self.synthetic_simulator_type = synthetic_simulator_type
465
        self.model_list = model_list
466
467
    def run(
468
        self,
469
        dataset: str = "tcga_100",
470
        num_important_features: int = 15,
471
        explainer_list: list = [
472
            "feature_ablation",
473
            "feature_permutation",
474
            "integrated_gradients",
475
            "shapley_value_sampling",
476
        ],
477
        train_ratio: float = 0.8,
478
        binary_outcome: bool = False,
479
        debug=False,
480
        directory_path_: str = None,
481
482
    ) -> None:
483
        log.info(
484
            f"Using dataset {dataset} with num_important features = {num_important_features}."
485
        )
486
        X_raw_train, X_raw_test = load(dataset, train_ratio=train_ratio, debug=debug, directory_path_=directory_path_)
487
        explainability_data = []
488
489
        for nonlinearity_scale in self.nonlinearity_scales:
490
            log.info(f"Now working with a nonlinearity scale {nonlinearity_scale}...")
491
492
            if self.synthetic_simulator_type == "linear":
493
                raise Exception("Linear simulator not supported for nonlinearity sensitivity.")
494
            
495
            elif self.synthetic_simulator_type == "nonlinear":
496
                sim = SyntheticSimulatorModulatedNonLinear(
497
                    X_raw_train,
498
                    num_important_features=num_important_features,
499
                    non_linearity_scale=nonlinearity_scale,
500
                    seed=self.seed,
501
                    selection_type=self.selection_type
502
                )
503
            else:
504
                raise Exception("Unknown simulator type.")
505
            
506
            (
507
                X_train,
508
                W_train,
509
                Y_train,
510
                po0_train,
511
                po1_train,
512
                propensity_train,
513
            ) = sim.simulate_dataset(
514
                X_raw_train,
515
                predictive_scale=self.predictive_scale,
516
                binary_outcome=binary_outcome,
517
                treatment_assign=self.propensity_type,
518
            )
519
            X_test, W_test, Y_test, po0_test, po1_test, _ = sim.simulate_dataset(
520
                X_raw_test,
521
                predictive_scale=self.predictive_scale,
522
                binary_outcome=binary_outcome,
523
                treatment_assign=self.propensity_type,
524
            )
525
526
            log.info("Fitting and explaining learners...")
527
            learners = get_learners(
528
                model_list=self.model_list,
529
                X_train=X_train,
530
                Y_train=Y_train,
531
                n_iter=self.n_iter,
532
                batch_size=self.batch_size,
533
                batch_norm=False,
534
                discrete_outcome=binary_outcome
535
            )
536
537
            learner_explanations = get_learner_explanations(learners, 
538
                                                            X_test, X_train, Y_train, W_train, 
539
                                                            self.explainer_limit, explainer_list)
540
541
            all_important_features = sim.get_all_important_features(with_selective=True)
542
            pred_features = sim.get_predictive_features()
543
            prog_features = sim.get_prognostic_features()
544
            select_features = sim.get_selective_features()
545
546
            cate_test = sim.te(X_test)
547
548
            for explainer_name in explainer_list:
549
                for learner_name in learners:
550
                    attribution_est = np.abs(
551
                        learner_explanations[learner_name][explainer_name]
552
                    )
553
                    acc_scores_all_features = attribution_accuracy(
554
                        all_important_features, attribution_est
555
                    )
556
                    acc_scores_predictive_features = attribution_accuracy(
557
                        pred_features, attribution_est
558
                    )
559
                    acc_scores_prog_features = attribution_accuracy(
560
                        prog_features, attribution_est
561
                    )
562
                    acc_scores_selective_features = attribution_accuracy(
563
                        select_features, attribution_est
564
                    )
565
                    
566
567
                    cate_pred = learners[learner_name].predict(X=X_test)
568
569
                    pehe_test = compute_pehe(cate_true=cate_test, cate_pred=cate_pred)
570
571
                    explainability_data.append(
572
                        [
573
                            nonlinearity_scale,
574
                            learner_name,
575
                            explainer_name,
576
                            acc_scores_all_features,
577
                            acc_scores_predictive_features,
578
                            acc_scores_prog_features,
579
                            acc_scores_selective_features,
580
                            pehe_test,
581
                            np.mean(cate_test),
582
                            np.var(cate_test),
583
                            pehe_test / np.sqrt(np.var(cate_test)),
584
                        ]
585
                    )
586
587
        metrics_df = pd.DataFrame(
588
            explainability_data,
589
            columns=[
590
                "Nonlinearity Scale",
591
                "Learner",
592
                "Explainer",
593
                "All features ACC",
594
                "Pred features ACC",
595
                "Prog features ACC",
596
                "Select features ACC",
597
                "PEHE",
598
                "CATE true mean",
599
                "CATE true var",
600
                "Normalized PEHE",
601
            ],
602
        )
603
604
        results_path = (
605
            self.save_path
606
            / f"results/nonlinearity_sensitivity/{self.synthetic_simulator_type}"
607
        )
608
        log.info(f"Saving results in {results_path}...")
609
        if not results_path.exists():
610
            results_path.mkdir(parents=True, exist_ok=True)
611
612
        metrics_df.to_csv(
613
            results_path
614
            / f"{dataset}_{num_important_features}_binary_{binary_outcome}-seed{self.seed}.csv"
615
        )
616
617
618
class PropensitySensitivity:
619
    """
620
    Sensitivity analysis for confounding.
621
    """
622
623
    def __init__(
624
        self,
625
        n_units_hidden: int = 50,
626
        n_layers: int = 1,
627
        penalty_orthogonal: float = 0.01,
628
        batch_size: int = 1024,
629
        batch_norm: bool = False,
630
        n_iter: int = 1000,
631
        seed: int = 42,
632
        explainer_limit: int = 1000,
633
        save_path: Path = Path.cwd(),
634
        num_interactions: int = 1,
635
        synthetic_simulator_type: str = "linear",
636
        nonlinearity_scale: float = 0,
637
        selection_type: str = "random",
638
        propensity_type: str = "pred",
639
        propensity_scales: list = [0, 0.5, 1, 2, 5, 10],
640
        model_list: list = ["TLearner"]
641
    ) -> None:
642
643
        self.n_units_hidden = n_units_hidden
644
        self.n_layers = n_layers
645
        self.penalty_orthogonal = penalty_orthogonal
646
        self.batch_size = batch_size
647
        self.batch_norm = batch_norm
648
        self.n_iter = n_iter
649
        self.seed = seed
650
        self.explainer_limit = explainer_limit
651
        self.save_path = save_path
652
        self.num_interactions = num_interactions
653
        self.synthetic_simulator_type = synthetic_simulator_type
654
        self.nonlinearity_scale = nonlinearity_scale
655
        self.selection_type = selection_type
656
        self.propensity_type = propensity_type
657
        self.propensity_scales = propensity_scales
658
        self.model_list = model_list
659
660
    def run(
661
        self,
662
        dataset: str = "tcga_10",
663
        train_ratio: float = 0.8,
664
        num_important_features: int = 2,
665
        binary_outcome: bool = False,
666
        random_feature_selection: bool = True,
667
        predictive_scale: float = 1,
668
        nonlinearity_scale: float = 0.5,
669
        explainer_list: list = [
670
            "feature_ablation",
671
            "feature_permutation",
672
            "integrated_gradients",
673
            "shapley_value_sampling",
674
        ],
675
        debug: bool = False,
676
        directory_path_: str = None,
677
    ) -> None:
678
        log.info(
679
            f"Using dataset {dataset} with num_important features = {num_important_features} and predictive scale {predictive_scale}."
680
        )
681
682
        X_raw_train, X_raw_test = load(dataset, train_ratio=train_ratio, debug=debug, directory_path_=directory_path_)
683
684
        if self.synthetic_simulator_type == "linear":
685
            sim = SyntheticSimulatorLinear(
686
                X_raw_train,
687
                num_important_features=num_important_features,
688
                random_feature_selection=random_feature_selection,
689
                seed=self.seed,
690
            )
691
        elif self.synthetic_simulator_type == "nonlinear":
692
            sim = SyntheticSimulatorModulatedNonLinear(
693
                X_raw_train,
694
                num_important_features=num_important_features,
695
                non_linearity_scale=self.nonlinearity_scale,
696
                seed=self.seed,
697
                selection_type=self.selection_type,
698
            )
699
        else:
700
            raise Exception("Unknown simulator type.")
701
702
        explainability_data = []
703
704
        for propensity_scale in self.propensity_scales:
705
            log.info(f"Now working with propensity_scale = {propensity_scale}...")
706
            (
707
                X_train,
708
                W_train,
709
                Y_train,
710
                po0_train,
711
                po1_train,
712
                propensity_train,
713
            ) = sim.simulate_dataset(
714
                X_raw_train,
715
                predictive_scale=predictive_scale,
716
                binary_outcome=binary_outcome,
717
                treatment_assign=self.propensity_type,
718
                prop_scale=propensity_scale,
719
            )
720
721
            X_test, W_test, Y_test, po0_test, po1_test, _ = sim.simulate_dataset(
722
                X_raw_test,
723
                predictive_scale=predictive_scale,
724
                binary_outcome=binary_outcome,
725
                treatment_assign=self.propensity_type,
726
                prop_scale=propensity_scale,
727
            )
728
729
            
730
            log.info("Fitting and explaining learners...")
731
            learners = get_learners(
732
                model_list=self.model_list,
733
                X_train=X_train,
734
                Y_train=Y_train,
735
                n_iter=self.n_iter,
736
                batch_size=self.batch_size,
737
                batch_norm=self.batch_norm,
738
                discrete_outcome=binary_outcome
739
            )
740
741
            learner_explanations = get_learner_explanations(learners, 
742
                                                            X_test, X_train, Y_train, W_train, 
743
                                                            self.explainer_limit, explainer_list)
744
745
            all_important_features = sim.get_all_important_features(with_selective=False)
746
            pred_features = sim.get_predictive_features()
747
            prog_features = sim.get_prognostic_features()
748
749
            cate_test = sim.te(X_test)
750
751
            for explainer_name in explainer_list:
752
                for learner_name in learners:
753
                    attribution_est = np.abs(
754
                        learner_explanations[learner_name][explainer_name]
755
                    )
756
                    acc_scores_all_features = attribution_accuracy(
757
                        all_important_features, attribution_est
758
                    )
759
                    acc_scores_predictive_features = attribution_accuracy(
760
                        pred_features, attribution_est
761
                    )
762
                    acc_scores_prog_features = attribution_accuracy(
763
                        prog_features, attribution_est
764
                    )
765
766
                    cate_pred = learners[learner_name].predict(X=X_test)
767
                    pehe_test = compute_pehe(cate_true=cate_test, cate_pred=cate_pred)
768
769
                    explainability_data.append(
770
                        [
771
                            propensity_scale,
772
                            learner_name,
773
                            explainer_name,
774
                            acc_scores_all_features,
775
                            acc_scores_predictive_features,
776
                            acc_scores_prog_features,
777
                            pehe_test,
778
                            np.mean(cate_test),
779
                            np.var(cate_test),
780
                            pehe_test / np.sqrt(np.var(cate_test)),
781
                        ]
782
                    )
783
784
        metrics_df = pd.DataFrame(
785
            explainability_data,
786
            columns=[
787
                "Propensity Scale",
788
                "Learner",
789
                "Explainer",
790
                "All features ACC",
791
                "Pred features ACC",
792
                "Prog features ACC",
793
                "PEHE",
794
                "CATE true mean",
795
                "CATE true var",
796
                "Normalized PEHE",
797
            ],
798
        )
799
800
        results_path = (
801
            self.save_path
802
            / f"results/propensity_sensitivity/{self.synthetic_simulator_type}"
803
        )
804
        log.info(f"Saving results in {results_path}...")
805
        if not results_path.exists():
806
            results_path.mkdir(parents=True, exist_ok=True)
807
808
        metrics_df.to_csv(
809
            results_path / f"propensity_scale_{dataset}_{num_important_features}_"
810
            f"proptype_{self.propensity_type}_"
811
            f"predscl_{predictive_scale}_"
812
            f"nonlinscl_{nonlinearity_scale}_"
813
            f"trainratio_{train_ratio}_"
814
            f"binary_{binary_outcome}-seed{self.seed}.csv"
815
        )
816
817
818
class CohortSizeSensitivity:
819
    """
820
    Sensitivity analysis for varying numbers of samples. This experiment will generate a .csv with the recorded metrics.
821
    It will also generate a gif, showing the progression on dimensionality-reduced spaces.
822
    """
823
824
    def __init__(
825
        self,
826
        n_units_hidden: int = 50,
827
        n_layers: int = 1,
828
        penalty_orthogonal: float = 0.01,
829
        batch_size: int = 1024,
830
        batch_norm: bool = False,
831
        n_iter: int = 1000,
832
        seed: int = 42,
833
        explainer_limit: int = 1000,
834
        save_path: Path = Path.cwd(),
835
        propensity_type: str = "selective",
836
        cohort_sizes: list = [0.5, 0.7, 1.0],
837
        nonlinearity_scale: float = 0.5,
838
        predictive_scale: float = 1,
839
        synthetic_simulator_type: str = "random",
840
        selection_type: str = "random",
841
        model_list: list = ["TLearner"],
842
        num_cube_samples: int = 1000,
843
        dim_reduction_method: str = "umap",
844
        dim_reduction_on_important_features: bool = True,
845
        visualize_progression: bool = True,
846
    ) -> None:
847
848
        self.n_units_hidden = n_units_hidden
849
        self.n_layers = n_layers
850
        self.penalty_orthogonal = penalty_orthogonal
851
        self.batch_size = batch_size
852
        self.batch_norm = batch_norm
853
        self.n_iter = n_iter
854
        self.seed = seed
855
        self.explainer_limit = explainer_limit
856
        self.save_path = save_path
857
        self.propensity_type = propensity_type
858
        self.cohort_sizes = cohort_sizes
859
        self.nonlinearity_scale = nonlinearity_scale
860
        self.predictive_scale = predictive_scale
861
        self.synthetic_simulator_type = synthetic_simulator_type
862
        self.selection_type = selection_type
863
        self.model_list = model_list
864
        self.num_cube_samples = num_cube_samples
865
        self.dim_reduction_method = dim_reduction_method
866
        self.dim_reduction_on_important_features = dim_reduction_on_important_features
867
        self.visualize_progression = visualize_progression
868
869
    def run(
870
        self,
871
        dataset: str = "tcga_100",
872
        num_important_features: int = 15,
873
        explainer_list: list = [
874
            "feature_ablation",
875
            "feature_permutation",
876
            "integrated_gradients",
877
            "shapley_value_sampling",
878
        ],
879
        train_ratio: float = 0.8,
880
        binary_outcome: bool = False,
881
        debug=False,
882
        directory_path_: str = None,
883
884
    ) -> None:
885
        # Log setting
886
        log.info(
887
            f"Using dataset {dataset} with num_important features = {num_important_features}."
888
        )
889
890
        # Load data
891
        X_raw_train_full, X_raw_test_full = load(dataset, train_ratio=train_ratio, debug=debug, directory_path_=directory_path_)
892
        explainability_data = []
893
894
        # Simulate treatment and outcome for train and test
895
        sim = SyntheticSimulatorModulatedNonLinear(
896
                X_raw_train_full,
897
                num_important_features=num_important_features,
898
                non_linearity_scale=self.nonlinearity_scale,
899
                seed=self.seed,
900
                selection_type=self.selection_type
901
            )
902
        
903
        (
904
            X_train_full,
905
            W_train_full,
906
            Y_train_full,
907
            po0_train_full,
908
            po1_train_full,
909
            propensity_train_full
910
        ) = sim.simulate_dataset(
911
                X_raw_train_full,
912
                predictive_scale=self.predictive_scale,
913
                binary_outcome=binary_outcome,
914
                treatment_assign=self.propensity_type,
915
            )
916
        
917
        X_test_full, W_test_full, Y_test_full, po0_test_full, po1_test_full, _ = sim.simulate_dataset(
918
            X_raw_test_full,
919
            predictive_scale=self.predictive_scale,
920
            binary_outcome=binary_outcome,
921
            treatment_assign=self.propensity_type,
922
        )
923
924
        # Retrieve important features
925
        all_important_features = sim.get_all_important_features(with_selective=True)
926
        pred_features = sim.get_predictive_features()
927
        prog_features = sim.get_prognostic_features()
928
        select_features = sim.get_selective_features()
929
930
        # Code for sampling from hypercube -> does not work well because data lives in very small part of that hypercube
931
        # # Sample from a hypercube grid for the X_raw_train_full dataset - making a complete hypercube grid would be too large
932
        # # Sample for important features only as we know these are the ones that matter and will make sampling from a hypercube more meaningful
933
        # X_raw_train_full_important = X_raw_train_full[:, all_important_features]
934
        # min_vals = X_raw_train_full_important.min(axis=0)
935
        # max_vals = X_raw_train_full_important.max(axis=0)
936
937
        # # Add some relative padding to the min and max values
938
        # min_vals = min_vals - 0.1 * np.abs(min_vals)
939
        # max_vals = max_vals + 0.1 * np.abs(max_vals)
940
941
        # # Sample points
942
        # grid_samples = np.zeros((self.num_cube_samples, X_raw_train_full.shape[1]))
943
        # grid_samples[:, all_important_features] = np.random.uniform(min_vals, max_vals, (self.num_cube_samples, X_raw_train_full_important.shape[1]))
944
945
        if self.visualize_progression:
946
            # Use full training set with added noisy samples as focused samples of space
947
            std_dev = np.std(X_raw_train_full, axis=0)
948
            grid_samples = X_raw_train_full
949
            grid_samples = np.vstack([grid_samples,
950
                                        X_raw_train_full + std_dev*np.random.normal(0, 1, X_raw_train_full.shape),
951
                                        X_raw_train_full + 0.1*std_dev*np.random.normal(0, 1, X_raw_train_full.shape),
952
                                        X_raw_train_full + 0.1*std_dev*np.random.normal(0, 1, X_raw_train_full.shape),
953
                                        # X_raw_train_full + 0.3*std_dev*np.random.normal(0, 1, X_raw_train_full.shape),
954
                                        X_raw_train_full + 0.1*std_dev*np.random.normal(0, 1, X_raw_train_full.shape),])
955
956
            # Reduce samples to two dimensions for plotting using umap
957
            if self.dim_reduction_method == "umap":
958
                reducer = umap.UMAP(min_dist=1, n_neighbors=30, spread=1)
959
                reducer_shap = umap.UMAP(min_dist=3, n_neighbors=40, spread=4)
960
                reducer_shap_prop = umap.UMAP(min_dist=2, n_neighbors=30, spread=3)
961
962
            elif self.dim_reduction_method == "pca":
963
                reducer = PCA(n_components=2)
964
                reducer_shap = PCA(n_components=2)
965
                reducer_shap_prop = PCA(n_components=2)
966
967
            elif self.dim_reduction_method == "tsne":
968
                raise Exception("t-SNE not supported for this analysis. Does not offer .transform() method.")
969
970
            else:
971
                raise Exception("Unknown dimensionality reduction method.")
972
973
            # Fit on grid samples and training data
974
            if self.dim_reduction_on_important_features:
975
                grid_samples_2d = reducer.fit_transform(grid_samples[:, all_important_features])
976
                train_samples_2d = reducer.transform(X_raw_train_full[:, all_important_features])
977
            else:
978
                grid_samples_2d = reducer.fit_transform(grid_samples)
979
                train_samples_2d = reducer.transform(X_raw_train_full)
980
981
            # Get model learners (here only one) and explanations for grid samples
982
            learners = get_learners(
983
                    model_list=self.model_list,
984
                    X_train=X_train_full,
985
                    Y_train=Y_train_full,
986
                    n_iter=self.n_iter,
987
                    batch_size=self.batch_size,
988
                    batch_norm=False,
989
                    discrete_outcome=binary_outcome
990
                )
991
        
992
            learner_explanations, learners = get_learner_explanations(learners, 
993
                                                                grid_samples, X_train_full, Y_train_full, W_train_full, 
994
                                                                grid_samples.shape[0], explainer_list,
995
                                                                return_learners=True)
996
            
997
            learner_explanations_train = get_learner_explanations(learners,
998
                                                                X_train_full, X_train_full, Y_train_full, W_train_full,
999
                                                                X_train_full.shape[0], explainer_list,
1000
                                                                already_trained=True)
1001
            
1002
            # Get shap for grid samples
1003
            shap_values_grid = learner_explanations[self.model_list[0]][explainer_list[0]]
1004
            shap_values_train = learner_explanations_train[self.model_list[0]][explainer_list[0]]
1005
1006
            # Perform logistic regression for propensity
1007
            est_prop = LogisticRegression().fit(X_train_full, W_train_full)
1008
            explainer_prop = shap.LinearExplainer(est_prop, X_train_full)
1009
            shap_values_prop_grid = explainer_prop.shap_values(grid_samples)
1010
            shap_values_prop_train = explainer_prop.shap_values(X_train_full)
1011
            
1012
            # Fit umap reducer for shap grid samples 
1013
            if self.dim_reduction_on_important_features:
1014
                shap_values_grid_2d = reducer_shap.fit_transform(shap_values_grid[:, all_important_features])
1015
                shap_values_train_2d = reducer_shap.transform(shap_values_train[:, all_important_features])
1016
                shap_values_prop_grid_2d = reducer_shap_prop.fit_transform(shap_values_prop_grid[:, all_important_features])
1017
                shap_values_prop_train_2d = reducer_shap_prop.transform(shap_values_prop_train[:, all_important_features])
1018
            else:
1019
                shap_values_grid_2d = reducer_shap.fit_transform(shap_values_grid)
1020
                shap_values_train_2d = reducer_shap.transform(shap_values_train)
1021
                shap_values_prop_grid_2d = reducer_shap_prop.fit_transform(shap_values_prop_grid)
1022
                shap_values_prop_train_2d = reducer_shap_prop.transform(shap_values_prop_train)
1023
1024
            # Initialize variable for storing frames
1025
            frames = []  # To store each frame for the GIF
1026
        
1027
        cohort_size_full = X_train_full.shape[0]
1028
        for cohort_size_perc in self.cohort_sizes:
1029
            cohort_size = int(cohort_size_perc * cohort_size_full)
1030
1031
            # Get a subset of the training data
1032
            X_train = X_train_full[:cohort_size]
1033
            W_train = W_train_full[:cohort_size]
1034
            Y_train = Y_train_full[:cohort_size]
1035
            po0_train = po0_train_full[:cohort_size]
1036
            po1_train = po1_train_full[:cohort_size]
1037
            propensity_train = propensity_train_full[:cohort_size]
1038
            cohort_size_train = X_train.shape[0]
1039
1040
            # Get subsets for test data
1041
            X_test = X_test_full[:cohort_size]
1042
            W_test = W_test_full[:cohort_size]
1043
            Y_test = Y_test_full[:cohort_size]
1044
            po0_test = po0_test_full[:cohort_size]
1045
            po1_test = po1_test_full[:cohort_size]
1046
1047
            log.info(f"Now working with a cohort size of {cohort_size}/{X_train_full.shape[0]}...")
1048
            log.info("Fitting and explaining learners...")
1049
            learners = get_learners(
1050
                model_list=self.model_list,
1051
                X_train=X_train,
1052
                Y_train=Y_train,
1053
                n_iter=self.n_iter,
1054
                batch_size=self.batch_size,
1055
                batch_norm=False,
1056
                discrete_outcome=binary_outcome
1057
            )
1058
1059
            # Get learners and explanations for training data
1060
            learner_explanations, learners = get_learner_explanations(learners, 
1061
                                                            X_train, X_train, Y_train, W_train, 
1062
                                                            X_train.shape[0], explainer_list,
1063
                                                            return_learners=True)
1064
1065
            if self.visualize_progression:
1066
                # Make logistic regression for propensity
1067
                est_prop = LogisticRegression().fit(X_train, W_train)
1068
                explainer_prop = shap.LinearExplainer(est_prop, X_train)
1069
1070
                # Set up plot
1071
                fig, axs = plt.subplots(2, 4, figsize=(15, 10))
1072
                eff_grid = sim.te(grid_samples)
1073
                prop_grid = sim.prop(grid_samples)
1074
1075
                # Get cate estimator
1076
                est_eff = learners[self.model_list[0]]
1077
                cate_pred_train = est_eff.predict(X=X_train)
1078
1079
                # Get predictions for the first model
1080
                p_eff_grid = est_eff.predict(X=grid_samples)
1081
                p_prop_grid = est_prop.predict_proba(grid_samples)[:, 1]
1082
                outcomes = [p_prop_grid, prop_grid, p_eff_grid, eff_grid]
1083
                titles = ['prop(x)', 'prop_true(x)', 'cate(x)', 'cate_true(x)']
1084
1085
                # # Get X_Train in 2d
1086
                # shap_values_train = learner_explanations[self.model_list[0]][explainer_list[0]]
1087
                # shap_values_prop_train = explainer_prop.shap_values(X_train)
1088
1089
                # if self.dim_reduction_on_important_features:
1090
                #     X_train_2d = reducer.transform(X_train[:, all_important_features])
1091
                # else:
1092
                #     X_train_2d = reducer.transform(X_train)
1093
1094
                # # Get shap values of X_Train in 2d
1095
                # if self.dim_reduction_on_important_features:
1096
                #     shap_values_train_2d = reducer_shap.transform(shap_values_train[:, all_important_features])
1097
                #     shap_values_prop_train_2d = reducer_shap_prop.transform(shap_values_prop_train[:, all_important_features])
1098
                # else:
1099
                #     shap_values_train_2d = reducer_shap.transform(shap_values_train)
1100
                #     shap_values_prop_train_2d = reducer_shap_prop.transform(shap_values_prop_train)
1101
1102
1103
                # If there are any non-finite elements in the outcomes, remove them from the grid_samples and the other outcomes and raise a warning
1104
                for i, outcome in enumerate(outcomes):
1105
                    if type(outcome) == torch.Tensor:
1106
                        outcome = outcome.cpu().detach().numpy()
1107
1108
                    outcome = np.array(outcome)
1109
                    if not np.all(np.isfinite(outcome)):
1110
                        print("-----------")
1111
                        print(i, np.sum(~np.isfinite(outcome)))
1112
                        log.warning(f'Found non-finite elements in outcomes. Removing {np.sum(~np.isfinite(outcome))} elements.')
1113
                        mask = np.isfinite(outcome)
1114
                        grid_samples_2d = grid_samples_2d[mask]
1115
                        for j in range(len(outcomes)):
1116
                            outcomes[j] = outcomes[j][mask]
1117
                        break
1118
1119
                for j, outcome in enumerate(outcomes):
1120
                    if type(outcome) == torch.Tensor:
1121
                        outcome = outcome.cpu().detach().numpy()
1122
1123
                    # Plot settings
1124
                    cmap = "viridis"
1125
                    s = 15 # 4 for many samples
1126
                    alpha = None
1127
                    edgecolors = "w"
1128
                    linewidths = 0.2
1129
                    
1130
                    # Plot contours
1131
                    if j == 0 or j == 1:
1132
                        tcf = axs[0][j].tricontourf(grid_samples_2d[:,0], 
1133
                                                    grid_samples_2d[:,1], 
1134
                                                    outcome.ravel(), 15, cmap=cmap, levels=50)
1135
                        
1136
                        tcf_shap = axs[1][j].tricontourf(shap_values_prop_grid_2d[:,0], 
1137
                                                        shap_values_prop_grid_2d[:,1], 
1138
                                                        outcome.ravel(), 15, cmap=cmap, levels=50)
1139
                    else:
1140
                        tcf = axs[0][j].tricontourf(grid_samples_2d[:,0], 
1141
                                                    grid_samples_2d[:,1], 
1142
                                                    outcome.ravel(), 15, cmap=cmap, levels=50)
1143
                        
1144
                        tcf_shap = axs[1][j].tricontourf(shap_values_grid_2d[:,0], 
1145
                                                        shap_values_grid_2d[:,1], 
1146
                                                        outcome.ravel(), 15, cmap=cmap, levels=50)
1147
1148
                    
1149
                    # Version: Plot in shape space from current model
1150
                    
1151
                    # if j == 0:
1152
                    #     axs[0][j].scatter(X_train_2d[:,0], X_train_2d[:,1], c=W_train, cmap='coolwarm', edgecolors=edgecolors, s=s, label='Training data for a', alpha=alpha)
1153
                    #     axs[1][j].scatter(shap_values_prop_train_2d[:,0], shap_values_prop_train_2d[:,1], c=W_train, cmap='coolwarm', edgecolors=edgecolors, s=s, alpha=alpha)
1154
                    #     #fig.colorbar(tcf)
1155
                    # if j == 2:
1156
                    #     axs[0][j].scatter(X_train_2d[:,0], X_train_2d[:,1], c=cate_pred_train, cmap='coolwarm', edgecolors=edgecolors, s=s, label='Training data for a', alpha=alpha)
1157
                    #     axs[1][j].scatter(shap_values_train_2d[:,0], shap_values_train_2d[:,1], c=Y_train, cmap='coolwarm', edgecolors=edgecolors, s=s, alpha=alpha)
1158
                    #     #fig.colorbar(tcf_shap)
1159
1160
                    # Version: Always plot in same shap space from full model
1161
1162
                    if j == 0:
1163
                        axs[0][j].scatter(train_samples_2d[:cohort_size_train,0], train_samples_2d[:cohort_size_train,1], 
1164
                                        c=W_train, cmap=cmap, s=s, edgecolors=edgecolors, linewidths=linewidths, alpha=0.5)
1165
                        
1166
                        axs[1][j].scatter(shap_values_prop_train_2d[:cohort_size_train,0], shap_values_prop_train_2d[:cohort_size_train,1], 
1167
                                        c=W_train, cmap=cmap, s=s, edgecolors=edgecolors, linewidths=linewidths, alpha=0.5)
1168
                        #fig.colorbar(tcf)
1169
                    if j == 2:
1170
                        axs[0][j].scatter(train_samples_2d[:cohort_size_train,0], train_samples_2d[:cohort_size_train,1], 
1171
                                        c=cate_pred_train, cmap=cmap, s=s, edgecolors=edgecolors, linewidths=linewidths)
1172
                        
1173
                        axs[1][j].scatter(shap_values_train_2d[:cohort_size_train,0], shap_values_train_2d[:cohort_size_train,1], 
1174
                                        c=cate_pred_train, cmap=cmap, s=s, edgecolors=edgecolors, linewidths=linewidths)
1175
                        #fig.colorbar(tcf_shap)
1176
1177
1178
                    # axs[0][j].set_title(f'{titles[j]} (N={cohort_size})_data_space')
1179
                    # axs[0][j].set_xlim([grid_samples_2d[:,0].min(), grid_samples_2d[:,0].max()])
1180
                    # axs[0][j].set_ylim([grid_samples_2d[:,1].min(), grid_samples_2d[:,1].max()])
1181
                    # axs[0][j].legend()
1182
1183
                    # if j == 0 or j == 1:
1184
                    #     axs[1][j].set_title(f'{titles[j]} (N={cohort_size})_shap_prop_space')
1185
                    #     axs[1][j].set_xlim([shap_values_prop_grid_2d[:,0].min(), shap_values_prop_grid_2d[:,0].max()])
1186
                    #     axs[1][j].set_ylim([shap_values_prop_grid_2d[:,1].min(), shap_values_prop_grid_2d[:,1].max()])
1187
                    #     axs[1][j].legend()
1188
                    # else: 
1189
                    #     axs[1][j].set_title(f'{titles[j]} (N={cohort_size})_shap_space')
1190
                    #     axs[1][j].set_xlim([shap_values_grid_2d[:,0].min(), shap_values_grid_2d[:,0].max()])
1191
                    #     axs[1][j].set_ylim([shap_values_grid_2d[:,1].min(), shap_values_grid_2d[:,1].max()])
1192
                    #     axs[1][j].legend()
1193
1194
                # Save the plot to a buffer
1195
                plt.tight_layout()
1196
                plt.savefig('temp_plot.png')
1197
                plt.close()
1198
                frames.append(imageio.imread('temp_plot.png'))
1199
1200
            for explainer_name in explainer_list:
1201
                for learner_name in learners:
1202
                    attribution_est = np.abs(
1203
                        learner_explanations[learner_name][explainer_name]
1204
                    )
1205
                    acc_scores_all_features = attribution_accuracy(
1206
                        all_important_features, attribution_est
1207
                    )
1208
                    acc_scores_predictive_features = attribution_accuracy(
1209
                        pred_features, attribution_est
1210
                    )
1211
                    acc_scores_prog_features = attribution_accuracy(
1212
                        prog_features, attribution_est
1213
                    )
1214
                    acc_scores_selective_features = attribution_accuracy(
1215
                        select_features, attribution_est
1216
                    )
1217
                    
1218
1219
                    cate_pred = learners[learner_name].predict(X=X_test)
1220
                    cate_test = sim.te(X_test)
1221
                    pehe_test = compute_pehe(cate_true=cate_test, cate_pred=cate_pred)
1222
1223
                    explainability_data.append(
1224
                        [
1225
                            cohort_size,
1226
                            cohort_size_perc,
1227
                            self.nonlinearity_scale,
1228
                            learner_name,
1229
                            explainer_name,
1230
                            acc_scores_all_features,
1231
                            acc_scores_predictive_features,
1232
                            acc_scores_prog_features,
1233
                            acc_scores_selective_features,
1234
                            pehe_test,
1235
                            np.mean(cate_test),
1236
                            np.var(cate_test),
1237
                            pehe_test / np.sqrt(np.var(cate_test)),
1238
                        ]
1239
                    )
1240
1241
        metrics_df = pd.DataFrame(
1242
            explainability_data,
1243
            columns=[
1244
                "Cohort Size",
1245
                "Cohort Size Perc",
1246
                "Nonlinearity Scale",
1247
                "Learner",
1248
                "Explainer",
1249
                "All features ACC",
1250
                "Pred features ACC",
1251
                "Prog features ACC",
1252
                "Select features ACC",
1253
                "PEHE",
1254
                "CATE true mean",
1255
                "CATE true var",
1256
                "Normalized PEHE",
1257
            ],
1258
        )
1259
1260
        results_path = (
1261
            self.save_path
1262
            / f"results/cohort_size_sensitivity/{self.synthetic_simulator_type}"
1263
        )
1264
1265
        log.info(f"Saving results in {results_path}...")
1266
        if not results_path.exists():
1267
            results_path.mkdir(parents=True, exist_ok=True)
1268
1269
        metrics_df.to_csv(
1270
            results_path
1271
            / f"{dataset}_{num_important_features}_binary_{binary_outcome}-seed{self.seed}.csv"
1272
        )
1273
1274
        if self.visualize_progression:
1275
            # Create GIF
1276
            imageio.mimsave("progression.gif", frames, fps = 1)
1277
            imageio.mimsave(results_path / "progression.gif", frames, fps=1)  # Set fps=1 for slower transition to observe changes clearly
1278