Switch to side-by-side view

--- a
+++ b/catenets/models/diffpo/diffpo_learner.py
@@ -0,0 +1,404 @@
+from typing import Any, Callable, List
+
+import numpy as np
+import torch
+from torch import nn
+import os
+import tqdm
+import catenets.logger as log
+from catenets.models.constants import (
+    DEFAULT_BATCH_SIZE,
+    DEFAULT_DIM_P_OUT,
+    DEFAULT_DIM_P_R,
+    DEFAULT_DIM_S_OUT,
+    DEFAULT_DIM_S_R,
+    DEFAULT_LAYERS_OUT,
+    DEFAULT_LAYERS_R,
+    DEFAULT_N_ITER,
+    DEFAULT_N_ITER_MIN,
+    DEFAULT_N_ITER_PRINT,
+    DEFAULT_PATIENCE,
+    DEFAULT_PENALTY_L2,
+    DEFAULT_PENALTY_ORTHOGONAL,
+    DEFAULT_SEED,
+    DEFAULT_NJOBS,
+    DEFAULT_STEP_SIZE,
+    DEFAULT_VAL_SPLIT,
+    LARGE_VAL,
+)
+from catenets.models.torch.base import DEVICE, BaseCATEEstimator
+from catenets.models.torch.utils.model_utils import make_val_split
+import pandas as pd
+# Hydra
+from omegaconf import DictConfig
+import json
+import datetime
+
+from .src.main_model_table import TabCSDI
+from .src.utils_table import train
+from .dataset_acic import get_dataloader
+
+from .PropensityNet import load_data
+
+
+torch.manual_seed(0)
+
+class AverageMeter(object):
+    """Computes and stores the average and current value"""
+    def __init__(self):
+        self.reset()
+
+    def reset(self):
+        self.val = 0
+        self.avg = 0
+        self.sum = 0
+        self.count = 0
+
+    def update(self, val, n=1):
+        self.val = val
+        self.sum += val * n
+        self.count += n
+        self.avg = self.sum / self.count
+
+
+class DiffPOLearner(BaseCATEEstimator):
+    """
+    A flexible treatment effect estimator based on the EconML framework.
+    """
+
+    def __init__(
+        self,
+        cfg: DictConfig,
+        num_features: int,
+        binary_y: bool,
+    ) -> None:
+        self.config = cfg.DiffPOLearner
+        self.diffpo_path = cfg.diffpo_path
+        self.config.diffusion.cond_dim = num_features+1 # make sure inner dimension matches the dataset
+        self.est = None
+        self.propnet = None
+        self.device = DEVICE
+        self.cate_cis = None # confidence intervals, dim: 2, n, num_T-1, dim_Y
+        self.pred_outcomes = None
+
+        # create folder if diffpo_path + 'data' does not exist
+        if not os.path.exists(self.diffpo_path):
+            os.makedirs(self.diffpo_path)
+        
+        # Store data for their pipeline
+        self.data_dir = self.diffpo_path+'/data/'
+        if not os.path.exists(self.data_dir):
+            os.makedirs(self.data_dir)
+
+        return None
+
+    def reshape_data(self, X: np.ndarray, w: np.ndarray, outcomes: np.ndarray) -> None:
+        data = np.concatenate([w.reshape(-1,1),outcomes[:,0],outcomes[:,1],outcomes[:,0],outcomes[:,1],X], axis=1)
+        data_df = pd.DataFrame(data)
+        # Create masking array of same shape as pp_data and initialize with 1s
+        mask = np.ones(data_df.shape)
+        mask[:,1] = w
+        mask[:,2] = 1-w
+        mask[:,3] = 0
+        mask[:,4] = 0
+        mask_df = pd.DataFrame(mask)
+
+        return data_df, mask_df
+
+    def train(self, X: np.ndarray, y: np.ndarray, w: np.ndarray, outcomes:np.ndarray) -> None:
+        """
+        Prepare data and train DiffPO Learner
+        """
+        log.info("Training data shapes: X: {}, Y: {}, T: {}".format(X.shape, y.shape, w.shape))
+
+        if not os.path.exists(self.data_dir):
+            os.makedirs(self.data_dir)
+        data, mask = self.reshape_data(X, w, outcomes)
+        
+        # create destination folders if not exist
+        if not os.path.exists(self.data_dir+"acic2018_norm_data/"):
+            os.makedirs(self.data_dir+"acic2018_norm_data/")
+        if not os.path.exists(self.data_dir+"acic2018_mask/"):
+            os.makedirs(self.data_dir+"acic2018_mask/")
+
+        # save intermediate data
+        data.to_csv(self.data_dir+"acic2018_norm_data/data_pp.csv", index=False)
+        mask.to_csv(self.data_dir+"acic2018_mask/data_pp.csv", index=False)
+
+        # Remove old files
+        if os.path.exists(self.data_dir+"missing_ratio-0.2_seed-1_current_id-data_max-min_norm.pk"):
+            os.remove(self.data_dir+"missing_ratio-0.2_seed-1_current_id-data_max-min_norm.pk")
+        if os.path.exists(self.data_dir+"missing_ratio-0.2_seed-1.pk"):
+            os.remove(self.data_dir+"missing_ratio-0.2_seed-1.pk")
+
+        # Create folder
+        current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
+        
+        # define these as variables
+        nfold = 1
+        config = "acic2018.yaml"
+        current_id = "data_pp"
+        device = DEVICE
+        seed = 1
+        testmissingratio = 0.2
+        unconditional = 0
+        modelfolder = ""
+        nsample = 1
+        perform_training = 1
+
+        foldername = self.diffpo_path + "/save/acic_fold" + str(nfold) + "_" + current_time + "/"
+        # print("model folder:", foldername)
+        os.makedirs(foldername, exist_ok=True)
+
+        current_id = "data_pp"
+        # print('Start exe_acic on current_id', current_id)
+
+        # Every loader contains "observed_data", "observed_mask", "gt_mask", "timepoints"
+        training_size = 1
+        
+        train_loader, valid_loader, _ = get_dataloader(
+            seed=seed,
+            nfold=nfold,
+            batch_size=self.config["train"]["batch_size"],
+            missing_ratio=testmissingratio,
+            dataset_name = self.config["dataset"]["data_name"],
+            current_id = current_id,
+            training_size = training_size,
+            data_path=self.data_dir,
+            x_dim=X.shape[1],
+        )
+
+        #=======================First train and fix propnet======================
+        # Train a propensitynet on this dataset
+
+        propnet = load_data(dataset_name = self.config["dataset"]["data_name"], current_id=current_id, x_dim=X.shape[1], data_path=self.data_dir)
+
+        # frozen the trained_propnet
+        # print('Finish training propnet and fix the parameters')
+        propnet.eval()
+        # ========================================================================
+
+        propnet = propnet.to(device)
+
+        model = TabCSDI(self.config, self.device).to(self.device)
+        # Train the model
+        train(
+            model,
+            self.config["train"],
+            train_loader,
+            valid_loader=valid_loader,
+            valid_epoch_interval=self.config["train"]["valid_epoch_interval"],
+            foldername=foldername,
+            propnet = propnet
+        )
+
+        directory = self.diffpo_path + "/save_model/" + current_id
+        if not os.path.exists(directory):
+            os.makedirs(directory)
+        
+        # # load model
+        # model.load_state_dict(torch.load(directory + "/model_weights.pth"))
+
+        # save model
+        torch.save(model.state_dict(), directory + "/model_weights.pth")
+        
+        
+
+    # predict function with bool return_po and return potential outcome if true
+    def predict(self, X: np.ndarray, T0: np.ndarray = None, T1: np.ndarray = None, outcomes: np.ndarray = None) -> np.ndarray:
+        """
+        Predict the treatment effect using the DiffPO estimator.
+        """
+        # Store data for their pipeline
+        data_dir = self.data_dir
+        
+        data, mask = self.reshape_data(X, T0, outcomes)
+        
+        data.to_csv(data_dir+"acic2018_norm_data/data_pp_test.csv", index=False)
+        mask.to_csv(data_dir+"acic2018_mask/data_pp_test.csv", index=False)
+
+        # Remove old files
+        if os.path.exists(data_dir+"missing_ratio-0.2_seed-1_current_id-data_max-min_norm.pk"):
+            os.remove(data_dir+"missing_ratio-0.2_seed-1_current_id-data_max-min_norm.pk")
+        if os.path.exists(data_dir+"missing_ratio-0.2_seed-1.pk"):
+            os.remove(data_dir+"missing_ratio-0.2_seed-1.pk")
+
+        # Create folder
+        current_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
+        
+        # define these as variables
+        nfold = 1
+        current_id = "data_pp_test"
+        current_id_train = "data_pp"
+        seed = 1
+        testmissingratio = 0.2
+        nsample = 50
+        perform_training = 1
+
+        foldername = "./save/acic_fold" + str(nfold) + "_" + current_time + "/"
+        # print("model folder:", foldername)
+        os.makedirs(foldername, exist_ok=True)
+
+        # Every loader contains "observed_data", "observed_mask", "gt_mask", "timepoints"
+        training_size = 0
+        _,_,test_loader = get_dataloader(
+            seed=seed,
+            nfold=nfold,
+            batch_size=1,
+            missing_ratio=testmissingratio,
+            dataset_name = self.config["dataset"]["data_name"],
+            current_id = current_id,
+            training_size = training_size,
+            data_path=data_dir,
+            x_dim=X.shape[1],
+        )
+
+        # load model
+        directory = self.diffpo_path + "/save_model/" + current_id_train
+        os.makedirs(directory, exist_ok=True)
+        model = TabCSDI(self.config, self.device).to(self.device)
+        model.load_state_dict(torch.load(directory + "/model_weights.pth"))
+
+        # get cates
+        return self.evaluate(model, test_loader, nsample, foldername=foldername)
+    
+    def predict_outcomes(self, X: np.ndarray, T0: np.ndarray = None, T1: np.ndarray = None, outcomes: np.ndarray = None) -> np.ndarray:
+        """
+        Predict the potential outcomes using the DiffPO estimator.
+        """
+        # add outer dimension to self.pred_outcomes
+        return self.pred_outcomes.cpu().numpy().reshape(self.pred_outcomes.shape[0], self.pred_outcomes.shape[1], 1)
+
+    def explain(self, X: np.ndarray, background_samples: np.ndarray = None, explainer_limit: int = None) -> np.ndarray:
+        """
+        Explain the treatment effect using the EconML estimator.
+        """
+        if explainer_limit is None:
+            explainer_limit = X.shape[0]
+
+        return self.est.shap_values(X[:explainer_limit], background_samples=None)
+    
+    def infer_effect_ci(self, X, T0) -> np.ndarray:
+        """
+        Infer the confidence interval of the treatment effect using the EconML estimator.
+        """
+        cates_conf_lbs = self.cate_cis[0]
+        cates_conf_ups = self.cate_cis[1]
+
+        temp = cates_conf_lbs[T0 != 0]
+        cates_conf_lbs[T0 != 0] = -cates_conf_ups[T0 != 0]
+        cates_conf_ups[T0 != 0] = -temp
+        return np.array([cates_conf_lbs, cates_conf_ups])
+    
+    def evaluate(self, model, test_loader, nsample=100, scaler=1, mean_scaler=0, foldername=""):
+        # Control random seed in the current script.
+        torch.manual_seed(0)
+        np.random.seed(0)
+
+        with torch.no_grad():
+            model.eval()
+            mse_total = 0
+            mae_total = 0
+            evalpoints_total = 0
+
+            pehe_test = AverageMeter()
+            y0_test = AverageMeter()
+            y1_test = AverageMeter()
+
+            # for uncertainty
+            y0_samples = []
+            y1_samples = []
+            y0_true_list = []
+            y1_true_list = []
+            ite_samples = []
+            ite_true_list = []
+            pred_ites = []
+            pred_y0s = []
+            pred_y1s = []
+            
+            for batch_no, test_batch in enumerate(test_loader, start=1):
+                # Get model outputs
+                output = model.evaluate(test_batch, nsample) 
+                samples, observed_data, target_mask, observed_mask, observed_tp = output
+
+                # Extract relevant quantities
+                y0_samples.append(samples[:,:,0]) 
+                y1_samples.append(samples[:,:,1]) 
+                ite_samples.append(samples[:,:,1] - samples[:,:,0])
+
+                # Get point estimation through median
+                est_data = torch.median(samples, dim=1).values
+
+                # Get true ite
+                obs_data = observed_data.squeeze(1)
+                true_ite = obs_data[:, 2] - obs_data[:, 1] 
+                ite_true_list.append(true_ite)
+
+                # Get predicted ite
+                pred_y0 = est_data[:, 0]
+                pred_y1 = est_data[:, 1]
+                pred_y0s.append(pred_y0)
+                pred_y1s.append(pred_y1)
+                y0_true_list.append(obs_data[:, 1])
+                y1_true_list.append(obs_data[:, 2])
+                pred_ite = pred_y1 - pred_y0
+                pred_ites.append(pred_ite)
+
+                #y0_test.update(diff_y0, obs_data.size(0))    
+                #diff_y0 = np.mean((pred_y0.cpu().numpy()-obs_data[:, 1].cpu().numpy())**2)
+                #y1_test.update(diff_y1, obs_data.size(0)) 
+                #diff_y1 = np.mean((pred_y1.cpu().numpy()-obs_data[:, 2].cpu().numpy())**2)
+                #pehe_test.update(diff_ite, obs_data.size(0))    
+                #diff_ite = np.mean((true_ite.cpu().numpy()-est_ite.cpu().numpy())**2)
+
+#---------------uncertainty estimation-------------------------
+            pred_samples_y0 = torch.cat(y0_samples, dim=0)
+            pred_samples_y1 = torch.cat(y1_samples, dim=0)
+            pred_samples_ite = torch.cat(ite_samples, dim=0)
+
+            truth_y0 = torch.cat(y0_true_list, dim=0) 
+            truth_y1 = torch.cat(y1_true_list, dim=0) 
+            truth_ite = torch.cat(ite_true_list, dim=0)
+
+            prob_0, median_width_0 = self.compute_interval(pred_samples_y0, truth_y0)
+            prob_1, median_width_1 = self.compute_interval(pred_samples_y1, truth_y1)
+            prob_ite, median_width_ite = self.compute_interval(pred_samples_ite, truth_ite)
+
+            self.cate_cis = torch.zeros(2, pred_samples_ite.shape[0], 1) # confidence intervals, dim: 2, n, dim_Y
+            for i in range(pred_samples_ite.shape[0]):
+                lower_quantile, upper_quantile, in_quantiles = self.check_intervel(confidence_level=0.95, y_pred= pred_samples_ite[i, :], y_true=truth_ite[i])
+                self.cate_cis[0, i, 0] = lower_quantile
+                self.cate_cis[1, i, 0] = upper_quantile
+           
+    #----------------------------------------------------------------
+        pred_ites = torch.cat(pred_ites, dim=0)
+        pred_y0s = torch.cat(pred_y0s, dim=0)
+        pred_y1s = torch.cat(pred_y1s, dim=0)
+
+        #np.zeros((X.shape[0], self.cfg.simulator.num_T, self.cfg.simulator.dim_Y))
+        self.pred_outcomes = torch.cat([pred_y0s.unsqueeze(1), pred_y1s.unsqueeze(1)], dim=1)
+        self.cate_cis = self.cate_cis.cpu().numpy()
+        
+        return pred_ites
+
+    def check_intervel(self, confidence_level, y_pred, y_true):
+        lower = (1 - confidence_level) / 2
+        upper = 1 - lower
+        lower_quantile = torch.quantile(y_pred, lower)
+        upper_quantile = torch.quantile(y_pred, upper)
+        in_quantiles = torch.logical_and(y_true >= lower_quantile, y_true <= upper_quantile)
+        return lower_quantile, upper_quantile, in_quantiles
+
+    def compute_interval(self, po_samples, y_true):
+        counter = 0
+        width_list = []
+        for i in range(po_samples.shape[0]):
+            lower_quantile, upper_quantile, in_quantiles = self.check_intervel(confidence_level=0.95, y_pred= po_samples[i, :], y_true=y_true[i])
+            if in_quantiles == True:
+                counter+=1
+            width = upper_quantile - lower_quantile
+            width_list.append(width.unsqueeze(0))
+        prob = (counter/po_samples.shape[0])
+        all_width = torch.cat(width_list, dim=0)
+        median_width = torch.median(all_width, dim=0).values
+        return prob, median_width
\ No newline at end of file