a b/src/iterpretability/simulators.py
1
# stdlib
2
import random
3
from typing import Tuple
4
import src.iterpretability.logger as log
5
6
# third party
7
import numpy as np
8
import torch
9
from scipy.special import expit
10
from scipy.stats import zscore
11
from omegaconf import DictConfig, OmegaConf
12
from src.iterpretability.utils import enable_reproducible_results
13
from abc import ABC, abstractmethod
14
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
16
# For computing the propensities from scores
17
from scipy.special import softmax
18
from scipy.stats import zscore
19
from sklearn.model_selection import train_test_split
20
21
EPS = 0
22
class SimulatorBase(ABC):
23
    """
24
    Base class for simulators.
25
    """
26
    @abstractmethod
27
    def simulate(self, X: np.ndarray, outcomes: np.ndarray = None) -> Tuple:
28
        raise NotImplementedError
29
    
30
    @abstractmethod
31
    def get_simulated_data(self, train_ratio: float) -> Tuple:
32
        raise NotImplementedError
33
    
34
    @property
35
    @abstractmethod
36
    def selective_features(self) -> np.ndarray:
37
        raise NotImplementedError
38
    
39
    @property
40
    @abstractmethod
41
    def prognostic_features(self) -> np.ndarray:
42
        raise NotImplementedError
43
    
44
    @property
45
    @abstractmethod
46
    def predictive_features(self) -> np.ndarray:
47
        raise NotImplementedError
48
    
49
class TYSimulator(SimulatorBase):
50
    """
51
    Data generation process class for simulating treatment selection and outcomes (and effects)
52
    """
53
    nonlinear_fcts = [
54
                #lambda x: np.abs(x),
55
                lambda x: np.exp(-(x**2) / 2),
56
                #  lambda x: 1 / (1 + x**2),
57
                # lambda x: np.sqrt(x)*(1+x),
58
                #lambda x: np.cos(5*x),
59
                #lambda x: x**2,
60
                # lambda x: np.arctan(x),
61
                # lambda x: np.tanh(x),
62
                # lambda x: np.sin(x),
63
                # lambda x: np.log(1 + x**2),
64
                #lambda x: np.sqrt(0.02 + x**2),
65
                #lambda x: np.cosh(x),
66
            ]
67
    
68
    def __init__(
69
        self,
70
        # Data dimensionality
71
        dim_X: int,
72
73
        # Seed
74
        seed: int = 42,
75
76
        # Simulation type
77
        simulation_type: str = "ty",
78
79
        # Dimensionality of treatments and outcome
80
        num_binary_outcome: int = 0,
81
        outcome_unbalancedness_ratio: float = 0,
82
        standardize_outcome: bool = False,
83
        num_T: int = 3,
84
        dim_Y: int = 3,
85
86
        # Scale parameters
87
        predictive_scale: float = 1,
88
        prognostic_scale: float = 1,
89
        propensity_scale: float = 1,
90
        unbalancedness_exp: float = 0,
91
        nonlinearity_scale: float = 1,
92
        propensity_type: str = "prog_pred",
93
        alpha: float = 0.5,
94
        enforce_balancedness: bool = False,
95
96
        # Control
97
        include_control: bool = False,
98
99
        # Important features
100
        num_pred_features: int = 5,
101
        num_prog_features: int = 5,
102
        num_select_features: int = 5,
103
        feature_type_overlap: str = "sel_none",
104
        treatment_feature_overlap: bool = False,
105
106
        # Feature selection
107
        random_feature_selection: bool = False,
108
        nonlinearity_selection_type: bool = True,
109
110
        # Noise
111
        noise: bool = True,
112
        noise_std: float = 0.1,
113
        
114
    ) -> None:
115
        # Number of features
116
        self.dim_X = dim_X
117
118
        # Make sure results are reproducible by setting seed for np, torch, random
119
        self.seed = seed
120
        enable_reproducible_results(seed=self.seed)
121
122
        # Simulation type
123
        self.simulation_type = simulation_type
124
125
        # Store dimensions
126
        self.num_binary_outcome = num_binary_outcome
127
        self.outcome_unbalancedness_ratio = outcome_unbalancedness_ratio
128
        self.standardize_outcome = standardize_outcome
129
        self.num_T = num_T
130
        self.dim_Y = dim_Y
131
132
        # Scale parameters
133
        self.predictive_scale = predictive_scale
134
        self.prognostic_scale = prognostic_scale
135
        self.propensity_scale = propensity_scale
136
        self.unbalancedness_exp = unbalancedness_exp
137
        self.nonlinearity_scale = nonlinearity_scale
138
        self.propensity_type = propensity_type
139
        self.alpha = alpha
140
        self.enforce_balancedness = enforce_balancedness
141
142
        # Control
143
        self.include_control = include_control
144
145
        # Important features
146
        self.num_pred_features = num_pred_features
147
        self.num_prog_features = num_prog_features
148
        self.num_select_features = num_select_features
149
        self.num_important_features = self.num_T*(num_pred_features + num_select_features) + num_prog_features
150
        self.feature_type_overlap = feature_type_overlap
151
        self.treatment_feature_overlap = treatment_feature_overlap
152
153
        # Feature selection
154
        self.random_feature_selection = random_feature_selection
155
        self.nonlinearity_selection_type = nonlinearity_selection_type
156
157
        # Noise
158
        self.noise = noise
159
        self.noise_std = noise_std
160
161
        # Setup variables
162
        self.nonlinearities = None
163
        self.prog_mask, self.pred_masks, self.select_masks = None, None, None
164
        self.prog_weights, self.pred_weights, self.select_weights = None, None, None
165
166
        # Setup
167
        self.setup()
168
169
        # Simulation variables
170
        self.X = None
171
        self.prog_scores, self.pred_scores, self.select_scores = None, None, None
172
        self.select_scores_pred_overlap = None
173
        self.select_scores_prog_overlap = None
174
        self.propensities, self.outcomes, self.T, self.Y = None, None, None, None
175
    
176
    def get_simulated_data(self):
177
        """
178
        Extract results and split into training and test set. Include counterfactual outcomes and propensities.
179
        """
180
        return self.X, self.T, self.Y, self.outcomes, self.propensities
181
182
        ## OLD CODE
183
        # Split data
184
        # train_size = int(train_ratio * self.X.shape[0])
185
186
        # if self.num_binary_outcome > 0:
187
        #     (
188
        #         X_train, X_test, 
189
        #         Y_train, Y_test, 
190
        #         T_train, T_test,
191
        #         outcomes_train, outcomes_test,
192
        #         propensities_train, propensities_test,
193
        #     ) = train_test_split(self.X, self.Y, self.T, self.outcomes, self.propensities, train_size=train_size, stratify=self.Y)
194
        # else:
195
        #     X_train, X_test = self.X[:train_size], self.X[train_size:]
196
        #     T_train, T_test = self.T[:train_size], self.T[train_size:]
197
        #     Y_train, Y_test = self.Y[:train_size], self.Y[train_size:]
198
199
        #     outcomes_train, outcomes_test = self.outcomes[:train_size,:,:], self.outcomes[train_size:,:,:]
200
        #     propensities_train, propensities_test = self.propensities[:train_size], self.propensities[train_size:]
201
202
        # if train_ratio == 1:
203
        #     return self.X, self.T, self.Y, self.outcomes, self.propensities
204
        
205
        # return X_train, X_test, T_train, T_test, Y_train, Y_test, outcomes_train, outcomes_test, propensities_train, propensities_test
206
207
    def simulate(self, X, outcomes=None) -> Tuple:
208
        """
209
        Simulate treatment and outcome for a dataset based on the configuration.
210
        """
