[39fb2b]: / evaluate.py

Download this file

286 lines (228 with data), 12.3 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
"""Evaluates the model"""
import argparse
import logging
import os
import numpy as np
import pandas as pd
import torch
from torch.autograd import Variable
import model.data_loader as data_loader
import model.net as net
import utils
from sklearn import linear_model
def evaluate(model, loss_fn, dataloader, metrics, params, setting, epoch, writer=None):
"""Evaluate the model on `num_steps` batches.
Args:
model: (torch.nn.Module) the neural network
loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch
dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches data
metrics: (dict) a dictionary of functions that compute a metric using the output and labels of each batch
params: (Params) hyperparameters
num_steps: (int) number of batches to train on, each of size params.batch_size
covar_mode: (bool) include covariate data in dataloader
"""
# set model to evaluation mode
model.eval()
model.to(params.device)
# summary for current eval loop
summ = []
preds = [] # for saving last predictions
bn_activations = []
# create storate for tensors for OLS after minibatches
Xhats = []
Zhats = []
# for counterfactuals
if setting.counterfactuals:
y0_hats = []
y1_hats = []
# compute metrics over the dataset
for batch in dataloader:
summary_batch = {}
batch = {k: v.to(params.device) for k, v in batch.items()}
img_batch = batch["image"].to(params.device, non_blocking=True)
labels_batch = batch["label"].to(params.device, non_blocking=True)
if setting.covar_mode and epoch > params.suppress_t_epochs:
data_batch = batch["t"].to(params.device, non_blocking=True).view(-1,1)
else:
data_batch = torch.zeros((params.batch_size, 1), requires_grad=False).to(params.device, non_blocking=True)
if params.multi_task:
# x_target_batch = Variable(batch["x"].to(params.device)).type(torch.cuda.LongTensor)
x_target_batch = batch["x"].to(params.device)
y_target_batch = batch["y"].to(params.device)
labels_batch = {'x': x_target_batch, 'y': y_target_batch}
# compute model output
# output_batch, bn_batch = model(img_batch, data_batch)
output_batch = model(img_batch, data_batch, epoch)
# calculate loss
if setting.fase == "feature":
# calculate loss for z directly, to get clear how well this can be measured
loss_fn_z = torch.nn.MSELoss()
loss_z = loss_fn_z(output_batch["y"].squeeze(), batch["z"])
loss = loss_z
summary_batch["loss_z"] = loss_z.item()
else:
loss_fn_y = torch.nn.MSELoss()
loss_y = loss_fn_y(output_batch["y"].squeeze(), batch["y"])
loss = loss_y
summary_batch["loss_y"] = loss_y.item()
# calculate loss for colllider x
loss_fn_x = torch.nn.MSELoss()
loss_x = loss_fn_x(output_batch["bnx"].squeeze(), batch["x"])
summary_batch["loss_x"] = loss_x.item()
if not params.alpha == 1:
# possibly weigh down contribution of estimating x
loss_x *= params.alpha
summary_batch["loss_x_weighted"] = loss_x.item()
# add x loss to total loss
loss += loss_x
# add least squares regression on final layer
if params.do_least_squares:
X = batch["x"].view(-1,1)
t = batch["t"].view(-1,1)
Z = output_batch["bnz"]
if Z.ndimension() == 1:
Z.unsqueeze_(1)
Xhat = output_batch["bnx"]
# add intercept
Zi = torch.cat([torch.ones_like(t), Z], 1)
# add treatment info
Zt = torch.cat([Zi, t], 1)
Y = batch["y"].view(-1,1)
# regress y on final layer, without x
betas_y = net.cholesky_least_squares(Zt, Y, intercept=False)
y_hat = Zt.matmul(betas_y).view(-1,1)
mse_y = ((Y - y_hat)**2).mean()
summary_batch["regr_b_t"] = betas_y[-1].item()
summary_batch["regr_loss_y"] = mse_y.item()
# regress x on final layer without x
betas_x = net.cholesky_least_squares(Zi, Xhat, intercept=False)
x_hat = Zi.matmul(betas_x).view(-1,1)
mse_x = ((Xhat - x_hat)**2).mean()
# store all tensors for single pass after epoch
Xhats.append(Xhat.detach().cpu())
Zhats.append(Z.detach().cpu())
summary_batch["regr_loss_x"] = mse_x.item()
# add loss_bn only after n epochs
if params.bottleneck_loss and epoch > params.bn_loss_lag_epochs:
# only add to loss when bigger than margin
if params.bn_loss_margin_type == "dynamic-mean":
# for each batch, calculate loss of just using mean for predicting x
mse_x_mean = ((X - X.mean())**2).mean()
loss_bn = torch.max(torch.zeros_like(mse_x), mse_x_mean - mse_x)
elif params.bn_loss_margin_type == "fixed":
mse_diff = params.bn_loss_margin - mse_x
loss_bn = torch.max(torch.zeros_like(mse_x), mse_diff)
else:
raise NotImplementedError(f'bottleneck loss margin type not implemented: {params.bn_loss_margin_type}')
# possibly reweigh bottleneck loss and add to total loss
summary_batch["loss_bn"] = loss_bn.item()
# note is this double?
if loss_bn > params.bn_loss_margin:
loss_bn *= params.bottleneck_loss_wt
loss += loss_bn
# generate counterfactual predictions
if setting.counterfactuals:
batch_t0 = Variable(torch.zeros_like(data_batch).to(torch.float32), requires_grad=False).to(params.device)
batch_t1 = Variable(torch.ones_like(data_batch).to(torch.float32), requires_grad=False).to(params.device)
y0_batch = model(img_batch, batch_t0)
y1_batch = model(img_batch, batch_t1)
y0_hats.append(y0_batch["y"].detach().cpu().numpy())
y1_hats.append(y1_batch["y"].detach().cpu().numpy())
# write out activations of bottleneck layer
if params.multi_task:
bn_activations.append(output_batch["bnz"])
else:
bn_activations.append(output_batch["bn"])
# extract data from torch Variable, move to cpu, convert to numpy arrays
if (len(setting.outcome) > 1) or params.multi_task:
for var, batch in labels_batch.items():
labels_batch[var] = batch.data.cpu().numpy()
else:
labels_batch = labels_batch.data.cpu().numpy()
# compute all metrics on this batch
data_batch = data_batch.data.cpu().numpy()
for var, batch in output_batch.items():
output_batch[var] = batch.detach().cpu().numpy()
if params.multi_task:
metrics_xy = {m: net.all_metrics[m] for m in setting.metrics_xy}
for var, batch in labels_batch.items():
for metric, metric_fn in metrics_xy.items():
summary_batch[metric+"_"+var] = metric_fn(setting, model, output_batch[var], labels_batch[var], data_batch)
if "b_t" in setting.metrics:
summary_batch["b_t"] = net.all_metrics["b_t"](setting, model, None, None)
else:
NotImplementedError
# summary_batch = {metric: metrics[metric](setting, model, output_batch[setting.outcome[0]], labels_batch, data_batch)
# for metric in metrics}
summary_batch["loss"] = loss.item()
summ.append(summary_batch)
#if "y" in setting.outcome:
preds.append(output_batch["y"])
#else:
# preds.append(output_batch[setting.outcome[0]])
# compute mean of all metrics in summary
metrics_mean = {metric:np.nanmean([x[metric] for x in summ]) for metric in summ[0]}
# if "ate" in setting.metrics:
# metrics_mean["ate"] = all_metrics["ate"](setting, model, preds, )
if params.save_bn_activations:
# write out batch activations
bn_activations = torch.cat(bn_activations, 0).detach().cpu().numpy()
writer.add_histogram("bn_activations", bn_activations, epoch+1)
# get means and covariances
if "bottleneck_loss" in setting.metrics:
bn_means = bn_activations.mean(dim=0)
bn_sds = bn_activations.std(dim=0)
bn_cov = net.cov(bn_activations)
bn_offdiags = net.get_of_diag(bn_cov.detach().cpu().numpy())
writer.add_histogram("bn_covariances", bn_offdiags, epoch+1)
# export predictions
preds = np.vstack([x.reshape(-1,1) for x in preds])
writer.add_histogram('predictions', preds, epoch+1)
labels = dataloader.dataset.df[setting.outcome[0]].values.astype(np.float32)
# predict individual treatment effects (only worth-while when there is an interaction with t)
if setting.counterfactuals:
y0_hats = np.vstack(y0_hats)
y1_hats = np.vstack(y1_hats)
ite_hats = y1_hats - y0_hats
metrics_mean["ite_mean"] = ite_hats.mean()
y0s = dataloader.dataset.df["y0"].values.astype(np.float32)
y1s = dataloader.dataset.df["y1"].values.astype(np.float32)
ites = y1s - y0s
metrics_mean["pehe"] = np.sqrt(np.mean(np.power((ite_hats - ites), 2)))
metrics_mean["loss_y1"] = ((y1s - y1_hats)**2).mean()
metrics_mean["loss_y0"] = ((y0s - y0_hats)**2).mean()
# in case of single last layer where x is part of, do regression on this layer
if params.bn_place == "single-regressor" and params.do_least_squares:
Xhat = torch.cat(Xhats, 0).view(-1,1).float()
Zhat = torch.cat(Zhats, 0).float()
t = torch.tensor(dataloader.dataset.df["t"].values).view(-1,1).float()
Y = torch.tensor(dataloader.dataset.df["y"].values).view(-1,1).float()
betas_bias = model.betas_bias.cpu()
betas_causal = model.betas_causal.cpu()
y_hat_bias = torch.cat([torch.ones_like(t), Xhat, Zhat, t], 1).matmul(betas_bias).view(-1,1)
y_hat_causal = torch.cat([torch.ones_like(t), Zhat, t], 1).matmul(betas_causal).view(-1,1)
reg_mse_bias = ((y_hat_bias - Y)**2).mean()
reg_mse_causal = ((y_hat_causal - Y)**2).mean()
metrics_mean["regr_bias_loss_y"] = reg_mse_bias
metrics_mean["regr_causal_loss_y"] = reg_mse_causal
if setting.counterfactuals:
y0_hat_bias = torch.cat([torch.ones_like(t), Xhat, Zhat, torch.zeros_like(t)], 1).matmul(betas_bias).view(-1,1)
y1_hat_bias = torch.cat([torch.ones_like(t), Xhat, Zhat, torch.ones_like(t)], 1).matmul(betas_bias).view(-1,1)
y0_hat_causal = torch.cat([torch.ones_like(t), Zhat, torch.zeros_like(t)], 1).matmul(betas_causal).view(-1,1)
y1_hat_causal = torch.cat([torch.ones_like(t), Zhat, torch.ones_like(t)], 1).matmul(betas_causal).view(-1,1)
ite_hats_bias = y1_hat_bias - y0_hat_bias
ite_hats_causal = y1_hat_causal - y0_hat_causal
writer.add_scalars("pehe", {"regr_bias": np.sqrt(((ite_hat_bias - ites)**2).mean())}, epoch+1)
writer.add_scalars("pehe", {"regr_causal": np.sqrt(((ite_hat_causal - ites)**2).mean())}, epoch+1)
writer.add_scalars("loss_y1", {"regr_bias": ((y1s - y1_hat_bias)**2).mean()}, epoch+1)
writer.add_scalars("loss_y0", {"regr_bias": ((y0s - y0_hat_bias)**2).mean()}, epoch+1)
writer.add_scalars("loss_y1", {"regr_causal": ((y1s - y1_hat_causal)**2).mean()}, epoch+1)
writer.add_scalars("loss_y0", {"regr_causal": ((y0s - y0_hat_causal)**2).mean()}, epoch+1)
outtensors = {
'bn_activations': bn_activations,
'predictions': preds,
'xhat': np.vstack(Xhats)
}
metrics_string = " ; ".join("{}: {:05.3f}".format(k, v) for k, v in metrics_mean.items())
logging.info("- Eval metrics : " + metrics_string)
return metrics_mean, outtensors