a b/catenets/models/diffpo/diffpo_learner.py
1
from typing import Any, Callable, List
2
3
import numpy as np
4
import torch
5
from torch import nn
6
import os
7
import tqdm
8
import catenets.logger as log
9
from catenets.models.constants import (
10
    DEFAULT_BATCH_SIZE,
11
    DEFAULT_DIM_P_OUT,
12
    DEFAULT_DIM_P_R,
13
    DEFAULT_DIM_S_OUT,
14
    DEFAULT_DIM_S_R,
15
    DEFAULT_LAYERS_OUT,
16
    DEFAULT_LAYERS_R,
17
    DEFAULT_N_ITER,
18
    DEFAULT_N_ITER_MIN,
19
    DEFAULT_N_ITER_PRINT,
20
    DEFAULT_PATIENCE,
21
    DEFAULT_PENALTY_L2,
22
    DEFAULT_PENALTY_ORTHOGONAL,
23
    DEFAULT_SEED,
24
    DEFAULT_NJOBS,
25
    DEFAULT_STEP_SIZE,
26
    DEFAULT_VAL_SPLIT,
27
    LARGE_VAL,
28
)
29
from catenets.models.torch.base import DEVICE, BaseCATEEstimator
30
from catenets.models.torch.utils.model_utils import make_val_split
31
import pandas as pd
32
# Hydra
33
from omegaconf import DictConfig
34
import json
35
import datetime
36
37
from .src.main_model_table import TabCSDI
38
from .src.utils_table import train
39
from .dataset_acic import get_dataloader
40
41
from .PropensityNet import load_data
42
43
44
torch.manual_seed(0)
45
46
class AverageMeter(object):
47
    """Computes and stores the average and current value"""
48
    def __init__(self):
49
        self.reset()
50
51
    def reset(self):
52
        self.val = 0
53
        self.avg = 0
54
        self.sum = 0
55
        self.count = 0
56
57
    def update(self, val, n=1):
58
        self.val = val
59
        self.sum += val * n
60
        self.count += n
61
        self.avg = self.sum / self.count
62
63
64
class DiffPOLearner(BaseCATEEstimator):
65
    """
66
    A flexible treatment effect estimator based on the EconML framework.
67
    """
68
69
    def __init__(
70
        self,
71
        cfg: DictConfig,
72
        num_features: int,
73
        binary_y: bool,
74
    ) -> None:
75
        self.config = cfg.DiffPOLearner
76
        self.diffpo_path = cfg.diffpo_path
77
        self.config.diffusion.cond_dim = num_features+1 # make sure inner dimension matches the dataset
78
        self.est = None
79
        self.propnet = None
80
        self.device = DEVICE
81
        self.cate_cis = None # confidence intervals, dim: 2, n, num_T-1, dim_Y
82
        self.pred_outcomes = None
83
84
        # create folder if diffpo_path + 'data' does not exist
85
        if not os.path.exists(self.diffpo_path):
86
            os.makedirs(self.diffpo_path)
87
        
88
        # Store data for their pipeline
89
        self.data_dir = self.diffpo_path+'/data/'
90
        if not os.path.exists(self.data_dir):
91
            os.makedirs(self.data_dir)
92
93
        return None
94
95
    def reshape_data(self, X: np.ndarray, w: np.ndarray, outcomes: np.ndarray) -> None:
96
        data = np.concatenate([w.reshape(-1,1),outcomes[:,0],outcomes[:,1],outcomes[:,0],outcomes[:,1],X], axis=1)
97
        data_df = pd.DataFrame(data)
98
        # Create masking array of same shape as pp_data and initialize with 1s
99
        mask = np.ones(data_df.shape)
100
        mask[:,1] = w
101
        mask[:,2] = 1-w
102
        mask[:,3] = 0
103
        mask[:,4] = 0
104
        mask_df = pd.DataFrame(mask)
105
106
        return data_df, mask_df
107
108
    def train(self, X: np.ndarray, y: np.ndarray, w: np.ndarray, outcomes:np.ndarray) -> None:
109
        """
110
        Prepare data and train DiffPO Learner
111
        """
112
        log.info("Training data shapes: X: {}, Y: {}, T: {}".format(X.shape, y.shape, w.shape))
113
114
        if not os.path.exists(self.data_dir):
115
            os.makedirs(self.data_dir)
116
        data, mask = self.reshape_data(X, w, outcomes)
117
        
118
        # create destination folders if not exist