211
        log.debug(
212
            f'Simulating treatment and outcome for a dataset with:'
213
            f'\n==================================================================='
214
            f'\nDim X: {self.dim_X}'
215
            f'\nDim T: {self.num_T}'
216
            f'\nDim Y: {self.dim_Y}'
217
            f'\nPredictive Scale: {self.predictive_scale}'
218
            f'\nPrognostic Scale: {self.prognostic_scale}'
219
            f'\nPropensity Scale: {self.propensity_scale}'
220
            f'\nUnbalancedness Exponent: {self.unbalancedness_exp}'
221
            f'\nNonlinearity Scale: {self.nonlinearity_scale}'
222
            f'\nNum Pred Features: {self.num_pred_features}'
223
            f'\nNum Prog Features: {self.num_prog_features}'
224
            f'\nNum Select Features: {self.num_select_features}'
225
            f'\nFeature Overlap: {self.treatment_feature_overlap}'
226
            f'\nRandom Feature Selection: {self.random_feature_selection}'
227
            f'\nNonlinearity Selection Type: {self.nonlinearity_selection_type}'
228
            f'\nNoise: {self.noise}'
229
            f'\nNoise Std: {self.noise_std}'
230
            f'\n===================================================================\n'
231
        )
232
233
        # 1. Store data with min max scaling to range [0, 1]
234
        self.X = X
235
        # self.X = (X - X.min(axis=0)) / (X.max(axis=0) - X.min(axis=0) + EPS)
236
237
        # 2. Compute scores for prognostic, predictive, and selective features
238
        self.compute_scores()
239
240
        # 3. Compute factual and counterfactual outcomes based on the data and the predictive and prognostic scores
241
        self.compute_all_outcomes()
242
243
        # 4. Compute propensities based on the data and the selective scores
244
        self.compute_propensities()
245
246
        # 5. Sample treatment assignment based on the propensities
247
        self.sample_T()
248
249
        # 6. Extract the outcome based on the treatment assignment
250
        self.extract_Y()
251
252
        return None
253
    
254
    def setup(self) -> None:
255
        """
256
        Setup the simulator by defining variables which remain the same across simulations with different samples but the same configuration.
257
        """
258
        # 1. Sample nonlinearities used 
259
        num_nonlinearities = 2 + self.dim_Y # Different non-linearities for each outcome (predictive), same for all treatments
260
        self.nonlinearities = self.sample_nonlinearities(num_nonlinearities)
261
262
        # 2. Set important feature masks - determine which features should be used for treatment selection, outcome prediction
263
        self.sample_important_feature_masks()
264
265
        # 3. Sample weights for features
266
        self.sample_uniform_weights()
267
268
    def get_true_cates(self, 
269
                       X: np.ndarray, 
270
                       T: np.ndarray, 
271
                       outcomes: np.ndarray) -> np.ndarray:
272
        """
273
        Compute true CATEs for each treatment based on the data and the outcomes.
274
        Always use the selected treatment as the base treatment.
275
        """
276
        # Compute CATEs for each treatment
277
        cates = np.zeros((X.shape[0], self.num_T, self.dim_Y))
278
279
        for i in range(X.shape[0]):
280
            for j in range(self.num_T):
281
                cates[i,j,:] = outcomes[i,j,:] - outcomes[i,int(T[i]),:]
282
283
        log.debug(
284
            f'\nCheck if true CATEs are computed correctly:'
285
            f'\n==================================================================='
286
            f'\nOutcomes: {outcomes.shape}'
287
            f'\n{outcomes}'
288
            f'\n\nTreatment Assignment: {T.shape}'
289
            f'\n{T}'
290
            f'\n\nTrue CATEs: {cates.shape}'
291
            f'\n{cates}'
292
            f'\n===================================================================\n'
293
        )
294
295
        return cates
296
    
297
    def extract_Y(self) -> None:
298
        """
299
        Extract the outcome based on the treatment assignment.
300
        """
301
        self.Y = self.outcomes[np.arange(self.X.shape[0]), self.T] 
302
303
        log.debug(
304
            f'\nCheck if outcomes are extracted correctly:'
305
            f'\n==================================================================='
306
            f'\nOutcomes'
307
            f'\n{self.outcomes}'
308
            f'\n{self.outcomes.shape}'
309
            f'\n\nTreatment Assignment'
310
            f'\n{self.T}'
311
            f'\n{self.T.shape}'
312
            f'\n\nExtracted Outcomes'
313
            f'\n{self.Y}'
314
            f'\n{self.Y.shape}'
315
            f'\n===================================================================\n'
316
        )
317
318
        return None
319
320
    def compute_all_outcomes_toy(self) -> None:
321
        # Compute outcomes for each treatment and outcome
322
        outcomes = np.zeros((self.X.shape[0], self.num_T, self.dim_Y))
323
        X0 = self.X[:,0]
324
        X1 = self.X[:,1]
325
326
        k=20
327
        nonlinearity = lambda x: 1 / (1 + np.exp(-k * (x - 0.5))) #logistic
328
        
329
        if self.propensity_type.startswith("toy1") or self.propensity_type.startswith("toy3") or self.propensity_type.startswith("toy4"):
330
            fun_y0 = lambda X0, X1: X0
331
            fun_y1 = lambda X0, X1: 1-X0
332
333
        elif self.propensity_type.startswith("toy2"):
334
            fun_y0 = lambda X0, X1: X0
335
            fun_y1 = lambda X0, X1: 1-X1
336
337
        elif self.propensity_type.startswith("toy6"):
338
            fun_y0 = lambda X0, X1: X0
339
            fun_y1 = lambda X0, X1: X1
340
341
        elif self.propensity_type.startswith("toy5"):
342
            fun_y0 = lambda X0, X1: np.sin(X0*10*np.pi)
343
            fun_y1 = lambda X0, X1: np.sin((1-X0)*10*np.pi)
344
345
        elif self.propensity_type.startswith("toy7"):
346
            fun_y0 = lambda X0, X1: nonlinearity(X0)-nonlinearity(X1)
347
            fun_y1 = lambda X0, X1: nonlinearity(X0)+nonlinearity(X1)
348
349
        elif self.propensity_type.startswith("toy8"):
350
            fun_y0 = lambda X0, X1: X0
351
            fun_y1 = lambda X0, X1: 1-X0
352
353
        Y = np.array([fun_y0(X0, X1),fun_y1(X0, X1)]).T
354
355
        if self.propensity_type.endswith("nonlinear"):
356
            Y = nonlinearity(Y)
357
358
        Y = zscore(Y, axis=None)
359
360
        outcomes[:,:,0] = Y
361
362
363
        return outcomes
364
365
    def compute_all_outcomes(self) -> None:
366
        """
367
        Compute factual and counterfactual outcomes based on the data and the predictive and prognostic scores.
368
        """
369
        if self.propensity_type.startswith("toy"):
370
            outcomes = self.compute_all_outcomes_toy()
371
372
        else:
373
            # Compute outcomes for each treatment and outcome
374
            outcomes = np.zeros((self.X.shape[0], self.num_T, self.dim_Y))
375
376
            for i in range(self.num_T):
377
                for j in range(self.dim_Y):
378
                    if self.include_control and i == 0:
379
                        outcomes[:,i,j] = self.prognostic_scale*self.prog_scores[:,j]
380
381
                    else:
382
                        outcomes[:,i,j] = self.prognostic_scale*self.prog_scores[:,j] + self.predictive_scale*self.pred_scores[:,i,j]
383
384
        # Add gaussian noise to outcomes
385
        if self.noise:
386
            outcomes = outcomes + np.random.normal(0, self.noise_std, size=outcomes.shape)
387
388
        # Create binary outcomes and introduce unbalancedness
389
        if int(self.num_binary_outcome) > 0:
390
            for j in range(self.num_binary_outcome):
391
                scores = zscore(outcomes[:,:,j], axis=0)
392
                prob = expit(scores)
393
                outcomes[:,:,j] = prob > self.outcome_unbalancedness_ratio
394
                
395
        self.outcomes = outcomes
396
397
        # Standardize outcomes
398
        if self.standardize_outcome:
399
            # normalize outcomes per outcome
400
            self.outcomes = zscore(self.outcomes, axis=0)
