Diff of /evaluate.py [000000] .. [39fb2b]

Switch to unified view

a b/evaluate.py
1
"""Evaluates the model"""
2
3
import argparse
4
import logging
5
import os
6
7
import numpy as np
8
import pandas as pd
9
import torch
10
from torch.autograd import Variable
11
12
import model.data_loader as data_loader
13
import model.net as net
14
import utils
15
from sklearn import linear_model
16
17
def evaluate(model, loss_fn, dataloader, metrics, params, setting, epoch, writer=None):
18
    """Evaluate the model on `num_steps` batches.
19
20
    Args:
21
        model: (torch.nn.Module) the neural network
22
        loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch
23
        dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches data
24
        metrics: (dict) a dictionary of functions that compute a metric using the output and labels of each batch
25
        params: (Params) hyperparameters
26
        num_steps: (int) number of batches to train on, each of size params.batch_size
27
        covar_mode: (bool) include covariate data in dataloader
28
    """
29
30
    # set model to evaluation mode
31
    model.eval()
32
    model.to(params.device)
33
34
    # summary for current eval loop
35
    summ  = []
36
    preds = [] # for saving last predictions
37
    bn_activations = []
38
39
    # create storate for tensors for OLS after minibatches
40
    Xhats = []
41
    Zhats = []
42
43
44
    # for counterfactuals
45
    if setting.counterfactuals:
46
        y0_hats = []
47
        y1_hats = []
48
49
    # compute metrics over the dataset
50
    for batch in dataloader:
51
        summary_batch = {}
52
        batch = {k: v.to(params.device) for k, v in batch.items()}
53
        img_batch    = batch["image"].to(params.device, non_blocking=True)
54
        labels_batch = batch["label"].to(params.device, non_blocking=True)
55
        if setting.covar_mode and epoch > params.suppress_t_epochs:
56
            data_batch = batch["t"].to(params.device, non_blocking=True).view(-1,1)
57
        else:
58
            data_batch = torch.zeros((params.batch_size, 1), requires_grad=False).to(params.device, non_blocking=True)
59
60
        if params.multi_task:
61
            # x_target_batch = Variable(batch["x"].to(params.device)).type(torch.cuda.LongTensor)
62
            x_target_batch = batch["x"].to(params.device)
63
            y_target_batch = batch["y"].to(params.device)
64
            labels_batch = {'x': x_target_batch, 'y': y_target_batch}
65
        
66
        # compute model output
67
        # output_batch, bn_batch = model(img_batch, data_batch)
68
        output_batch = model(img_batch, data_batch, epoch)
69
70
        # calculate loss
71
        if setting.fase == "feature":
72
            # calculate loss for z directly, to get clear how well this can be measured
73
            loss_fn_z = torch.nn.MSELoss()
74
            loss_z = loss_fn_z(output_batch["y"].squeeze(), batch["z"])
75
            loss   = loss_z
76
            summary_batch["loss_z"] = loss_z.item()
77
        else:
78
            loss_fn_y = torch.nn.MSELoss()
79
            loss_y = loss_fn_y(output_batch["y"].squeeze(), batch["y"])
80
            loss   = loss_y
81
            summary_batch["loss_y"] = loss_y.item()
82
83
        # calculate loss for colllider x
84
        loss_fn_x = torch.nn.MSELoss()
85
        loss_x = loss_fn_x(output_batch["bnx"].squeeze(), batch["x"])
86
        summary_batch["loss_x"] = loss_x.item()
87
        if not params.alpha == 1:
88
            # possibly weigh down contribution of estimating x
89
            loss_x *= params.alpha
90
            summary_batch["loss_x_weighted"] = loss_x.item()
91
92
        # add x loss to total loss
93
        loss += loss_x
94
95
        # add least squares regression on final layer
96
        if params.do_least_squares:
97
            X    = batch["x"].view(-1,1)
98
            t    = batch["t"].view(-1,1)
99
            Z    = output_batch["bnz"]
100
            if Z.ndimension() == 1:
101
                Z.unsqueeze_(1)
102
            Xhat = output_batch["bnx"]
103
            # add intercept
104
            Zi = torch.cat([torch.ones_like(t), Z], 1)
105
            # add treatment info
106
            Zt = torch.cat([Zi, t], 1)
107
            Y  = batch["y"].view(-1,1)
108
109
            # regress y on final layer, without x
110
            betas_y = net.cholesky_least_squares(Zt, Y, intercept=False)
111
            y_hat   = Zt.matmul(betas_y).view(-1,1)
112
            mse_y  = ((Y - y_hat)**2).mean()
113
114
            summary_batch["regr_b_t"] = betas_y[-1].item()
115
            summary_batch["regr_loss_y"] = mse_y.item()
116
117
            # regress x on final layer without x