119
        if not os.path.exists(self.data_dir+"acic2018_norm_data/"):
120
            os.makedirs(self.data_dir+"acic2018_norm_data/")
121
        if not os.path.exists(self.data_dir+"acic2018_mask/"):
122
            os.makedirs(self.data_dir+"acic2018_mask/")
123
124
        # save intermediate data
125
        data.to_csv(self.data_dir+"acic2018_norm_data/data_pp.csv", index=False)
126
        mask.to_csv(self.data_dir+"acic2018_mask/data_pp.csv", index=False)
127
128
        # Remove old files
129
        if os.path.exists(self.data_dir+"missing_ratio-0.2_seed-1_current_id-data_max-min_norm.pk"):
130
            os.remove(self.data_dir+"missing_ratio-0.2_seed-1_current_id-data_max-min_norm.pk")
131
        if os.path.exists(self.data_dir+"missing_ratio-0.2_seed-1.pk"):
132
            os.remove(self.data_dir+"missing_ratio-0.2_seed-1.pk")
133
134
        # Create folder
135
        current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
136
        
137
        # define these as variables
138
        nfold = 1
139
        config = "acic2018.yaml"
140
        current_id = "data_pp"
141
        device = DEVICE
142
        seed = 1
143
        testmissingratio = 0.2
144
        unconditional = 0
145
        modelfolder = ""
146
        nsample = 1
147
        perform_training = 1
148
149
        foldername = self.diffpo_path + "/save/acic_fold" + str(nfold) + "_" + current_time + "/"
150
        # print("model folder:", foldername)
151
        os.makedirs(foldername, exist_ok=True)
152
153
        current_id = "data_pp"
154
        # print('Start exe_acic on current_id', current_id)
155
156
        # Every loader contains "observed_data", "observed_mask", "gt_mask", "timepoints"
157
        training_size = 1
158
        
159
        train_loader, valid_loader, _ = get_dataloader(
160
            seed=seed,
161
            nfold=nfold,
162
            batch_size=self.config["train"]["batch_size"],
163
            missing_ratio=testmissingratio,
164
            dataset_name = self.config["dataset"]["data_name"],
165
            current_id = current_id,
166
            training_size = training_size,
167
            data_path=self.data_dir,
168
            x_dim=X.shape[1],
169
        )
170
171
        #=======================First train and fix propnet======================
172
        # Train a propensitynet on this dataset
173
174
        propnet = load_data(dataset_name = self.config["dataset"]["data_name"], current_id=current_id, x_dim=X.shape[1], data_path=self.data_dir)
175
176
        # frozen the trained_propnet
177
        # print('Finish training propnet and fix the parameters')
178
        propnet.eval()
179
        # ========================================================================
180
181
        propnet = propnet.to(device)
182
183
        model = TabCSDI(self.config, self.device).to(self.device)
184
        # Train the model
185
        train(
186
            model,
187
            self.config["train"],
188
            train_loader,
189
            valid_loader=valid_loader,
190
            valid_epoch_interval=self.config["train"]["valid_epoch_interval"],
191
            foldername=foldername,
192
            propnet = propnet
193
        )
194
195
        directory = self.diffpo_path + "/save_model/" + current_id
196
        if not os.path.exists(directory):
197
            os.makedirs(directory)
198
        
199
        # # load model
200
        # model.load_state_dict(torch.load(directory + "/model_weights.pth"))
201
202
        # save model
203
        torch.save(model.state_dict(), directory + "/model_weights.pth")
204
        
205
        
206
207
    # predict function with bool return_po and return potential outcome if true
208
    def predict(self, X: np.ndarray, T0: np.ndarray = None, T1: np.ndarray = None, outcomes: np.ndarray = None) -> np.ndarray:
209
        """
210
        Predict the treatment effect using the DiffPO estimator.
211
        """
212
        # Store data for their pipeline
213
        data_dir = self.data_dir
214
        
215
        data, mask = self.reshape_data(X, T0, outcomes)
216
        
217
        data.to_csv(data_dir+"acic2018_norm_data/data_pp_test.csv", index=False)
218
        mask.to_csv(data_dir+"acic2018_mask/data_pp_test.csv", index=False)
219
220
        # Remove old files
221
        if os.path.exists(data_dir+"missing_ratio-0.2_seed-1_current_id-data_max-min_norm.pk"):
222
            os.remove(data_dir+"missing_ratio-0.2_seed-1_current_id-data_max-min_norm.pk")