401
402
        log.debug(
403
            f'\nCheck if outcomes are computed correctly:'
404
            f'\n==================================================================='
405
            f'\nProg Scores'
406
            f'\n{self.prog_scores}'
407
            f'\n{self.prog_scores.shape}'
408
            f'\n\nPred Scores'
409
            f'\n{self.pred_scores}'
410
            f'\n{self.pred_scores.shape}'
411
            f'\n\nOutcomes'
412
            f'\n{self.outcomes}'
413
            f'\n{self.outcomes.shape}'
414
            f'\n\nMean Outcomes'
415
            f'\n{self.outcomes.mean(axis=0)}'
416
            f'\n\nVariance Outcomes'
417
            f'\n{self.outcomes.var(axis=0)}'
418
            f'\n===================================================================\n'
419
        )
420
421
        return None
422
423
424
    def sample_T(self) -> None:
425
        """
426
        Sample treatment assignment based on the propensities.
427
        """
428
        # Sample from the resulting categorical distribution per row
429
        self.T = np.array([np.random.choice([tre for tre in range(self.propensities.shape[1])], p=row) for row in self.propensities])
430
431
        log.debug(
432
            f'\nCheck if treatment assignment is sampled correctly:'
433
            f'\n==================================================================='
434
            f'\nPropensities'
435
            f'\n{self.propensities}'
436
            f'\n{self.propensities.shape}'
437
            f'\n\nTreatment Assignment'
438
            f'\n{self.T}'
439
            f'\n{self.T.shape}'
440
            f'\n\nUnique Treatment Counts'
441
            f'\n{np.unique(self.T, return_counts=True)}'
442
            f'\n===================================================================\n'
443
        )
444
445
        return None
446
    
447
    def get_unbalancedness_weights(self, size: int) -> np.ndarray:
448
        """
449
        Create weights for introducing unbalancedness for class probabilities.
450
        """
451
        # Sample initial distribution of treatment assignment 
452
        unb_weights = np.random.uniform(0, 1, size=size) 
453
        unb_weights = unb_weights / unb_weights.sum()
454
455
        # Standardize the weights and make sure that a treatment doesn't completely disappear for small unbalancedness exponents
456
        min_val = unb_weights.min()
457
        range_val = unb_weights.max() - min_val
458
        unb_weights = (unb_weights - min_val) / range_val
459
        unb_weights = 0.01 + unb_weights * 0.98
460
461
        return unb_weights
462
    
463
    def compute_propensity_scores_toy(self) -> np.ndarray:
464
        X0 = self.X[:,0]
465
        X1 = self.X[:,1]
466
467
        if self.propensity_type.startswith("toy1"):
468
            fun_t0 = lambda X0, X1: X0
469
            fun_t1 = lambda X0, X1: 1-X0
470
471
        elif self.propensity_type.startswith("toy2"):
472
            fun_t0 = lambda X0, X1: X0
473
            fun_t1 = lambda X0, X1: 1-X1
474
475
        elif self.propensity_type.startswith("toy3"):
476
            fun_t0 = lambda X0, X1: X1
477
            fun_t1 = lambda X0, X1: 1-X1
478
479
        elif self.propensity_type.startswith("toy4"):
480
            fun_t0 = lambda X0, X1: np.sin(X0*10*np.pi)
481
            fun_t1 = lambda X0, X1: np.sin((1-X0)*10*np.pi)
482
483
        elif self.propensity_type.startswith("toy5"):
484
            fun_t0 = lambda X0, X1: 1-X0
485
            fun_t1 = lambda X0, X1: X0
486
487
        elif self.propensity_type.startswith("toy6"):
488
            fun_t0 = lambda X0, X1: 1-X0
489
            fun_t1 = lambda X0, X1: X0
490
491
        elif self.propensity_type.startswith("toy7"):
492
            fun_t0 = lambda X0, X1: 1-X0
493
            fun_t1 = lambda X0, X1: X0
494
495
        elif self.propensity_type.startswith("toy8"):
496
            fun_t0 = lambda X0, X1: 1-X0
497
            fun_t1 = lambda X0, X1: X0
498
499
        scores = np.array([fun_t0(X0, X1),fun_t1(X0, X1)]).T
500
501
        return scores
502
503
504
    def compute_propensities(self) -> None:
505
        """
506
        Compute propensities based on the data and the selective scores.
507
        """
508
        
509
        select_scores_pred_overlap = zscore(self.select_scores_pred_overlap, axis=0) # Comment for Predictive Epertise
510
        select_scores_prog_overlap = zscore(self.select_scores_prog_overlap, axis=0) # Comment for Predictive Epertise
511
        select_scores_none = zscore(self.select_scores, axis=0) # Comment for Predictive Epertise
512
513
        select_scores_pred = np.zeros((self.X.shape[0], self.num_T))
514
        select_scores_pred_flipped = np.zeros((self.X.shape[0], self.num_T))
515
        select_scores_prog = np.zeros((self.X.shape[0], self.num_T))
516
        select_scores_tre = np.zeros((self.X.shape[0], self.num_T))
517
518
        select_scores_pred[:,0] = self.outcomes[:,0,0] - self.outcomes[:,1,0]
519
        select_scores_pred[:,1] = self.outcomes[:,1,0] - self.outcomes[:,0,0]
520
521
        select_scores_pred_flipped[:,0] = self.outcomes[:,1,0] - self.outcomes[:,0,0]
522
        select_scores_pred_flipped[:,1] = self.outcomes[:,0,0] - self.outcomes[:,1,0]
523
524
        select_scores_prog[:,0] = self.outcomes[:,0,0]
525
        select_scores_prog[:,1] = -self.outcomes[:,0,0]
526
527
        select_scores_tre[:,0] = -self.outcomes[:,1,0]
528
        select_scores_tre[:,1] = self.outcomes[:,1,0]
529
530
        if self.propensity_type == "prog_tre":
531
            scores = self.alpha * select_scores_tre + (1 - self.alpha) * select_scores_prog
532
533
        # Standardize all scores
534
        select_scores_pred = zscore(select_scores_pred, axis=0)
535
        select_scores_pred_flipped = zscore(select_scores_pred_flipped, axis=0)
536
        select_scores_prog = zscore(select_scores_prog, axis=0)
537
        select_scores_tre = zscore(select_scores_tre, axis=0)
538
539
        if self.propensity_type == "prog_pred":
540
            scores = self.alpha * select_scores_pred + (1 - self.alpha) * select_scores_prog
541
542
        elif self.propensity_type == "prog_tre":
543
            pass
544
545
        elif self.propensity_type == "none_prog":
546
            scores = self.alpha * select_scores_prog + (1 - self.alpha) * select_scores_none
547
548
        elif self.propensity_type == "none_pred":
549
            scores = self.alpha * select_scores_pred + (1 - self.alpha) * select_scores_none
550
551
        elif self.propensity_type == "none_tre":
552
            scores = self.alpha * select_scores_tre + (1 - self.alpha) * select_scores_none
553
554
        elif self.propensity_type == "none_pred_flipped":
555
            scores = self.alpha * select_scores_pred_flipped + (1 - self.alpha) * select_scores_none
556
557
        elif self.propensity_type == "pred_pred_flipped":
558
            scores = self.alpha * select_scores_pred_flipped + (1 - self.alpha) * select_scores_pred
559
560
        elif self.propensity_type == "none_pred_overlap":
561
            scores = self.alpha * select_scores_pred_overlap + (1 - self.alpha) * select_scores_none
562
            
563
        elif self.propensity_type == "none_prog_overlap":
564
            scores = self.alpha * select_scores_prog_overlap + (1 - self.alpha) * select_scores_none
565
566
        elif self.propensity_type == "pred_overalp_prog_overlap":
567
            scores = self.alpha * select_scores_prog_overlap + (1 - self.alpha) * select_scores_pred_overlap
568
569
        elif self.propensity_type == "rct_none":
570
            scores = select_scores_none
571
572
        elif self.propensity_type.startswith("toy"):