118
            betas_x = net.cholesky_least_squares(Zi, Xhat, intercept=False)
119
            x_hat   = Zi.matmul(betas_x).view(-1,1)
120
            mse_x  = ((Xhat - x_hat)**2).mean()
121
122
            # store all tensors for single pass after epoch
123
            Xhats.append(Xhat.detach().cpu())
124
            Zhats.append(Z.detach().cpu())
125
126
            summary_batch["regr_loss_x"] = mse_x.item()
127
128
129
        # add loss_bn only after n epochs
130
        if params.bottleneck_loss and epoch > params.bn_loss_lag_epochs:
131
            # only add to loss when bigger than margin
132
            if params.bn_loss_margin_type == "dynamic-mean":
133
                # for each batch, calculate loss of just using mean for predicting x
134
                mse_x_mean = ((X - X.mean())**2).mean()
135
                loss_bn = torch.max(torch.zeros_like(mse_x), mse_x_mean - mse_x)
136
            elif params.bn_loss_margin_type == "fixed":
137
                mse_diff = params.bn_loss_margin - mse_x
138
                loss_bn = torch.max(torch.zeros_like(mse_x), mse_diff)
139
            else:
140
                raise NotImplementedError(f'bottleneck loss margin type not implemented: {params.bn_loss_margin_type}')
141
142
            # possibly reweigh bottleneck loss and add to total loss
143
            summary_batch["loss_bn"] = loss_bn.item()
144
            # note is this double?
145
            if loss_bn > params.bn_loss_margin:
146
                loss_bn *= params.bottleneck_loss_wt
147
                loss    += loss_bn
148
149
       # generate counterfactual predictions
150
        if setting.counterfactuals:
151
            batch_t0 = Variable(torch.zeros_like(data_batch).to(torch.float32), requires_grad=False).to(params.device)
152
            batch_t1 = Variable(torch.ones_like(data_batch).to(torch.float32), requires_grad=False).to(params.device)
153
            y0_batch = model(img_batch, batch_t0)
154
            y1_batch = model(img_batch, batch_t1)
155
            y0_hats.append(y0_batch["y"].detach().cpu().numpy())
156
            y1_hats.append(y1_batch["y"].detach().cpu().numpy())
157
158
159
        # write out activations of bottleneck layer
160
        if params.multi_task:
161
            bn_activations.append(output_batch["bnz"])
162
        else:
163
            bn_activations.append(output_batch["bn"])
164
165
        # extract data from torch Variable, move to cpu, convert to numpy arrays
166
        if (len(setting.outcome) > 1) or params.multi_task:
167
            for var, batch in labels_batch.items():
168
                labels_batch[var] = batch.data.cpu().numpy()
169
        else:
170
            labels_batch = labels_batch.data.cpu().numpy()
171
172
        # compute all metrics on this batch
173
        data_batch = data_batch.data.cpu().numpy()
174
        for var, batch in output_batch.items():
175
            output_batch[var] = batch.detach().cpu().numpy()
176
        if params.multi_task:
177
            metrics_xy = {m: net.all_metrics[m] for m in setting.metrics_xy}
178
            for var, batch in labels_batch.items():
179
                for metric, metric_fn in metrics_xy.items():
180
                    summary_batch[metric+"_"+var] = metric_fn(setting, model, output_batch[var], labels_batch[var], data_batch)
181
            if "b_t" in setting.metrics:
182
                summary_batch["b_t"] = net.all_metrics["b_t"](setting, model, None, None)
183
184
        else:
185
            NotImplementedError
186
            # summary_batch = {metric: metrics[metric](setting, model, output_batch[setting.outcome[0]], labels_batch, data_batch)
187
            #                 for metric in metrics}
188
189
        summary_batch["loss"]   = loss.item()
190
        summ.append(summary_batch)
191
        #if "y" in setting.outcome:
192
        preds.append(output_batch["y"])
193
        #else:
194
        #    preds.append(output_batch[setting.outcome[0]])
195
196
197
198
    # compute mean of all metrics in summary
199
    metrics_mean = {metric:np.nanmean([x[metric] for x in summ]) for metric in summ[0]} 
200
201
#    if "ate" in setting.metrics:
202
 #       metrics_mean["ate"] = all_metrics["ate"](setting, model, preds, )
203
    
204
    if params.save_bn_activations:
205
        # write out batch activations
206
        bn_activations = torch.cat(bn_activations, 0).detach().cpu().numpy()
207
        writer.add_histogram("bn_activations", bn_activations, epoch+1)
208
209
210
    # get means and covariances
211
    if "bottleneck_loss" in setting.metrics:
212
        bn_means    = bn_activations.mean(dim=0)
213
        bn_sds      = bn_activations.std(dim=0)