223
        if os.path.exists(data_dir+"missing_ratio-0.2_seed-1.pk"):
224
            os.remove(data_dir+"missing_ratio-0.2_seed-1.pk")
225
226
        # Create folder
227
        current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
228
        
229
        # define these as variables
230
        nfold = 1
231
        current_id = "data_pp_test"
232
        current_id_train = "data_pp"
233
        seed = 1
234
        testmissingratio = 0.2
235
        nsample = 50
236
        perform_training = 1
237
238
        foldername = "./save/acic_fold" + str(nfold) + "_" + current_time + "/"
239
        # print("model folder:", foldername)
240
        os.makedirs(foldername, exist_ok=True)
241
242
        # Every loader contains "observed_data", "observed_mask", "gt_mask", "timepoints"
243
        training_size = 0
244
        _,_,test_loader = get_dataloader(
245
            seed=seed,
246
            nfold=nfold,
247
            batch_size=1,
248
            missing_ratio=testmissingratio,
249
            dataset_name = self.config["dataset"]["data_name"],
250
            current_id = current_id,
251
            training_size = training_size,
252
            data_path=data_dir,
253
            x_dim=X.shape[1],
254
        )
255
256
        # load model
257
        directory = self.diffpo_path + "/save_model/" + current_id_train
258
        os.makedirs(directory, exist_ok=True)
259
        model = TabCSDI(self.config, self.device).to(self.device)
260
        model.load_state_dict(torch.load(directory + "/model_weights.pth"))
261
262
        # get cates
263
        return self.evaluate(model, test_loader, nsample, foldername=foldername)
264
    
265
    def predict_outcomes(self, X: np.ndarray, T0: np.ndarray = None, T1: np.ndarray = None, outcomes: np.ndarray = None) -> np.ndarray:
266
        """
267
        Predict the potential outcomes using the DiffPO estimator.
268
        """
269
        # add outer dimension to self.pred_outcomes
270
        return self.pred_outcomes.cpu().numpy().reshape(self.pred_outcomes.shape[0], self.pred_outcomes.shape[1], 1)
271
272
    def explain(self, X: np.ndarray, background_samples: np.ndarray = None, explainer_limit: int = None) -> np.ndarray:
273
        """
274
        Explain the treatment effect using the EconML estimator.
275
        """
276
        if explainer_limit is None:
277
            explainer_limit = X.shape[0]
278
279
        return self.est.shap_values(X[:explainer_limit], background_samples=None)
280
    
281
    def infer_effect_ci(self, X, T0) -> np.ndarray:
282
        """
283
        Infer the confidence interval of the treatment effect using the EconML estimator.
284
        """
285
        cates_conf_lbs = self.cate_cis[0]
286
        cates_conf_ups = self.cate_cis[1]
287
288
        temp = cates_conf_lbs[T0 != 0]
289
        cates_conf_lbs[T0 != 0] = -cates_conf_ups[T0 != 0]
290
        cates_conf_ups[T0 != 0] = -temp
291
        return np.array([cates_conf_lbs, cates_conf_ups])
292
    
293
    def evaluate(self, model, test_loader, nsample=100, scaler=1, mean_scaler=0, foldername=""):
294
        # Control random seed in the current script.
295
        torch.manual_seed(0)
296
        np.random.seed(0)
297
298
        with torch.no_grad():
299
            model.eval()
300
            mse_total = 0
301
            mae_total = 0
302
            evalpoints_total = 0
303
304
            pehe_test = AverageMeter()
305
            y0_test = AverageMeter()
306
            y1_test = AverageMeter()
307
308
            # for uncertainty
309
            y0_samples = []
310
            y1_samples = []
311
            y0_true_list = []
312
            y1_true_list = []
313
            ite_samples = []
314
            ite_true_list = []
315
            pred_ites = []
316
            pred_y0s = []
317
            pred_y1s = []
318
            
319
            for batch_no, test_batch in enumerate(test_loader, start=1):
320
                # Get model outputs
321
                output = model.evaluate(test_batch, nsample) 
322
                samples, observed_data, target_mask, observed_mask, observed_tp = output
323
324
                # Extract relevant quantities
325
                y0_samples.append(samples[:,:,0]) 
326
                y1_samples.append(samples[:,:,1]) 
327
                ite_samples.append(samples[:,:,1] - samples[:,:,0])
328
329
                # Get point estimation through median