573
            scores = self.compute_propensity_scores_toy()
574
575
        else:
576
            raise ValueError(f"Unknown propensity type {self.propensity_type}.")
577
578
        if self.enforce_balancedness:
579
            scores = zscore(scores, axis=0)
580
581
        if self.propensity_type == "rct_none":
582
            scores = self.alpha * select_scores_none
583
584
        # Introduce unbalancedness and manipulate unbalancedness weights for comparable experiments with different seeds
585
        unb_weights = self.get_unbalancedness_weights(size=scores.shape[1])
586
587
        # Apply the softmax function to each row to get probabilities
588
        p = softmax(self.propensity_scale*scores, axis=1)
589
590
        # Scale probabilities to introduce unbalancedness
591
        p = p * (1 - unb_weights) ** self.unbalancedness_exp
592
593
        # Make sure rows add up to one again
594
        row_sums = p.sum(axis=1, keepdims=True)
595
        p = p / row_sums
596
        self.propensities = p
597
598
        log.debug(
599
            f'\nCheck if propensities are computed correctly:'
600
            f'\n==================================================================='
601
            f'\nSelect Scores'
602
            f'\n{self.select_scores}'
603
            f'\n{self.select_scores.shape}'
604
            f'\n\nPropensities'
605
            f'\n{self.propensities}'
606
            f'\n{self.propensities.shape}'
607
            f'\n===================================================================\n'
608
        )
609
610
        return None
611
612
    def compute_scores(self) -> None:
613
        """
614
        Compute scores for prognostic, predictive, and selective features based on the data and the feature weights.
615
        """
616
        # Each column of the score matrix corresponds to the score for a specific outcome. Rows correspond to samples.
617
        prog_lin = self.X @ self.prog_weights.T
618
        select_lin = self.X @ self.select_weights.T
619
        select_lin_pred = self.X @ self.select_weights_pred.T
620
        select_lin_prog = self.X @ self.select_weights_prog.T
621
622
        log.debug(
623
            f'\nCheck if linear scores are computed correctly for selective features:'
624
            f'\n==================================================================='
625
            f'\nself.X'
626
            f'\n{self.X}'
627
            f'\n{self.X.shape}'
628
            f'\n\nSelect Weights'
629
            f'\n{self.select_weights}'
630
            f'\n{self.select_weights.shape}'
631
            f'\n\nSelect Lin'
632
            f'\n{select_lin}'
633
            f'\n{select_lin.shape}'
634
            f'\n===================================================================\n'
635
        )
636
637
        # Compute scores for predictive and selective features for each treatment and outcome
638
        pred_lin = np.zeros((self.X.shape[0], self.num_T, self.dim_Y))
639
640
        # This creates a score for each treatment and outcome for each sample
641
        for i in range(self.num_T):
642
            pred_lin[:,i,:] = self.X @ self.pred_weights[i].T
643
        
644
        # Introduce non-linearity and get final scores
645
        prog_scores = (1 - self.nonlinearity_scale) * prog_lin + self.nonlinearity_scale * self.nonlinearities[0](prog_lin)
646
        select_scores = (1 - self.nonlinearity_scale) * select_lin + self.nonlinearity_scale * self.nonlinearities[1](select_lin)
647
        select_scores_pred_overlap = (1 - self.nonlinearity_scale) * select_lin_pred + self.nonlinearity_scale * self.nonlinearities[1](select_lin_pred)
648
        select_scores_prog_overlap = (1 - self.nonlinearity_scale) * select_lin_prog + self.nonlinearity_scale * self.nonlinearities[1](select_lin_prog)
649
650
        pred_scores = np.zeros((self.X.shape[0], self.num_T, self.dim_Y))
651
        for i in range(self.dim_Y):
652
            pred_scores[:,:,i] = (1 - self.nonlinearity_scale) * pred_lin[:,:,i] + self.nonlinearity_scale * self.nonlinearities[i+2](pred_lin[:,:,i])
653
654
        log.debug(
655
            f'\nCheck if all scores are computed correctly for predictive features:'
656
            f'\n==================================================================='
657
            f'\nself.X'
658
            f'\n{self.X}'
659
            f'\n{self.X.shape}'
660
            f'\n\nPred Weights'
661
            f'\n{self.pred_weights}'
662
            f'\n{self.pred_weights.shape}'
663
            f'\n\nPred Lin'
664
            f'\n{pred_lin}'
665
            f'\n{pred_lin.shape}'
666
            f'\n\nPred Scores'
667
            f'\n{pred_scores}'
668
            f'\n{pred_scores.shape}'
669
            f'\n===================================================================\n'
670
        )
671
672
        self.prog_scores = prog_scores
673
        self.select_scores = select_scores
674
        self.select_scores_pred_overlap = select_scores_pred_overlap
675
        self.select_scores_prog_overlap = select_scores_prog_overlap
676
677
        self.pred_scores = pred_scores
678
679
        return None
680
    
681
    @property
682
    def weights(self) -> Tuple:
683
        """
684
        Return weights for prognostic, predictive, and selective features.
685
        """
686
        return self.prog_weights, self.pred_weights, self.select_weights
687
    
688
    def sample_uniform_weights(self) -> None:
689
        """
690
        sample uniform weights for the features.
691
        """
692
        if self.propensity_type.startswith("toy"):
693
            self.prog_weights = np.zeros((self.dim_Y, self.dim_X))
694
            self.pred_weights = np.zeros((self.num_T, self.dim_Y, self.dim_X))
695
            self.select_weights = np.zeros((self.num_T, self.dim_X))
696
            self.select_weights_pred = np.zeros((self.num_T, self.dim_X))
697
            self.select_weights_prog = np.zeros((self.num_T, self.dim_X))
698
            return None
699
700
701
        # Sample weights for prognostic features, a weight for every outcome
702
        prog_weights = np.random.uniform(-1, 1, size=(self.dim_Y, self.dim_X)) * self.prog_mask
703
704
        # Sample weights for predictive and selective features, a weight for every dimension for every treatment and outcome
705
        pred_weights = np.random.uniform(-1, 1, size=(self.num_T, self.dim_Y, self.dim_X))
706
        select_weights = np.random.uniform(-1, 1, size=(self.num_T, self.dim_X))
707
        select_weights_pred = select_weights.copy()
708
        select_weights_prog = select_weights.copy()
709
710
        # # Sample weights for prognostic features, a weight for every outcome
711
        # prog_weights = np.random.uniform(0, 1, size=(self.dim_Y, self.dim_X)) * self.prog_mask
712
713
        # # Sample weights for predictive and selective features, a weight for every dimension for every treatment and outcome
714
        # pred_weights = np.random.uniform(0, 1, size=(self.num_T, self.dim_Y, self.dim_X))
715
        # select_weights = np.random.uniform(0, 1, size=(self.num_T, self.dim_X))
716
717
        # # Make sure treatments are different
718
        # pred_weights[0] = -pred_weights[0]
719
        # select_weights[0] = -select_weights[0]
720
721
        # # Ones as weights
722
        # prog_weights = np.ones((self.dim_Y, self.dim_X)) * self.prog_mask#/ self.prog_mask.sum()
723
        # pred_weights = np.ones((self.num_T, self.dim_Y, self.dim_X))  #/ self.pred_masks.sum(axis=1, keepdims=True)
724
        # select_weights = np.ones((self.num_T, self.dim_X))  #/ self.select_masks.sum(axis=1, keepdims=True)
725
726
        # Mask weights for features that are not important
727
        for i in range(self.num_T):
728
            pred_weights[i] = pred_weights[i] * self.pred_masks[:,i]
729
            select_weights[i] = select_weights[i] * self.select_masks[:,i]
730
            select_weights_pred[i] = select_weights_pred[i] * self.select_masks_pred[:,i]
731
            select_weights_prog[i] = select_weights_prog[i] * self.select_masks_prog[:,i]
732
733
        # for i in range(self.num_T):
734
        #     row_sums = pred_weights[i].sum(axis=1, keepdims=True)