214
        bn_cov      = net.cov(bn_activations)
215
        bn_offdiags = net.get_of_diag(bn_cov.detach().cpu().numpy())
216
        writer.add_histogram("bn_covariances", bn_offdiags, epoch+1)
217
218
219
220
    # export predictions
221
222
    preds  = np.vstack([x.reshape(-1,1) for x in preds])
223
    writer.add_histogram('predictions', preds, epoch+1)
224
    labels = dataloader.dataset.df[setting.outcome[0]].values.astype(np.float32)
225
226
    # predict individual treatment effects (only worth-while when there is an interaction with t)
227
    if setting.counterfactuals:
228
        y0_hats = np.vstack(y0_hats)
229
        y1_hats = np.vstack(y1_hats)
230
        ite_hats = y1_hats - y0_hats
231
        metrics_mean["ite_mean"] = ite_hats.mean()
232
233
        y0s = dataloader.dataset.df["y0"].values.astype(np.float32)
234
        y1s = dataloader.dataset.df["y1"].values.astype(np.float32)
235
        ites = y1s - y0s
236
        metrics_mean["pehe"] = np.sqrt(np.mean(np.power((ite_hats - ites), 2)))
237
238
        metrics_mean["loss_y1"] = ((y1s - y1_hats)**2).mean()
239
        metrics_mean["loss_y0"] = ((y0s - y0_hats)**2).mean()
240
241
    # in case of single last layer where x is part of, do regression on this layer
242
    if params.bn_place == "single-regressor" and params.do_least_squares:
243
        Xhat  = torch.cat(Xhats, 0).view(-1,1).float()
244
        Zhat  = torch.cat(Zhats, 0).float()
245
        t     = torch.tensor(dataloader.dataset.df["t"].values).view(-1,1).float()
246
        Y     = torch.tensor(dataloader.dataset.df["y"].values).view(-1,1).float()
247
248
        betas_bias   = model.betas_bias.cpu()
249
        betas_causal = model.betas_causal.cpu()
250
251
        y_hat_bias   = torch.cat([torch.ones_like(t), Xhat, Zhat, t], 1).matmul(betas_bias).view(-1,1)
252
        y_hat_causal = torch.cat([torch.ones_like(t), Zhat, t], 1).matmul(betas_causal).view(-1,1)
253
254
        reg_mse_bias   = ((y_hat_bias - Y)**2).mean()
255
        reg_mse_causal = ((y_hat_causal - Y)**2).mean()
256
257
        metrics_mean["regr_bias_loss_y"] = reg_mse_bias
258
        metrics_mean["regr_causal_loss_y"] = reg_mse_causal
259
260
        if setting.counterfactuals:
261
            y0_hat_bias   = torch.cat([torch.ones_like(t), Xhat, Zhat, torch.zeros_like(t)], 1).matmul(betas_bias).view(-1,1)
262
            y1_hat_bias   = torch.cat([torch.ones_like(t), Xhat, Zhat, torch.ones_like(t)], 1).matmul(betas_bias).view(-1,1)
263
            y0_hat_causal = torch.cat([torch.ones_like(t), Zhat, torch.zeros_like(t)], 1).matmul(betas_causal).view(-1,1)
264
            y1_hat_causal = torch.cat([torch.ones_like(t), Zhat, torch.ones_like(t)], 1).matmul(betas_causal).view(-1,1)
265
       
266
            ite_hats_bias = y1_hat_bias - y0_hat_bias
267
            ite_hats_causal = y1_hat_causal - y0_hat_causal
268
269
            writer.add_scalars("pehe", {"regr_bias": np.sqrt(((ite_hat_bias - ites)**2).mean())}, epoch+1)
270
            writer.add_scalars("pehe", {"regr_causal": np.sqrt(((ite_hat_causal - ites)**2).mean())}, epoch+1)
271
            writer.add_scalars("loss_y1", {"regr_bias": ((y1s - y1_hat_bias)**2).mean()}, epoch+1)
272
            writer.add_scalars("loss_y0", {"regr_bias": ((y0s - y0_hat_bias)**2).mean()}, epoch+1)
273
            writer.add_scalars("loss_y1", {"regr_causal": ((y1s - y1_hat_causal)**2).mean()}, epoch+1)
274
            writer.add_scalars("loss_y0", {"regr_causal": ((y0s - y0_hat_causal)**2).mean()}, epoch+1)
275
276
277
    outtensors = {
278
        'bn_activations': bn_activations,
279
        'predictions': preds,
280
        'xhat': np.vstack(Xhats)
281
    }
282
283
    metrics_string = " ; ".join("{}: {:05.3f}".format(k, v) for k, v in metrics_mean.items())
284
    logging.info("- Eval metrics : " + metrics_string)
285
286
    return metrics_mean, outtensors