330
                est_data = torch.median(samples, dim=1).values
331
332
                # Get true ite
333
                obs_data = observed_data.squeeze(1)
334
                true_ite = obs_data[:, 2] - obs_data[:, 1] 
335
                ite_true_list.append(true_ite)
336
337
                # Get predicted ite
338
                pred_y0 = est_data[:, 0]
339
                pred_y1 = est_data[:, 1]
340
                pred_y0s.append(pred_y0)
341
                pred_y1s.append(pred_y1)
342
                y0_true_list.append(obs_data[:, 1])
343
                y1_true_list.append(obs_data[:, 2])
344
                pred_ite = pred_y1 - pred_y0
345
                pred_ites.append(pred_ite)
346
347
                #y0_test.update(diff_y0, obs_data.size(0))    
348
                #diff_y0 = np.mean((pred_y0.cpu().numpy()-obs_data[:, 1].cpu().numpy())**2)
349
                #y1_test.update(diff_y1, obs_data.size(0)) 
350
                #diff_y1 = np.mean((pred_y1.cpu().numpy()-obs_data[:, 2].cpu().numpy())**2)
351
                #pehe_test.update(diff_ite, obs_data.size(0))    
352
                #diff_ite = np.mean((true_ite.cpu().numpy()-est_ite.cpu().numpy())**2)
353
354
#---------------uncertainty estimation-------------------------
355
            pred_samples_y0 = torch.cat(y0_samples, dim=0)
356
            pred_samples_y1 = torch.cat(y1_samples, dim=0)
357
            pred_samples_ite = torch.cat(ite_samples, dim=0)
358
359
            truth_y0 = torch.cat(y0_true_list, dim=0) 
360
            truth_y1 = torch.cat(y1_true_list, dim=0) 
361
            truth_ite = torch.cat(ite_true_list, dim=0)
362
363
            prob_0, median_width_0 = self.compute_interval(pred_samples_y0, truth_y0)
364
            prob_1, median_width_1 = self.compute_interval(pred_samples_y1, truth_y1)
365
            prob_ite, median_width_ite = self.compute_interval(pred_samples_ite, truth_ite)
366
367
            self.cate_cis = torch.zeros(2, pred_samples_ite.shape[0], 1) # confidence intervals, dim: 2, n, dim_Y
368
            for i in range(pred_samples_ite.shape[0]):
369
                lower_quantile, upper_quantile, in_quantiles = self.check_intervel(confidence_level=0.95, y_pred= pred_samples_ite[i, :], y_true=truth_ite[i])
370
                self.cate_cis[0, i, 0] = lower_quantile
371
                self.cate_cis[1, i, 0] = upper_quantile
372
           
373
    #----------------------------------------------------------------
374
        pred_ites = torch.cat(pred_ites, dim=0)
375
        pred_y0s = torch.cat(pred_y0s, dim=0)
376
        pred_y1s = torch.cat(pred_y1s, dim=0)
377
378
        #np.zeros((X.shape[0], self.cfg.simulator.num_T, self.cfg.simulator.dim_Y))
379
        self.pred_outcomes = torch.cat([pred_y0s.unsqueeze(1), pred_y1s.unsqueeze(1)], dim=1)
380
        self.cate_cis = self.cate_cis.cpu().numpy()
381
        
382
        return pred_ites
383
384
    def check_intervel(self, confidence_level, y_pred, y_true):
385
        lower = (1 - confidence_level) / 2
386
        upper = 1 - lower
387
        lower_quantile = torch.quantile(y_pred, lower)
388
        upper_quantile = torch.quantile(y_pred, upper)
389
        in_quantiles = torch.logical_and(y_true >= lower_quantile, y_true <= upper_quantile)
390
        return lower_quantile, upper_quantile, in_quantiles
391
392
    def compute_interval(self, po_samples, y_true):
393
        counter = 0
394
        width_list = []
395
        for i in range(po_samples.shape[0]):
396
            lower_quantile, upper_quantile, in_quantiles = self.check_intervel(confidence_level=0.95, y_pred= po_samples[i, :], y_true=y_true[i])
397
            if in_quantiles == True:
398
                counter+=1
399
            width = upper_quantile - lower_quantile
400
            width_list.append(width.unsqueeze(0))
401
        prob = (counter/po_samples.shape[0])
402
        all_width = torch.cat(width_list, dim=0)
403
        median_width = torch.median(all_width, dim=0).values
404
        return prob, median_width