|
a |
|
b/evaluate.py |
|
|
1 |
"""Evaluates the model""" |
|
|
2 |
|
|
|
3 |
import argparse |
|
|
4 |
import logging |
|
|
5 |
import os |
|
|
6 |
|
|
|
7 |
import numpy as np |
|
|
8 |
import pandas as pd |
|
|
9 |
import torch |
|
|
10 |
from torch.autograd import Variable |
|
|
11 |
|
|
|
12 |
import model.data_loader as data_loader |
|
|
13 |
import model.net as net |
|
|
14 |
import utils |
|
|
15 |
from sklearn import linear_model |
|
|
16 |
|
|
|
17 |
def evaluate(model, loss_fn, dataloader, metrics, params, setting, epoch, writer=None): |
|
|
18 |
"""Evaluate the model on `num_steps` batches. |
|
|
19 |
|
|
|
20 |
Args: |
|
|
21 |
model: (torch.nn.Module) the neural network |
|
|
22 |
loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch |
|
|
23 |
dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches data |
|
|
24 |
metrics: (dict) a dictionary of functions that compute a metric using the output and labels of each batch |
|
|
25 |
params: (Params) hyperparameters |
|
|
26 |
num_steps: (int) number of batches to train on, each of size params.batch_size |
|
|
27 |
covar_mode: (bool) include covariate data in dataloader |
|
|
28 |
""" |
|
|
29 |
|
|
|
30 |
# set model to evaluation mode |
|
|
31 |
model.eval() |
|
|
32 |
model.to(params.device) |
|
|
33 |
|
|
|
34 |
# summary for current eval loop |
|
|
35 |
summ = [] |
|
|
36 |
preds = [] # for saving last predictions |
|
|
37 |
bn_activations = [] |
|
|
38 |
|
|
|
39 |
# create storate for tensors for OLS after minibatches |
|
|
40 |
Xhats = [] |
|
|
41 |
Zhats = [] |
|
|
42 |
|
|
|
43 |
|
|
|
44 |
# for counterfactuals |
|
|
45 |
if setting.counterfactuals: |
|
|
46 |
y0_hats = [] |
|
|
47 |
y1_hats = [] |
|
|
48 |
|
|
|
49 |
# compute metrics over the dataset |
|
|
50 |
for batch in dataloader: |
|
|
51 |
summary_batch = {} |
|
|
52 |
batch = {k: v.to(params.device) for k, v in batch.items()} |
|
|
53 |
img_batch = batch["image"].to(params.device, non_blocking=True) |
|
|
54 |
labels_batch = batch["label"].to(params.device, non_blocking=True) |
|
|
55 |
if setting.covar_mode and epoch > params.suppress_t_epochs: |
|
|
56 |
data_batch = batch["t"].to(params.device, non_blocking=True).view(-1,1) |
|
|
57 |
else: |
|
|
58 |
data_batch = torch.zeros((params.batch_size, 1), requires_grad=False).to(params.device, non_blocking=True) |
|
|
59 |
|
|
|
60 |
if params.multi_task: |
|
|
61 |
# x_target_batch = Variable(batch["x"].to(params.device)).type(torch.cuda.LongTensor) |
|
|
62 |
x_target_batch = batch["x"].to(params.device) |
|
|
63 |
y_target_batch = batch["y"].to(params.device) |
|
|
64 |
labels_batch = {'x': x_target_batch, 'y': y_target_batch} |
|
|
65 |
|
|
|
66 |
# compute model output |
|
|
67 |
# output_batch, bn_batch = model(img_batch, data_batch) |
|
|
68 |
output_batch = model(img_batch, data_batch, epoch) |
|
|
69 |
|
|
|
70 |
# calculate loss |
|
|
71 |
if setting.fase == "feature": |
|
|
72 |
# calculate loss for z directly, to get clear how well this can be measured |
|
|
73 |
loss_fn_z = torch.nn.MSELoss() |
|
|
74 |
loss_z = loss_fn_z(output_batch["y"].squeeze(), batch["z"]) |
|
|
75 |
loss = loss_z |
|
|
76 |
summary_batch["loss_z"] = loss_z.item() |
|
|
77 |
else: |
|
|
78 |
loss_fn_y = torch.nn.MSELoss() |
|
|
79 |
loss_y = loss_fn_y(output_batch["y"].squeeze(), batch["y"]) |
|
|
80 |
loss = loss_y |
|
|
81 |
summary_batch["loss_y"] = loss_y.item() |
|
|
82 |
|
|
|
83 |
# calculate loss for colllider x |
|
|
84 |
loss_fn_x = torch.nn.MSELoss() |
|
|
85 |
loss_x = loss_fn_x(output_batch["bnx"].squeeze(), batch["x"]) |
|
|
86 |
summary_batch["loss_x"] = loss_x.item() |
|
|
87 |
if not params.alpha == 1: |
|
|
88 |
# possibly weigh down contribution of estimating x |
|
|
89 |
loss_x *= params.alpha |
|
|
90 |
summary_batch["loss_x_weighted"] = loss_x.item() |
|
|
91 |
|
|
|
92 |
# add x loss to total loss |
|
|
93 |
loss += loss_x |
|
|
94 |
|
|
|
95 |
# add least squares regression on final layer |
|
|
96 |
if params.do_least_squares: |
|
|
97 |
X = batch["x"].view(-1,1) |
|
|
98 |
t = batch["t"].view(-1,1) |
|
|
99 |
Z = output_batch["bnz"] |
|
|
100 |
if Z.ndimension() == 1: |
|
|
101 |
Z.unsqueeze_(1) |
|
|
102 |
Xhat = output_batch["bnx"] |
|
|
103 |
# add intercept |
|
|
104 |
Zi = torch.cat([torch.ones_like(t), Z], 1) |
|
|
105 |
# add treatment info |
|
|
106 |
Zt = torch.cat([Zi, t], 1) |
|
|
107 |
Y = batch["y"].view(-1,1) |
|
|
108 |
|
|
|
109 |
# regress y on final layer, without x |
|
|
110 |
betas_y = net.cholesky_least_squares(Zt, Y, intercept=False) |
|
|
111 |
y_hat = Zt.matmul(betas_y).view(-1,1) |
|
|
112 |
mse_y = ((Y - y_hat)**2).mean() |
|
|
113 |
|
|
|
114 |
summary_batch["regr_b_t"] = betas_y[-1].item() |
|
|
115 |
summary_batch["regr_loss_y"] = mse_y.item() |
|
|
116 |
|
|
|
117 |
# regress x on final layer without x |
|
|
118 |
betas_x = net.cholesky_least_squares(Zi, Xhat, intercept=False) |
|
|
119 |
x_hat = Zi.matmul(betas_x).view(-1,1) |
|
|
120 |
mse_x = ((Xhat - x_hat)**2).mean() |
|
|
121 |
|
|
|
122 |
# store all tensors for single pass after epoch |
|
|
123 |
Xhats.append(Xhat.detach().cpu()) |
|
|
124 |
Zhats.append(Z.detach().cpu()) |
|
|
125 |
|
|
|
126 |
summary_batch["regr_loss_x"] = mse_x.item() |
|
|
127 |
|
|
|
128 |
|
|
|
129 |
# add loss_bn only after n epochs |
|
|
130 |
if params.bottleneck_loss and epoch > params.bn_loss_lag_epochs: |
|
|
131 |
# only add to loss when bigger than margin |
|
|
132 |
if params.bn_loss_margin_type == "dynamic-mean": |
|
|
133 |
# for each batch, calculate loss of just using mean for predicting x |
|
|
134 |
mse_x_mean = ((X - X.mean())**2).mean() |
|
|
135 |
loss_bn = torch.max(torch.zeros_like(mse_x), mse_x_mean - mse_x) |
|
|
136 |
elif params.bn_loss_margin_type == "fixed": |
|
|
137 |
mse_diff = params.bn_loss_margin - mse_x |
|
|
138 |
loss_bn = torch.max(torch.zeros_like(mse_x), mse_diff) |
|
|
139 |
else: |
|
|
140 |
raise NotImplementedError(f'bottleneck loss margin type not implemented: {params.bn_loss_margin_type}') |
|
|
141 |
|
|
|
142 |
# possibly reweigh bottleneck loss and add to total loss |
|
|
143 |
summary_batch["loss_bn"] = loss_bn.item() |
|
|
144 |
# note is this double? |
|
|
145 |
if loss_bn > params.bn_loss_margin: |
|
|
146 |
loss_bn *= params.bottleneck_loss_wt |
|
|
147 |
loss += loss_bn |
|
|
148 |
|
|
|
149 |
# generate counterfactual predictions |
|
|
150 |
if setting.counterfactuals: |
|
|
151 |
batch_t0 = Variable(torch.zeros_like(data_batch).to(torch.float32), requires_grad=False).to(params.device) |
|
|
152 |
batch_t1 = Variable(torch.ones_like(data_batch).to(torch.float32), requires_grad=False).to(params.device) |
|
|
153 |
y0_batch = model(img_batch, batch_t0) |
|
|
154 |
y1_batch = model(img_batch, batch_t1) |
|
|
155 |
y0_hats.append(y0_batch["y"].detach().cpu().numpy()) |
|
|
156 |
y1_hats.append(y1_batch["y"].detach().cpu().numpy()) |
|
|
157 |
|
|
|
158 |
|
|
|
159 |
# write out activations of bottleneck layer |
|
|
160 |
if params.multi_task: |
|
|
161 |
bn_activations.append(output_batch["bnz"]) |
|
|
162 |
else: |
|
|
163 |
bn_activations.append(output_batch["bn"]) |
|
|
164 |
|
|
|
165 |
# extract data from torch Variable, move to cpu, convert to numpy arrays |
|
|
166 |
if (len(setting.outcome) > 1) or params.multi_task: |
|
|
167 |
for var, batch in labels_batch.items(): |
|
|
168 |
labels_batch[var] = batch.data.cpu().numpy() |
|
|
169 |
else: |
|
|
170 |
labels_batch = labels_batch.data.cpu().numpy() |
|
|
171 |
|
|
|
172 |
# compute all metrics on this batch |
|
|
173 |
data_batch = data_batch.data.cpu().numpy() |
|
|
174 |
for var, batch in output_batch.items(): |
|
|
175 |
output_batch[var] = batch.detach().cpu().numpy() |
|
|
176 |
if params.multi_task: |
|
|
177 |
metrics_xy = {m: net.all_metrics[m] for m in setting.metrics_xy} |
|
|
178 |
for var, batch in labels_batch.items(): |
|
|
179 |
for metric, metric_fn in metrics_xy.items(): |
|
|
180 |
summary_batch[metric+"_"+var] = metric_fn(setting, model, output_batch[var], labels_batch[var], data_batch) |
|
|
181 |
if "b_t" in setting.metrics: |
|
|
182 |
summary_batch["b_t"] = net.all_metrics["b_t"](setting, model, None, None) |
|
|
183 |
|
|
|
184 |
else: |
|
|
185 |
NotImplementedError |
|
|
186 |
# summary_batch = {metric: metrics[metric](setting, model, output_batch[setting.outcome[0]], labels_batch, data_batch) |
|
|
187 |
# for metric in metrics} |
|
|
188 |
|
|
|
189 |
summary_batch["loss"] = loss.item() |
|
|
190 |
summ.append(summary_batch) |
|
|
191 |
#if "y" in setting.outcome: |
|
|
192 |
preds.append(output_batch["y"]) |
|
|
193 |
#else: |
|
|
194 |
# preds.append(output_batch[setting.outcome[0]]) |
|
|
195 |
|
|
|
196 |
|
|
|
197 |
|
|
|
198 |
# compute mean of all metrics in summary |
|
|
199 |
metrics_mean = {metric:np.nanmean([x[metric] for x in summ]) for metric in summ[0]} |
|
|
200 |
|
|
|
201 |
# if "ate" in setting.metrics: |
|
|
202 |
# metrics_mean["ate"] = all_metrics["ate"](setting, model, preds, ) |
|
|
203 |
|
|
|
204 |
if params.save_bn_activations: |
|
|
205 |
# write out batch activations |
|
|
206 |
bn_activations = torch.cat(bn_activations, 0).detach().cpu().numpy() |
|
|
207 |
writer.add_histogram("bn_activations", bn_activations, epoch+1) |
|
|
208 |
|
|
|
209 |
|
|
|
210 |
# get means and covariances |
|
|
211 |
if "bottleneck_loss" in setting.metrics: |
|
|
212 |
bn_means = bn_activations.mean(dim=0) |
|
|
213 |
bn_sds = bn_activations.std(dim=0) |
|
|
214 |
bn_cov = net.cov(bn_activations) |
|
|
215 |
bn_offdiags = net.get_of_diag(bn_cov.detach().cpu().numpy()) |
|
|
216 |
writer.add_histogram("bn_covariances", bn_offdiags, epoch+1) |
|
|
217 |
|
|
|
218 |
|
|
|
219 |
|
|
|
220 |
# export predictions |
|
|
221 |
|
|
|
222 |
preds = np.vstack([x.reshape(-1,1) for x in preds]) |
|
|
223 |
writer.add_histogram('predictions', preds, epoch+1) |
|
|
224 |
labels = dataloader.dataset.df[setting.outcome[0]].values.astype(np.float32) |
|
|
225 |
|
|
|
226 |
# predict individual treatment effects (only worth-while when there is an interaction with t) |
|
|
227 |
if setting.counterfactuals: |
|
|
228 |
y0_hats = np.vstack(y0_hats) |
|
|
229 |
y1_hats = np.vstack(y1_hats) |
|
|
230 |
ite_hats = y1_hats - y0_hats |
|
|
231 |
metrics_mean["ite_mean"] = ite_hats.mean() |
|
|
232 |
|
|
|
233 |
y0s = dataloader.dataset.df["y0"].values.astype(np.float32) |
|
|
234 |
y1s = dataloader.dataset.df["y1"].values.astype(np.float32) |
|
|
235 |
ites = y1s - y0s |
|
|
236 |
metrics_mean["pehe"] = np.sqrt(np.mean(np.power((ite_hats - ites), 2))) |
|
|
237 |
|
|
|
238 |
metrics_mean["loss_y1"] = ((y1s - y1_hats)**2).mean() |
|
|
239 |
metrics_mean["loss_y0"] = ((y0s - y0_hats)**2).mean() |
|
|
240 |
|
|
|
241 |
# in case of single last layer where x is part of, do regression on this layer |
|
|
242 |
if params.bn_place == "single-regressor" and params.do_least_squares: |
|
|
243 |
Xhat = torch.cat(Xhats, 0).view(-1,1).float() |
|
|
244 |
Zhat = torch.cat(Zhats, 0).float() |
|
|
245 |
t = torch.tensor(dataloader.dataset.df["t"].values).view(-1,1).float() |
|
|
246 |
Y = torch.tensor(dataloader.dataset.df["y"].values).view(-1,1).float() |
|
|
247 |
|
|
|
248 |
betas_bias = model.betas_bias.cpu() |
|
|
249 |
betas_causal = model.betas_causal.cpu() |
|
|
250 |
|
|
|
251 |
y_hat_bias = torch.cat([torch.ones_like(t), Xhat, Zhat, t], 1).matmul(betas_bias).view(-1,1) |
|
|
252 |
y_hat_causal = torch.cat([torch.ones_like(t), Zhat, t], 1).matmul(betas_causal).view(-1,1) |
|
|
253 |
|
|
|
254 |
reg_mse_bias = ((y_hat_bias - Y)**2).mean() |
|
|
255 |
reg_mse_causal = ((y_hat_causal - Y)**2).mean() |
|
|
256 |
|
|
|
257 |
metrics_mean["regr_bias_loss_y"] = reg_mse_bias |
|
|
258 |
metrics_mean["regr_causal_loss_y"] = reg_mse_causal |
|
|
259 |
|
|
|
260 |
if setting.counterfactuals: |
|
|
261 |
y0_hat_bias = torch.cat([torch.ones_like(t), Xhat, Zhat, torch.zeros_like(t)], 1).matmul(betas_bias).view(-1,1) |
|
|
262 |
y1_hat_bias = torch.cat([torch.ones_like(t), Xhat, Zhat, torch.ones_like(t)], 1).matmul(betas_bias).view(-1,1) |
|
|
263 |
y0_hat_causal = torch.cat([torch.ones_like(t), Zhat, torch.zeros_like(t)], 1).matmul(betas_causal).view(-1,1) |
|
|
264 |
y1_hat_causal = torch.cat([torch.ones_like(t), Zhat, torch.ones_like(t)], 1).matmul(betas_causal).view(-1,1) |
|
|
265 |
|
|
|
266 |
ite_hats_bias = y1_hat_bias - y0_hat_bias |
|
|
267 |
ite_hats_causal = y1_hat_causal - y0_hat_causal |
|
|
268 |
|
|
|
269 |
writer.add_scalars("pehe", {"regr_bias": np.sqrt(((ite_hat_bias - ites)**2).mean())}, epoch+1) |
|
|
270 |
writer.add_scalars("pehe", {"regr_causal": np.sqrt(((ite_hat_causal - ites)**2).mean())}, epoch+1) |
|
|
271 |
writer.add_scalars("loss_y1", {"regr_bias": ((y1s - y1_hat_bias)**2).mean()}, epoch+1) |
|
|
272 |
writer.add_scalars("loss_y0", {"regr_bias": ((y0s - y0_hat_bias)**2).mean()}, epoch+1) |
|
|
273 |
writer.add_scalars("loss_y1", {"regr_causal": ((y1s - y1_hat_causal)**2).mean()}, epoch+1) |
|
|
274 |
writer.add_scalars("loss_y0", {"regr_causal": ((y0s - y0_hat_causal)**2).mean()}, epoch+1) |
|
|
275 |
|
|
|
276 |
|
|
|
277 |
outtensors = { |
|
|
278 |
'bn_activations': bn_activations, |
|
|
279 |
'predictions': preds, |
|
|
280 |
'xhat': np.vstack(Xhats) |
|
|
281 |
} |
|
|
282 |
|
|
|
283 |
metrics_string = " ; ".join("{}: {:05.3f}".format(k, v) for k, v in metrics_mean.items()) |
|
|
284 |
logging.info("- Eval metrics : " + metrics_string) |
|
|
285 |
|
|
|
286 |
return metrics_mean, outtensors |