|
a |
|
b/catenets/models/diffpo/diffpo_learner.py |
|
|
1 |
from typing import Any, Callable, List |
|
|
2 |
|
|
|
3 |
import numpy as np |
|
|
4 |
import torch |
|
|
5 |
from torch import nn |
|
|
6 |
import os |
|
|
7 |
import tqdm |
|
|
8 |
import catenets.logger as log |
|
|
9 |
from catenets.models.constants import ( |
|
|
10 |
DEFAULT_BATCH_SIZE, |
|
|
11 |
DEFAULT_DIM_P_OUT, |
|
|
12 |
DEFAULT_DIM_P_R, |
|
|
13 |
DEFAULT_DIM_S_OUT, |
|
|
14 |
DEFAULT_DIM_S_R, |
|
|
15 |
DEFAULT_LAYERS_OUT, |
|
|
16 |
DEFAULT_LAYERS_R, |
|
|
17 |
DEFAULT_N_ITER, |
|
|
18 |
DEFAULT_N_ITER_MIN, |
|
|
19 |
DEFAULT_N_ITER_PRINT, |
|
|
20 |
DEFAULT_PATIENCE, |
|
|
21 |
DEFAULT_PENALTY_L2, |
|
|
22 |
DEFAULT_PENALTY_ORTHOGONAL, |
|
|
23 |
DEFAULT_SEED, |
|
|
24 |
DEFAULT_NJOBS, |
|
|
25 |
DEFAULT_STEP_SIZE, |
|
|
26 |
DEFAULT_VAL_SPLIT, |
|
|
27 |
LARGE_VAL, |
|
|
28 |
) |
|
|
29 |
from catenets.models.torch.base import DEVICE, BaseCATEEstimator |
|
|
30 |
from catenets.models.torch.utils.model_utils import make_val_split |
|
|
31 |
import pandas as pd |
|
|
32 |
# Hydra |
|
|
33 |
from omegaconf import DictConfig |
|
|
34 |
import json |
|
|
35 |
import datetime |
|
|
36 |
|
|
|
37 |
from .src.main_model_table import TabCSDI |
|
|
38 |
from .src.utils_table import train |
|
|
39 |
from .dataset_acic import get_dataloader |
|
|
40 |
|
|
|
41 |
from .PropensityNet import load_data |
|
|
42 |
|
|
|
43 |
|
|
|
44 |
torch.manual_seed(0) |
|
|
45 |
|
|
|
46 |
class AverageMeter(object): |
|
|
47 |
"""Computes and stores the average and current value""" |
|
|
48 |
def __init__(self): |
|
|
49 |
self.reset() |
|
|
50 |
|
|
|
51 |
def reset(self): |
|
|
52 |
self.val = 0 |
|
|
53 |
self.avg = 0 |
|
|
54 |
self.sum = 0 |
|
|
55 |
self.count = 0 |
|
|
56 |
|
|
|
57 |
def update(self, val, n=1): |
|
|
58 |
self.val = val |
|
|
59 |
self.sum += val * n |
|
|
60 |
self.count += n |
|
|
61 |
self.avg = self.sum / self.count |
|
|
62 |
|
|
|
63 |
|
|
|
64 |
class DiffPOLearner(BaseCATEEstimator): |
|
|
65 |
""" |
|
|
66 |
A flexible treatment effect estimator based on the EconML framework. |
|
|
67 |
""" |
|
|
68 |
|
|
|
69 |
def __init__( |
|
|
70 |
self, |
|
|
71 |
cfg: DictConfig, |
|
|
72 |
num_features: int, |
|
|
73 |
binary_y: bool, |
|
|
74 |
) -> None: |
|
|
75 |
self.config = cfg.DiffPOLearner |
|
|
76 |
self.diffpo_path = cfg.diffpo_path |
|
|
77 |
self.config.diffusion.cond_dim = num_features+1 # make sure inner dimension matches the dataset |
|
|
78 |
self.est = None |
|
|
79 |
self.propnet = None |
|
|
80 |
self.device = DEVICE |
|
|
81 |
self.cate_cis = None # confidence intervals, dim: 2, n, num_T-1, dim_Y |
|
|
82 |
self.pred_outcomes = None |
|
|
83 |
|
|
|
84 |
# create folder if diffpo_path + 'data' does not exist |
|
|
85 |
if not os.path.exists(self.diffpo_path): |
|
|
86 |
os.makedirs(self.diffpo_path) |
|
|
87 |
|
|
|
88 |
# Store data for their pipeline |
|
|
89 |
self.data_dir = self.diffpo_path+'/data/' |
|
|
90 |
if not os.path.exists(self.data_dir): |
|
|
91 |
os.makedirs(self.data_dir) |
|
|
92 |
|
|
|
93 |
return None |
|
|
94 |
|
|
|
95 |
def reshape_data(self, X: np.ndarray, w: np.ndarray, outcomes: np.ndarray) -> None: |
|
|
96 |
data = np.concatenate([w.reshape(-1,1),outcomes[:,0],outcomes[:,1],outcomes[:,0],outcomes[:,1],X], axis=1) |
|
|
97 |
data_df = pd.DataFrame(data) |
|
|
98 |
# Create masking array of same shape as pp_data and initialize with 1s |
|
|
99 |
mask = np.ones(data_df.shape) |
|
|
100 |
mask[:,1] = w |
|
|
101 |
mask[:,2] = 1-w |
|
|
102 |
mask[:,3] = 0 |
|
|
103 |
mask[:,4] = 0 |
|
|
104 |
mask_df = pd.DataFrame(mask) |
|
|
105 |
|
|
|
106 |
return data_df, mask_df |
|
|
107 |
|
|
|
108 |
def train(self, X: np.ndarray, y: np.ndarray, w: np.ndarray, outcomes:np.ndarray) -> None: |
|
|
109 |
""" |
|
|
110 |
Prepare data and train DiffPO Learner |
|
|
111 |
""" |
|
|
112 |
log.info("Training data shapes: X: {}, Y: {}, T: {}".format(X.shape, y.shape, w.shape)) |
|
|
113 |
|
|
|
114 |
if not os.path.exists(self.data_dir): |
|
|
115 |
os.makedirs(self.data_dir) |
|
|
116 |
data, mask = self.reshape_data(X, w, outcomes) |
|
|
117 |
|
|
|
118 |
# create destination folders if not exist |
|
|
119 |
if not os.path.exists(self.data_dir+"acic2018_norm_data/"): |
|
|
120 |
os.makedirs(self.data_dir+"acic2018_norm_data/") |
|
|
121 |
if not os.path.exists(self.data_dir+"acic2018_mask/"): |
|
|
122 |
os.makedirs(self.data_dir+"acic2018_mask/") |
|
|
123 |
|
|
|
124 |
# save intermediate data |
|
|
125 |
data.to_csv(self.data_dir+"acic2018_norm_data/data_pp.csv", index=False) |
|
|
126 |
mask.to_csv(self.data_dir+"acic2018_mask/data_pp.csv", index=False) |
|
|
127 |
|
|
|
128 |
# Remove old files |
|
|
129 |
if os.path.exists(self.data_dir+"missing_ratio-0.2_seed-1_current_id-data_max-min_norm.pk"): |
|
|
130 |
os.remove(self.data_dir+"missing_ratio-0.2_seed-1_current_id-data_max-min_norm.pk") |
|
|
131 |
if os.path.exists(self.data_dir+"missing_ratio-0.2_seed-1.pk"): |
|
|
132 |
os.remove(self.data_dir+"missing_ratio-0.2_seed-1.pk") |
|
|
133 |
|
|
|
134 |
# Create folder |
|
|
135 |
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
136 |
|
|
|
137 |
# define these as variables |
|
|
138 |
nfold = 1 |
|
|
139 |
config = "acic2018.yaml" |
|
|
140 |
current_id = "data_pp" |
|
|
141 |
device = DEVICE |
|
|
142 |
seed = 1 |
|
|
143 |
testmissingratio = 0.2 |
|
|
144 |
unconditional = 0 |
|
|
145 |
modelfolder = "" |
|
|
146 |
nsample = 1 |
|
|
147 |
perform_training = 1 |
|
|
148 |
|
|
|
149 |
foldername = self.diffpo_path + "/save/acic_fold" + str(nfold) + "_" + current_time + "/" |
|
|
150 |
# print("model folder:", foldername) |
|
|
151 |
os.makedirs(foldername, exist_ok=True) |
|
|
152 |
|
|
|
153 |
current_id = "data_pp" |
|
|
154 |
# print('Start exe_acic on current_id', current_id) |
|
|
155 |
|
|
|
156 |
# Every loader contains "observed_data", "observed_mask", "gt_mask", "timepoints" |
|
|
157 |
training_size = 1 |
|
|
158 |
|
|
|
159 |
train_loader, valid_loader, _ = get_dataloader( |
|
|
160 |
seed=seed, |
|
|
161 |
nfold=nfold, |
|
|
162 |
batch_size=self.config["train"]["batch_size"], |
|
|
163 |
missing_ratio=testmissingratio, |
|
|
164 |
dataset_name = self.config["dataset"]["data_name"], |
|
|
165 |
current_id = current_id, |
|
|
166 |
training_size = training_size, |
|
|
167 |
data_path=self.data_dir, |
|
|
168 |
x_dim=X.shape[1], |
|
|
169 |
) |
|
|
170 |
|
|
|
171 |
#=======================First train and fix propnet====================== |
|
|
172 |
# Train a propensitynet on this dataset |
|
|
173 |
|
|
|
174 |
propnet = load_data(dataset_name = self.config["dataset"]["data_name"], current_id=current_id, x_dim=X.shape[1], data_path=self.data_dir) |
|
|
175 |
|
|
|
176 |
# frozen the trained_propnet |
|
|
177 |
# print('Finish training propnet and fix the parameters') |
|
|
178 |
propnet.eval() |
|
|
179 |
# ======================================================================== |
|
|
180 |
|
|
|
181 |
propnet = propnet.to(device) |
|
|
182 |
|
|
|
183 |
model = TabCSDI(self.config, self.device).to(self.device) |
|
|
184 |
# Train the model |
|
|
185 |
train( |
|
|
186 |
model, |
|
|
187 |
self.config["train"], |
|
|
188 |
train_loader, |
|
|
189 |
valid_loader=valid_loader, |
|
|
190 |
valid_epoch_interval=self.config["train"]["valid_epoch_interval"], |
|
|
191 |
foldername=foldername, |
|
|
192 |
propnet = propnet |
|
|
193 |
) |
|
|
194 |
|
|
|
195 |
directory = self.diffpo_path + "/save_model/" + current_id |
|
|
196 |
if not os.path.exists(directory): |
|
|
197 |
os.makedirs(directory) |
|
|
198 |
|
|
|
199 |
# # load model |
|
|
200 |
# model.load_state_dict(torch.load(directory + "/model_weights.pth")) |
|
|
201 |
|
|
|
202 |
# save model |
|
|
203 |
torch.save(model.state_dict(), directory + "/model_weights.pth") |
|
|
204 |
|
|
|
205 |
|
|
|
206 |
|
|
|
207 |
# predict function with bool return_po and return potential outcome if true |
|
|
208 |
def predict(self, X: np.ndarray, T0: np.ndarray = None, T1: np.ndarray = None, outcomes: np.ndarray = None) -> np.ndarray: |
|
|
209 |
""" |
|
|
210 |
Predict the treatment effect using the DiffPO estimator. |
|
|
211 |
""" |
|
|
212 |
# Store data for their pipeline |
|
|
213 |
data_dir = self.data_dir |
|
|
214 |
|
|
|
215 |
data, mask = self.reshape_data(X, T0, outcomes) |
|
|
216 |
|
|
|
217 |
data.to_csv(data_dir+"acic2018_norm_data/data_pp_test.csv", index=False) |
|
|
218 |
mask.to_csv(data_dir+"acic2018_mask/data_pp_test.csv", index=False) |
|
|
219 |
|
|
|
220 |
# Remove old files |
|
|
221 |
if os.path.exists(data_dir+"missing_ratio-0.2_seed-1_current_id-data_max-min_norm.pk"): |
|
|
222 |
os.remove(data_dir+"missing_ratio-0.2_seed-1_current_id-data_max-min_norm.pk") |
|
|
223 |
if os.path.exists(data_dir+"missing_ratio-0.2_seed-1.pk"): |
|
|
224 |
os.remove(data_dir+"missing_ratio-0.2_seed-1.pk") |
|
|
225 |
|
|
|
226 |
# Create folder |
|
|
227 |
current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
228 |
|
|
|
229 |
# define these as variables |
|
|
230 |
nfold = 1 |
|
|
231 |
current_id = "data_pp_test" |
|
|
232 |
current_id_train = "data_pp" |
|
|
233 |
seed = 1 |
|
|
234 |
testmissingratio = 0.2 |
|
|
235 |
nsample = 50 |
|
|
236 |
perform_training = 1 |
|
|
237 |
|
|
|
238 |
foldername = "./save/acic_fold" + str(nfold) + "_" + current_time + "/" |
|
|
239 |
# print("model folder:", foldername) |
|
|
240 |
os.makedirs(foldername, exist_ok=True) |
|
|
241 |
|
|
|
242 |
# Every loader contains "observed_data", "observed_mask", "gt_mask", "timepoints" |
|
|
243 |
training_size = 0 |
|
|
244 |
_,_,test_loader = get_dataloader( |
|
|
245 |
seed=seed, |
|
|
246 |
nfold=nfold, |
|
|
247 |
batch_size=1, |
|
|
248 |
missing_ratio=testmissingratio, |
|
|
249 |
dataset_name = self.config["dataset"]["data_name"], |
|
|
250 |
current_id = current_id, |
|
|
251 |
training_size = training_size, |
|
|
252 |
data_path=data_dir, |
|
|
253 |
x_dim=X.shape[1], |
|
|
254 |
) |
|
|
255 |
|
|
|
256 |
# load model |
|
|
257 |
directory = self.diffpo_path + "/save_model/" + current_id_train |
|
|
258 |
os.makedirs(directory, exist_ok=True) |
|
|
259 |
model = TabCSDI(self.config, self.device).to(self.device) |
|
|
260 |
model.load_state_dict(torch.load(directory + "/model_weights.pth")) |
|
|
261 |
|
|
|
262 |
# get cates |
|
|
263 |
return self.evaluate(model, test_loader, nsample, foldername=foldername) |
|
|
264 |
|
|
|
265 |
def predict_outcomes(self, X: np.ndarray, T0: np.ndarray = None, T1: np.ndarray = None, outcomes: np.ndarray = None) -> np.ndarray: |
|
|
266 |
""" |
|
|
267 |
Predict the potential outcomes using the DiffPO estimator. |
|
|
268 |
""" |
|
|
269 |
# add outer dimension to self.pred_outcomes |
|
|
270 |
return self.pred_outcomes.cpu().numpy().reshape(self.pred_outcomes.shape[0], self.pred_outcomes.shape[1], 1) |
|
|
271 |
|
|
|
272 |
def explain(self, X: np.ndarray, background_samples: np.ndarray = None, explainer_limit: int = None) -> np.ndarray: |
|
|
273 |
""" |
|
|
274 |
Explain the treatment effect using the EconML estimator. |
|
|
275 |
""" |
|
|
276 |
if explainer_limit is None: |
|
|
277 |
explainer_limit = X.shape[0] |
|
|
278 |
|
|
|
279 |
return self.est.shap_values(X[:explainer_limit], background_samples=None) |
|
|
280 |
|
|
|
281 |
def infer_effect_ci(self, X, T0) -> np.ndarray: |
|
|
282 |
""" |
|
|
283 |
Infer the confidence interval of the treatment effect using the EconML estimator. |
|
|
284 |
""" |
|
|
285 |
cates_conf_lbs = self.cate_cis[0] |
|
|
286 |
cates_conf_ups = self.cate_cis[1] |
|
|
287 |
|
|
|
288 |
temp = cates_conf_lbs[T0 != 0] |
|
|
289 |
cates_conf_lbs[T0 != 0] = -cates_conf_ups[T0 != 0] |
|
|
290 |
cates_conf_ups[T0 != 0] = -temp |
|
|
291 |
return np.array([cates_conf_lbs, cates_conf_ups]) |
|
|
292 |
|
|
|
293 |
def evaluate(self, model, test_loader, nsample=100, scaler=1, mean_scaler=0, foldername=""): |
|
|
294 |
# Control random seed in the current script. |
|
|
295 |
torch.manual_seed(0) |
|
|
296 |
np.random.seed(0) |
|
|
297 |
|
|
|
298 |
with torch.no_grad(): |
|
|
299 |
model.eval() |
|
|
300 |
mse_total = 0 |
|
|
301 |
mae_total = 0 |
|
|
302 |
evalpoints_total = 0 |
|
|
303 |
|
|
|
304 |
pehe_test = AverageMeter() |
|
|
305 |
y0_test = AverageMeter() |
|
|
306 |
y1_test = AverageMeter() |
|
|
307 |
|
|
|
308 |
# for uncertainty |
|
|
309 |
y0_samples = [] |
|
|
310 |
y1_samples = [] |
|
|
311 |
y0_true_list = [] |
|
|
312 |
y1_true_list = [] |
|
|
313 |
ite_samples = [] |
|
|
314 |
ite_true_list = [] |
|
|
315 |
pred_ites = [] |
|
|
316 |
pred_y0s = [] |
|
|
317 |
pred_y1s = [] |
|
|
318 |
|
|
|
319 |
for batch_no, test_batch in enumerate(test_loader, start=1): |
|
|
320 |
# Get model outputs |
|
|
321 |
output = model.evaluate(test_batch, nsample) |
|
|
322 |
samples, observed_data, target_mask, observed_mask, observed_tp = output |
|
|
323 |
|
|
|
324 |
# Extract relevant quantities |
|
|
325 |
y0_samples.append(samples[:,:,0]) |
|
|
326 |
y1_samples.append(samples[:,:,1]) |
|
|
327 |
ite_samples.append(samples[:,:,1] - samples[:,:,0]) |
|
|
328 |
|
|
|
329 |
# Get point estimation through median |
|
|
330 |
est_data = torch.median(samples, dim=1).values |
|
|
331 |
|
|
|
332 |
# Get true ite |
|
|
333 |
obs_data = observed_data.squeeze(1) |
|
|
334 |
true_ite = obs_data[:, 2] - obs_data[:, 1] |
|
|
335 |
ite_true_list.append(true_ite) |
|
|
336 |
|
|
|
337 |
# Get predicted ite |
|
|
338 |
pred_y0 = est_data[:, 0] |
|
|
339 |
pred_y1 = est_data[:, 1] |
|
|
340 |
pred_y0s.append(pred_y0) |
|
|
341 |
pred_y1s.append(pred_y1) |
|
|
342 |
y0_true_list.append(obs_data[:, 1]) |
|
|
343 |
y1_true_list.append(obs_data[:, 2]) |
|
|
344 |
pred_ite = pred_y1 - pred_y0 |
|
|
345 |
pred_ites.append(pred_ite) |
|
|
346 |
|
|
|
347 |
#y0_test.update(diff_y0, obs_data.size(0)) |
|
|
348 |
#diff_y0 = np.mean((pred_y0.cpu().numpy()-obs_data[:, 1].cpu().numpy())**2) |
|
|
349 |
#y1_test.update(diff_y1, obs_data.size(0)) |
|
|
350 |
#diff_y1 = np.mean((pred_y1.cpu().numpy()-obs_data[:, 2].cpu().numpy())**2) |
|
|
351 |
#pehe_test.update(diff_ite, obs_data.size(0)) |
|
|
352 |
#diff_ite = np.mean((true_ite.cpu().numpy()-est_ite.cpu().numpy())**2) |
|
|
353 |
|
|
|
354 |
#---------------uncertainty estimation------------------------- |
|
|
355 |
pred_samples_y0 = torch.cat(y0_samples, dim=0) |
|
|
356 |
pred_samples_y1 = torch.cat(y1_samples, dim=0) |
|
|
357 |
pred_samples_ite = torch.cat(ite_samples, dim=0) |
|
|
358 |
|
|
|
359 |
truth_y0 = torch.cat(y0_true_list, dim=0) |
|
|
360 |
truth_y1 = torch.cat(y1_true_list, dim=0) |
|
|
361 |
truth_ite = torch.cat(ite_true_list, dim=0) |
|
|
362 |
|
|
|
363 |
prob_0, median_width_0 = self.compute_interval(pred_samples_y0, truth_y0) |
|
|
364 |
prob_1, median_width_1 = self.compute_interval(pred_samples_y1, truth_y1) |
|
|
365 |
prob_ite, median_width_ite = self.compute_interval(pred_samples_ite, truth_ite) |
|
|
366 |
|
|
|
367 |
self.cate_cis = torch.zeros(2, pred_samples_ite.shape[0], 1) # confidence intervals, dim: 2, n, dim_Y |
|
|
368 |
for i in range(pred_samples_ite.shape[0]): |
|
|
369 |
lower_quantile, upper_quantile, in_quantiles = self.check_intervel(confidence_level=0.95, y_pred= pred_samples_ite[i, :], y_true=truth_ite[i]) |
|
|
370 |
self.cate_cis[0, i, 0] = lower_quantile |
|
|
371 |
self.cate_cis[1, i, 0] = upper_quantile |
|
|
372 |
|
|
|
373 |
#---------------------------------------------------------------- |
|
|
374 |
pred_ites = torch.cat(pred_ites, dim=0) |
|
|
375 |
pred_y0s = torch.cat(pred_y0s, dim=0) |
|
|
376 |
pred_y1s = torch.cat(pred_y1s, dim=0) |
|
|
377 |
|
|
|
378 |
#np.zeros((X.shape[0], self.cfg.simulator.num_T, self.cfg.simulator.dim_Y)) |
|
|
379 |
self.pred_outcomes = torch.cat([pred_y0s.unsqueeze(1), pred_y1s.unsqueeze(1)], dim=1) |
|
|
380 |
self.cate_cis = self.cate_cis.cpu().numpy() |
|
|
381 |
|
|
|
382 |
return pred_ites |
|
|
383 |
|
|
|
384 |
def check_intervel(self, confidence_level, y_pred, y_true): |
|
|
385 |
lower = (1 - confidence_level) / 2 |
|
|
386 |
upper = 1 - lower |
|
|
387 |
lower_quantile = torch.quantile(y_pred, lower) |
|
|
388 |
upper_quantile = torch.quantile(y_pred, upper) |
|
|
389 |
in_quantiles = torch.logical_and(y_true >= lower_quantile, y_true <= upper_quantile) |
|
|
390 |
return lower_quantile, upper_quantile, in_quantiles |
|
|
391 |
|
|
|
392 |
def compute_interval(self, po_samples, y_true): |
|
|
393 |
counter = 0 |
|
|
394 |
width_list = [] |
|
|
395 |
for i in range(po_samples.shape[0]): |
|
|
396 |
lower_quantile, upper_quantile, in_quantiles = self.check_intervel(confidence_level=0.95, y_pred= po_samples[i, :], y_true=y_true[i]) |
|
|
397 |
if in_quantiles == True: |
|
|
398 |
counter+=1 |
|
|
399 |
width = upper_quantile - lower_quantile |
|
|
400 |
width_list.append(width.unsqueeze(0)) |
|
|
401 |
prob = (counter/po_samples.shape[0]) |
|
|
402 |
all_width = torch.cat(width_list, dim=0) |
|
|
403 |
median_width = torch.median(all_width, dim=0).values |
|
|
404 |
return prob, median_width |