735
        #     pred_weights[i] = pred_weights[i] / row_sums
736
737
        #     row_sums = select_weights[i].sum()
738
        #     select_weights[i] = select_weights[i] / row_sums
739
740
        # # Make sure that prog weights sum to one per outcome
741
        # row_sums = prog_weights.sum(axis=1, keepdims=True)
742
        # prog_weights = prog_weights / row_sums
743
744
        log.debug(
745
            f'\nCheck if masks are applied correctly:'
746
            f'\n==================================================================='
747
            f'\nSelect Weights'
748
            f'\n{select_weights}'
749
            f'\n{select_weights.shape}'
750
            f'\n\nSelect Masks'
751
            f'\n{self.select_masks}'
752
            f'\n{self.select_masks.shape}'
753
            f'\n\nPred Weights'
754
            f'\n{pred_weights}'
755
            f'\n{pred_weights.shape}'
756
            f'\n\nPred Masks'
757
            f'\n{self.pred_masks}'
758
            f'\n{self.pred_masks.shape}'
759
            f'\n===================================================================\n'
760
        )
761
        
762
        self.prog_weights = prog_weights
763
        self.pred_weights = pred_weights
764
        self.select_weights = select_weights
765
        self.select_weights_pred = select_weights_pred
766
        self.select_weights_prog = select_weights_prog
767
768
        return None
769
    
770
    @property
771
    def all_important_features(self) -> np.ndarray:
772
        """
773
        Return all important feature indices.
774
        """
775
        all_important_features = np.union1d(self.predictive_features, self.prognostic_features)
776
        all_important_features = np.union1d(all_important_features, self.selective_features)
777
778
        log.debug(
779
            f'\nCheck if all important features are computed correctly:'
780
            f'\n==================================================================='
781
            f'\nProg Features'
782
            f'\n{self.prognostic_features}'
783
            f'\n\nPred Features'
784
            f'\n{self.predictive_features}'
785
            f'\n\nSelect Features'
786
            f'\n{self.selective_features}'
787
            f'\n\nAll Important Features'
788
            f'\n{all_important_features}'
789
            f'\n===================================================================\n'
790
        )
791
792
        return all_important_features
793
794
    @property
795
    def prognostic_features(self) -> np.ndarray:
796
        """
797
        Return prognostic feature indices.
798
        """
799
        prog_features = np.where((self.prog_mask).astype(np.int32) != 0)
800
        return prog_features
801
    
802
    @property
803
    def predictive_features(self) -> np.ndarray:
804
        """
805
        Return predictive feature indices.
806
        """
807
        pred_features = np.where((self.pred_masks.sum(axis=1)).astype(np.int32) != 0)
808
        return pred_features
809
810
    @property
811
    def selective_features(self) -> np.ndarray:
812
        """
813
        Return selective feature indices.
814
        """
815
        select_features = np.where((self.select_masks.sum(axis=1)).astype(np.int32) != 0)
816
        return select_features
817
    
818
    def sample_important_feature_masks(self) -> None:
819
        """
820
        Pick features that are important for treatment selection, outcome prediction, and prognostic prediction based on the configuration.
821
        """
822
        if self.propensity_type.startswith("toy"):
823
            self.prog_mask = np.zeros(shape=(self.dim_X))
824
            self.pred_masks = np.zeros(shape=(self.dim_X, self.num_T))
825
            self.select_masks = np.zeros(shape=(self.dim_X, self.num_T))
826
827
            self.prog_mask[0] = 1
828
            self.pred_masks[0,0] = 1
829
            self.pred_masks[1,1] = 1
830
            self.select_masks[0,0] = 1
831
            self.select_masks[1,1] = 1
832
833
            return None
834
835
        # Get indices for features and shuffle if random_feature_selection is True
836
        all_indices = np.arange(self.dim_X)
837
        n = self.num_pred_features
838
839
        if self.random_feature_selection:
840
            np.random.shuffle(all_indices)
841
842
        # Initialize masks
843
        prog_mask = np.zeros(shape=(self.dim_X))
844
        pred_masks = np.zeros(shape=(self.dim_X, self.num_T))
845
        select_masks = np.zeros(shape=(self.dim_X, self.num_T))
846
847
        # Handle case with feature overlap
848
        if self.feature_type_overlap == "sel_pred":
849
        
850
            prog_indices = all_indices[:n]
851
            prog_mask[prog_indices] = 1
852
853
            if self.treatment_feature_overlap:
854
                assert 2*n <= int(self.dim_X)
855
                pred_indices = np.array(self.num_T * [all_indices[n:2*n]])
856
                select_indices = np.array(self.num_T * [all_indices[n:2*n]])
857
858
                prog_mask[prog_indices] = 1
859
                pred_masks[pred_indices] = 1
860
                select_masks[select_indices] = 1
861
862
            else:
863
                assert n*(1+self.num_T) <= int(self.dim_X)
864
                for i in range(self.num_T):
865
                    pred_indices = all_indices[(i+1)*n: (i+2)*n]
866
                    select_indices = all_indices[(i+1)*n: (i+2)*n]
867
            
868
                    pred_masks[pred_indices,i] = 1
869
                    select_masks[select_indices,i] = 1
870
871
        elif self.feature_type_overlap == "sel_prog":
872
873
            if self.treatment_feature_overlap:
874
                assert 2*n <= int(self.dim_X)
875
                prog_indices = all_indices[:n]
876
                prog_mask[prog_indices] = 1
877
                pred_indices = np.array(self.num_T * [all_indices[n:2*n]])
878
                select_indices = np.array(self.num_T * [all_indices[:n]])
879
880
                prog_mask[prog_indices] = 1
881
                pred_masks[pred_indices] = 1
882
                select_masks[select_indices] = 1
883
884
            else:
885
                assert 2*n*self.num_T <= int(self.dim_X)
886
                prog_indices = all_indices[:n*self.num_T:self.num_T]
887
                prog_mask[prog_indices] = 1
888
                for i in range(self.num_T):
889
                    select_indices = all_indices[i*n: (i+1)*n]
890
                    pred_indices = all_indices[(i+self.num_T+1)*n: (i+self.num_T+2)*n]
891
            
892
                    pred_masks[pred_indices,i] = 1
893
                    select_masks[select_indices,i] = 1
894
895
        elif self.feature_type_overlap == "sel_none":
896
            prog_indices = all_indices[:n]
897
            prog_mask[prog_indices] = 1
898
899
            if self.treatment_feature_overlap:
900
                assert 3*n <= int(self.dim_X)
901
                pred_indices = np.array(self.num_T * [all_indices[n:2*n]])
902
                select_indices = np.array(self.num_T * [all_indices[2*n:3*n]])
903
904
                prog_mask[prog_indices] = 1
905
                pred_masks[pred_indices] = 1
906
                select_masks[select_indices] = 1
907
908
            else:
909
                #assert n+2*n*self.num_T <= int(self.dim_X)
910
                for i in range(1,self.num_T+1):
911
                    select_indices = all_indices[i*n: (i+1)*n]
912
                    pred_indices = all_indices[(i+self.num_T)*n: (i+self.num_T+1)*n]
913
                    pred_masks[pred_indices,i-1] = 1
914
                    select_masks[select_indices,i-1] = 1
915
916
        # # Handle case with feature overlap
917
        # if self.feature_overlap:
918
        #     assert max(self.num_pred_features, self.num_prog_features, self.num_select_features) <= int(self.dim_X)
919
        
920
        #     prog_indices = all_indices[:self.num_prog_features]
921
        #     pred_indices = np.array(self.num_T * [all_indices[:self.num_pred_features]])
922
        #     select_indices = np.array(self.num_T * [all_indices[:self.num_select_features]])
923
924
        #     prog_mask[prog_indices] = 1
925
        #     pred_masks[pred_indices] = 1
926
        #     select_masks[select_indices] = 1
927
928
        # # Handle case without feature overlap
929
        # else:
930
        #     assert (self.num_prog_features + self.num_T * (self.num_pred_features + self.num_select_features)) <= int(self.dim_X)
931
932
        #     prog_indices = all_indices[:self.num_prog_features]
933
        #     prog_mask[prog_indices] = 1
934
        #     pred_indices = all_indices[self.num_prog_features : (self.num_prog_features + self.num_T*self.num_pred_features)]
935
        #     select_indices = all_indices[(self.num_prog_features + self.num_T*self.num_pred_features):(self.num_prog_features + self.num_T*(self.num_pred_features+self.num_select_features))]
936
            
937
        #     # Mask features for every treatment
938
        #     for i in range(self.num_T):
939
        #         pred_masks[pred_indices[i*self.num_pred_features:(i+1)*self.num_pred_features],i] = 1
940
        #         select_masks[select_indices[i*self.num_select_features:(i+1)*self.num_select_features],i] = 1
941
942
        self.prog_mask = prog_mask
943
        self.pred_masks = pred_masks
944
        self.select_masks = select_masks
945
        self.select_masks_pred = pred_masks.copy()
946
        self.select_masks_prog = select_masks.copy()
947
948
        log.debug(
949
            f'\nCheck if important features are sampled correctly:'
950
            f'\n==================================================================='
951
            f'\nProg Indices'
952
            f'\n{prog_indices}'
953
            f'\n\nPred Indices'
954
            f'\n{pred_indices}'
955
            f'\n\nSelect Indices'
956
            f'\n{select_indices}'
957
            f'\n\nProg Mask'
958
            f'\n{prog_mask}'
959
            f'\n\nPred Masks'
960
            f'\n{pred_masks}'
961
            f'\n\nSelect Masks'
962
            f'\n{select_masks}'
963
            f'\n===================================================================\n'
964
        )
965
        return None
966
967
    def sample_nonlinearities(self, num_nonlinearities: int):
968
        """
969
        Sample non-linearities for each outcome.
970
        """
971
        if self.nonlinearity_selection_type == "random":
972
            # pick num_nonlinearities 
973
            return random.choices(population=self.nonlinear_fcts, k=num_nonlinearities)
974
        
975
        else:
976
            raise ValueError(f"Unknown nonlinearity selection type {self.selection_type}.")
977
        
978
979
class TSimulator(SimulatorBase):
980
    """
981
    Data generation process class for simulating treatment selection only, when counterfactual outcomes are available (as for in-vitro/pharmacoscopy data).
982
    """
983
    nonlinear_fcts = [
984
                lambda x: np.abs(x),
985
                lambda x: np.exp(-(x**2) / 2),
986
                lambda x: 1 / (1 + x**2),
987
                lambda x: np.cos(x),
988
                lambda x: np.arctan(x),
989
                lambda x: np.tanh(x),
990
                lambda x: np.sin(x),
991
                lambda x: np.log(1 + x**2),
992
                lambda x: np.sqrt(1 + x**2),
993
                lambda x: np.cosh(x),
994
            ]
995
    
996
    def __init__(
997
        self,
998
        # Data dimensionality
999
        dim_X: int,
1000
1001
        # Seed
1002
        seed: int = 42,
1003
1004
        # Simulation type
1005
        simulation_type: str = "T",
1006
1007
        # Dimensionality of treatments and outcome
1008
        num_binary_outcome: int = 0,
1009
        standardize_outcome: bool = False,
1010
        standardize_per_outcome: bool = False,
1011
        num_T: int = 3,
1012
        dim_Y: int = 3,
1013
1014
        # Scale parameters
1015
        propensity_scale: float = 1,
1016
        unbalancedness_exp: float = 0,
1017
        nonlinearity_scale: float = 1,
1018
        propensity_type: str = "prog_pred",
1019
        alpha: float = 0.5,
1020
        enforce_balancedness: bool = False,
1021
1022
        # Important features
1023
        num_select_features: int = 5,
1024
        treatment_feature_overlap: bool = False,
1025
1026
        # Feature selection
1027
        random_feature_selection: bool = True,
1028
        nonlinearity_selection_type: bool = True,
1029
1030
        
1031
    ) -> None:
1032
        # Number of features
1033
        self.dim_X = dim_X
1034
1035
        # Make sure results are reproducible by setting seed for np, torch, random
1036
        self.seed = seed
1037
        enable_reproducible_results(seed=self.seed)
1038
1039
        # Simulation type
1040
        self.simulation_type = simulation_type
1041
1042
        # Store dimensions
1043
        self.num_binary_outcome = num_binary_outcome
1044
        self.standardize_outcome = standardize_outcome
1045
        self.standardize_per_outcome = standardize_per_outcome
1046
        self.num_T = num_T
1047
        self.dim_Y = dim_Y
1048
1049
        # Scale parameters
1050
        self.propensity_scale = propensity_scale
1051
        self.unbalancedness_exp = unbalancedness_exp
1052
        self.nonlinearity_scale = nonlinearity_scale
1053
        self.propensity_type = propensity_type
1054
        self.alpha = alpha
1055
        self.enforce_balancedness = enforce_balancedness
1056
1057
        # Important features
1058
        self.num_select_features = num_select_features
1059
        self.treatment_feature_overlap = treatment_feature_overlap
1060
        self.num_important_features = num_select_features
1061
1062
        # Feature selection
1063
        self.random_feature_selection = random_feature_selection
1064
        self.nonlinearity_selection_type = nonlinearity_selection_type
1065
1066
        # Setup variables
1067
        self.nonlinearities = None
1068
        self.select_masks = None
1069
        self.select_weights = None
1070
1071
        # Setup
1072
        self.setup()
1073
1074
        # Simulation variables
1075
        self.X = None
1076
        self.select_scores = None
1077
        self.propensities, self.outcomes, self.T, self.Y = None, None, None, None
1078
1079
    def get_simulated_data(self, train_ratio: float = 0.8):
1080
        """
1081
        Extract results and split into training and test set. Include counterfactual outcomes and propensities.
1082
        """
1083
        return self.X, self.T, self.Y, self.outcomes, self.propensities
1084
        # Split data
1085
        # train_size = int(train_ratio * self.X.shape[0])
1086
        # X_train, X_test = self.X[:train_size], self.X[train_size:]
1087
        # T_train, T_test = self.T[:train_size], self.T[train_size:]
1088
        # Y_train, Y_test = self.Y[:train_size], self.Y[train_size:]
1089
        # outcomes_train, outcomes_test = self.outcomes[:train_size,:,:], self.outcomes[train_size:,:,:]
1090
        # propensities_train, propensities_test = self.propensities[:train_size], self.propensities[train_size:]
1091
1092
        # if train_ratio == 1:
1093
        #     return self.X, self.T, self.Y, self.outcomes, self.propensities
1094
        
1095
        # return X_train, X_test, T_train, T_test, Y_train, Y_test, outcomes_train, outcomes_test, propensities_train, propensities_test
1096
1097
    def simulate(self, X, outcomes=None) -> Tuple:
1098
        """
1099
        Simulate treatment and outcome for a dataset based on the configuration.
1100
        """
1101
        log.debug(
1102
            f'Simulating treatment and outcome for a dataset with:'
1103
            f'\n==================================================================='
1104
            f'\nDim X: {self.dim_X}'
1105
            f'\nDim T: {self.num_T}'
1106
            f'\nDim Y: {self.dim_Y}'
1107
            f'\nPropensity Scale: {self.propensity_scale}'
1108
            f'\nUnbalancedness Exponent: {self.unbalancedness_exp}'
1109
            f'\nNonlinearity Scale: {self.nonlinearity_scale}'
1110
            f'\nNum Select Features: {self.num_select_features}'
1111
            f'\nFeature Overlap: {self.treatment_feature_overlap}'
1112
            f'\nRandom Feature Selection: {self.random_feature_selection}'
1113
            f'\nNonlinearity Selection Type: {self.nonlinearity_selection_type}'
1114
            f'\n===================================================================\n'
1115
        )
1116
1117
        # 1. Store data
1118
        self.X = X
1119
1120
        # 2. Compute scores for prognostic, predictive, and selective features
1121
        self.compute_scores()
1122
1123
        # 3. Retrieve factual and counterfactual outcomes based on the data and the predictive and prognostic scores
1124
        self.outcomes = outcomes
1125
        assert self.outcomes.shape == (self.X.shape[0], self.num_T, self.dim_Y)
1126
1127
        if self.standardize_outcome:
1128
            if self.standardize_per_outcome:
1129
                self.outcomes = zscore(self.outcomes, axis=0) #, axis=None) # add axis=None to make problem easier again
1130
            else:
1131
                self.outcomes = zscore(self.outcomes, axis=None) #, axis=None) # add axis=None to make problem easier again
1132
1133
        log.debug(
1134
            f'\nCheck if outcomes are processed correctly:'
1135
            f'\n==================================================================='
1136
            f'\n\nOutcomes'
1137
            f'\n{self.outcomes}'
1138
            f'\n{self.outcomes.shape}'
1139
            f'\n\nMean Outcomes'
1140
            f'\n{self.outcomes.mean(axis=0)}'
1141
            f'\n\nVariance Outcomes'
1142
            f'\n{self.outcomes.var(axis=0)}'
1143
            f'\n===================================================================\n'
1144
        )
1145
1146
        # 4. Compute propensities based on the data and the selective scores
1147
        self.compute_propensities()
1148
1149
        # 5. Sample treatment assignment based on the propensities
1150
        self.sample_T()
1151
1152
        # 6. Extract the outcome based on the treatment assignment
1153
        self.extract_Y()
1154
1155
        return None
1156
    
1157
    def setup(self) -> None:
1158
        """
1159
        Setup the simulator by defining variables which remain the same across simulations with different samples but the same configuration.
1160
        """
1161
        # 1. Sample nonlinearities used 
1162
        num_nonlinearities = 1 # Same non-linearity for all treatment selection mechanisms
1163
        self.nonlinearities = self.sample_nonlinearities(num_nonlinearities)
1164
1165
        # 2. Set important feature masks - determine which features should be used for treatment selection, outcome prediction
1166
        self.sample_important_feature_masks()
1167
1168
        # 3. Sample weights for features
1169
        self.sample_uniform_weights()
1170
1171
    def get_true_cates(self, 
1172
                       X: np.ndarray, 
1173
                       T: np.ndarray, 
1174
                       outcomes: np.ndarray) -> np.ndarray:
1175
        """
1176
        Compute true CATEs for each treatment based on the data and the outcomes.
1177
        Always use the selected treatment as the base treatment.
1178
        """
1179
        # Compute CATEs for each treatment
1180
        cates = np.zeros((X.shape[0], self.num_T, self.dim_Y))
1181
1182
        for i in range(X.shape[0]):
1183
            for j in range(self.num_T):
1184
                cates[i,j,:] = outcomes[i,j,:] - outcomes[i,int(T[i]),:]
1185
1186
        log.debug(
1187
            f'\nCheck if true CATEs are computed correctly:'
1188
            f'\n==================================================================='
1189
            f'\nOutcomes: {outcomes.shape}'
1190
            f'\n{outcomes}'
1191
            f'\n\nTreatment Assignment: {T.shape}'
1192
            f'\n{T}'
1193
            f'\n\nTrue CATEs: {cates.shape}'
1194
            f'\n{cates}'
1195
            f'\n===================================================================\n'
1196
        )
1197
1198
        return cates
1199
    
1200
    def extract_Y(self) -> None:
1201
        """
1202
        Extract the outcome based on the treatment assignment.
1203
        """
1204
        self.Y = self.outcomes[np.arange(self.X.shape[0]), self.T] 
1205
1206
        log.debug(
1207
            f'\nCheck if outcomes are extracted correctly:'
1208
            f'\n==================================================================='
1209
            f'\nOutcomes'
1210
            f'\n{self.outcomes}'
1211
            f'\n{self.outcomes.shape}'
1212
            f'\n\nTreatment Assignment'
1213
            f'\n{self.T}'
1214
            f'\n{self.T.shape}'
1215
            f'\n\nExtracted Outcomes'
1216
            f'\n{self.Y}'
1217
            f'\n{self.Y.shape}'
1218
            f'\n===================================================================\n'
1219
        )
1220
1221
        return None
1222
1223
    def sample_T(self) -> None:
1224
        """
1225
        Sample treatment assignment based on the propensities.
1226
        """
1227
        # Sample from the resulting categorical distribution per row
1228
        self.T = np.array([np.random.choice([tre for tre in range(self.propensities.shape[1])], p=row) for row in self.propensities])
1229
   
1230
        log.debug(
1231
            f'\nCheck if treatment assignment is sampled correctly:'
1232
            f'\n==================================================================='
1233
            f'\nPropensities'
1234
            f'\n{self.propensities}'
1235
            f'\n{self.propensities.shape}'
1236
            f'\n\nTreatment Assignment'
1237
            f'\n{self.T}'
1238
            f'\n{self.T.shape}'
1239
            f'\n\nUnique Treatment Counts'
1240
            f'\n{np.unique(self.T, return_counts=True)}'
1241
            f'\n===================================================================\n'
1242
        )
1243
1244
        return None
1245
    
1246
    def get_unbalancedness_weights(self, size: int) -> np.ndarray:
1247
        """
1248
        Create weights for introducing unbalancedness for class probabilities.
1249
        """
1250
        # Sample initial distribution of treatment assignment 
1251
        unb_weights = np.random.uniform(0, 1, size=size) 
1252
        unb_weights = unb_weights / unb_weights.sum()
1253
1254
        # Standardize the weights and make sure that a treatment doesn't completely disappear for small unbalancedness exponents
1255
        min_val = unb_weights.min()
1256
        range_val = unb_weights.max() - min_val
1257
        unb_weights = (unb_weights - min_val) / range_val
1258
        unb_weights = 0.01 + unb_weights * 0.98
1259
1260
        return unb_weights
1261
    
1262
    def compute_propensities(self) -> None:
1263
        """
1264
        Compute propensities based on the data and the selective scores.
1265
        """
1266
        select_scores_none = zscore(self.select_scores, axis=0) # Comment for Predictive Epertise
1267
1268
        select_scores_pred = np.zeros((self.X.shape[0], self.num_T))
1269
        select_scores_pred_flipped = np.zeros((self.X.shape[0], self.num_T))
1270
        select_scores_prog = np.zeros((self.X.shape[0], self.num_T))
1271
        select_scores_tre = np.zeros((self.X.shape[0], self.num_T))
1272
1273
        select_scores_pred[:,0] = self.outcomes[:,0,0] - self.outcomes[:,1,0]
1274
        select_scores_pred[:,1] = self.outcomes[:,1,0] - self.outcomes[:,0,0]
1275
1276
        select_scores_pred_flipped[:,0] = self.outcomes[:,1,0] - self.outcomes[:,0,0]
1277
        select_scores_pred_flipped[:,1] = self.outcomes[:,0,0] - self.outcomes[:,1,0]
1278
1279
        select_scores_prog[:,0] = self.outcomes[:,0,0]
1280
        select_scores_prog[:,1] = -self.outcomes[:,0,0]
1281
1282
        select_scores_tre[:,0] = -self.outcomes[:,1,0]
1283
        select_scores_tre[:,1] = self.outcomes[:,1,0]
1284
1285
        if self.propensity_type == "prog_tre":
1286
            scores = self.alpha * select_scores_tre + (1 - self.alpha) * select_scores_prog
1287
1288
        # Standardize all scores
1289
        select_scores_pred = zscore(select_scores_pred, axis=0)
1290
        select_scores_pred_flipped = zscore(select_scores_pred_flipped, axis=0)
1291
        select_scores_prog = zscore(select_scores_prog, axis=0)
1292
        select_scores_tre = zscore(select_scores_tre, axis=0)
1293
1294
        if self.propensity_type == "prog_pred":
1295
            scores = self.alpha * select_scores_pred + (1 - self.alpha) * select_scores_prog
1296
1297
        elif self.propensity_type == "prog_tre":
1298
            pass
1299
1300
        elif self.propensity_type == "none_prog":
1301
            scores = self.alpha * select_scores_prog + (1 - self.alpha) * select_scores_none
1302
1303
        elif self.propensity_type == "none_pred":
1304
            scores = self.alpha * select_scores_pred + (1 - self.alpha) * select_scores_none
1305
1306
        elif self.propensity_type == "none_tre":
1307
            scores = self.alpha * select_scores_tre + (1 - self.alpha) * select_scores_none
1308
1309
        elif self.propensity_type == "none_pred_flipped":
1310
            scores = self.alpha * select_scores_pred_flipped + (1 - self.alpha) * select_scores_none
1311
1312
        elif self.propensity_type == "pred_pred_flipped":
1313
            scores = self.alpha * select_scores_pred_flipped + (1 - self.alpha) * select_scores_pred
1314
1315
        elif self.propensity_type == "rct_none":
1316
            scores = select_scores_none
1317
1318
        else:
1319
            raise ValueError(f"Unknown propensity type {self.propensity_type}.")
1320
1321
        if self.enforce_balancedness:
1322
            scores = zscore(scores, axis=0)
1323
1324
        if self.propensity_type == "rct_none":
1325
            scores = self.alpha * select_scores_none
1326
1327
        # Introduce unbalancedness and manipulate unbalancedness weights for comparable experiments with different seeds
1328
        unb_weights = self.get_unbalancedness_weights(size=scores.shape[1])
1329
1330
        # Apply the softmax function to each row to get probabilities
1331
        p = softmax(self.propensity_scale*scores, axis=1)
1332
1333
        # Scale probabilities to introduce unbalancedness
1334
        p = p * (1 - unb_weights) ** self.unbalancedness_exp
1335
1336
        # Make sure rows add up to one again
1337
        row_sums = p.sum(axis=1, keepdims=True)
1338
        p = p / row_sums
1339
        self.propensities = p
1340
1341
        log.debug(
1342
            f'\nCheck if propensities are computed correctly:'
1343
            f'\n==================================================================='
1344
            f'\nSelect Scores'
1345
            f'\n{self.select_scores}'
1346
            f'\n{self.select_scores.shape}'
1347
            f'\n\nPropensities'
1348
            f'\n{self.propensities}'
1349
            f'\n{self.propensities.shape}'
1350
            f'\n===================================================================\n'
1351
        )
1352
1353
        return None
1354
1355
    def compute_scores(self) -> None:
1356
        """
1357
        Compute scores for prognostic, predictive, and selective features based on the data and the feature weights.
1358
        """
1359
        # Each column of the score matrix corresponds to the score for a specific outcome. Rows correspond to samples.
1360
        select_lin = self.X @ self.select_weights.T
1361
1362
        log.debug(
1363
            f'\nCheck if linear scores are computed correctly for selective features:'
1364
            f'\n==================================================================='
1365
            f'\nself.X'
1366
            f'\n{self.X}'
1367
            f'\n{self.X.shape}'
1368
            f'\n\nSelect Weights'
1369
            f'\n{self.select_weights}'
1370
            f'\n{self.select_weights.shape}'
1371
            f'\n\nSelect Lin'
1372
            f'\n{select_lin}'
1373
            f'\n{select_lin.shape}'
1374
            f'\n===================================================================\n'
1375
        )
1376
1377
        # Introduce non-linearity and get final scores
1378
        select_scores = (1 - self.nonlinearity_scale) * select_lin + self.nonlinearity_scale * self.nonlinearities[0](select_lin)
1379
        self.select_scores = select_scores
1380
1381
        return None
1382
    
1383
    @property
1384
    def weights(self) -> Tuple:
1385
        """
1386
        Return weights for prognostic, predictive, and selective features.
1387
        """
1388
        return None, None, self.select_weights
1389
    
1390
    def sample_uniform_weights(self) -> None:
1391
        """
1392
        sample uniform weights for the features.
1393
        """
1394
        # Sample weights for selective features, a weight for every dimension for every treatment and outcome
1395
        select_weights = np.random.uniform(-1, 1, size=(self.num_T, self.dim_X))
1396
1397
1398
        # Mask weights for features that are not important
1399
        for i in range(self.num_T):
1400
            select_weights[i] = select_weights[i] * self.select_masks[:,i]
1401
1402
        log.debug(
1403
            f'\nCheck if masks are applied correctly:'
1404
            f'\n==================================================================='
1405
            f'\nSelect Weights'
1406
            f'\n{select_weights}'
1407
            f'\n{select_weights.shape}'
1408
            f'\n\nSelect Masks'
1409
            f'\n{self.select_masks}'
1410
            f'\n{self.select_masks.shape}'
1411
            f'\n===================================================================\n'
1412
        )
1413
           
1414
        self.select_weights = select_weights
1415
1416
        return None
1417
    @property
1418
    def all_important_features(self) -> np.ndarray:
1419
        """
1420
        Return all important feature indices.
1421
        """
1422
        all_important_features = self.selective_features
1423
        log.debug(
1424
            f'\nCheck if all important features are computed correctly:'
1425
            f'\n==================================================================='
1426
            f'\n\nSelect Features'
1427
            f'\n{self.selective_features}'
1428
            f'\n\nAll Important Features'
1429
            f'\n{all_important_features}'
1430
            f'\n===================================================================\n'
1431
        )
1432
1433
        return all_important_features
1434
    
1435
    @property
1436
    def predictive_features(self) -> np.ndarray:
1437
        """
1438
        Return predictive feature indices.
1439
        """
1440
        return None
1441
    
1442
    @property
1443
    def prognostic_features(self) -> np.ndarray:
1444
        """
1445
        Return prognostic feature indices.
1446
        """
1447
        return None
1448
1449
    @property
1450
    def selective_features(self) -> np.ndarray:
1451
        """
1452
        Return selective feature indices.
1453
        """
1454
        select_features = np.where((self.select_masks.sum(axis=1)).astype(np.int32) != 0)
1455
        return select_features
1456
    
1457
    def sample_important_feature_masks(self) -> None:
1458
        """
1459
        Pick features that are important for treatment selection based on the configuration.
1460
        """
1461
        # Get indices for features and shuffle if random_feature_selection is True
1462
        all_indices = np.arange(self.dim_X)
1463
1464
        if self.random_feature_selection:
1465
            np.random.shuffle(all_indices)
1466
1467
        # Initialize masks
1468
        select_masks = np.zeros(shape=(self.dim_X, self.num_T))
1469
1470
        # Handle case with feature overlap
1471
        if self.treatment_feature_overlap:
1472
            assert self.num_select_features <= int(self.dim_X)
1473
            select_indices = np.array(self.num_T * [all_indices[:self.num_select_features]])
1474
            select_masks[select_indices] = 1
1475
1476
        # Handle case without feature overlap
1477
        else:
1478
            assert (self.num_T * self.num_select_features) <= int(self.dim_X)
1479
            select_indices = all_indices[:self.num_select_features*self.num_T]
1480
            
1481
            # Mask features for every treatment
1482
            for i in range(self.num_T):
1483
                select_masks[select_indices[i*self.num_select_features:(i+1)*self.num_select_features],i] = 1
1484
1485
        self.select_masks = select_masks
1486
1487
        return None
1488
1489
    def sample_nonlinearities(self, num_nonlinearities: int):
1490
        """
1491
        Sample non-linearities for each outcome.
1492
        """
1493
        if self.nonlinearity_selection_type == "random":
1494
            # pick num_nonlinearities 
1495
            return random.choices(population=self.nonlinear_fcts, k=num_nonlinearities)
1496
        
1497
        else:
1498
            raise ValueError(f"Unknown nonlinearity selection type {self.selection_type}.")