|
a |
|
b/train.py |
|
|
1 |
"""Train the model""" |
|
|
2 |
|
|
|
3 |
import argparse |
|
|
4 |
import logging |
|
|
5 |
import os, shutil |
|
|
6 |
|
|
|
7 |
import numpy as np |
|
|
8 |
import pandas as pd |
|
|
9 |
from sklearn.utils.class_weight import compute_class_weight |
|
|
10 |
import torch |
|
|
11 |
import torch.optim as optim |
|
|
12 |
import torchvision.models as models |
|
|
13 |
from torch.autograd import Variable |
|
|
14 |
from torch.utils.tensorboard import SummaryWriter |
|
|
15 |
from tqdm import tqdm |
|
|
16 |
# from torchsummary import summary |
|
|
17 |
|
|
|
18 |
import utils |
|
|
19 |
import json |
|
|
20 |
import model.net as net |
|
|
21 |
import model.data_loader as data_loader |
|
|
22 |
from evaluate import evaluate |
|
|
23 |
|
|
|
24 |
parser = argparse.ArgumentParser() |
|
|
25 |
parser.add_argument('--data-dir', default='data', help="Directory containing the dataset") |
|
|
26 |
parser.add_argument('--model-dir', default='experiments', help="Directory containing params.json") |
|
|
27 |
parser.add_argument('--setting-dir', default='settings', help="Directory with different settings") |
|
|
28 |
parser.add_argument('--setting', default='collider-prognosticfactor', help="Directory contain setting.json, experimental setting, data-generation, regression model etc") |
|
|
29 |
parser.add_argument('--fase', default='xybn', help='fase of training model, see manuscript for details. x, y, xy, bn, or feature') |
|
|
30 |
parser.add_argument('--experiment', default='', help="Manual name for experiment for logging, will be subdir of setting") |
|
|
31 |
parser.add_argument('--restore-file', default=None, |
|
|
32 |
help="Optional, name of the file in --model_dir containing weights to reload before \ |
|
|
33 |
training") # 'best' or 'train' |
|
|
34 |
parser.add_argument('--restore-last', action='store_true', help="continue a last run") |
|
|
35 |
parser.add_argument('--restore-warm', action='store_true', help="continue on the run called 'warm-start.pth'") |
|
|
36 |
parser.add_argument('--use-last', action="store_true", help="use last state dict instead of 'best' (use for early stopping manually)") |
|
|
37 |
parser.add_argument('--cold-start', action='store_true', help="ignore previous state dicts (weights), even if they exist") |
|
|
38 |
parser.add_argument('--warm-start', dest='cold_start', action='store_false', help="start from previous state dict") |
|
|
39 |
parser.add_argument('--disable-cuda', action='store_true', help="Disable Cuda") |
|
|
40 |
parser.add_argument('--no-parallel', action="store_false", help="no multiple GPU", dest="parallel") |
|
|
41 |
parser.add_argument('--parallel', action="store_true", help="multiple GPU", dest="parallel") |
|
|
42 |
parser.add_argument('--gpu', default=0, type=int, help='if not running in parallel (=all gpus), only use this gpu') |
|
|
43 |
parser.add_argument('--intercept', action="store_true", help="dummy run for getting intercept baseline results") |
|
|
44 |
# parser.add_argument('--visdom', action='store_true', help='generate plots with visdom') |
|
|
45 |
# parser.add_argument('--novisdom', dest='visdom', action='store_false', help='dont plot with visdom') |
|
|
46 |
parser.add_argument('--monitor-grads', action='store_true', help='keep track of mean norm of gradients') |
|
|
47 |
parser.set_defaults(parallel=False, cold_start=True, use_last=False, intercept=False, restore_last=False, save_preds=False, |
|
|
48 |
monitor_grads=False, restore_warm=False |
|
|
49 |
# visdom=False |
|
|
50 |
) |
|
|
51 |
|
|
|
52 |
def train(model, optimizer, loss_fn, dataloader, metrics, params, setting, writer=None, epoch=None): |
|
|
53 |
"""Train the model on `num_steps` batches |
|
|
54 |
|
|
|
55 |
Args: |
|
|
56 |
model: (torch.nn.Module) the neural network |
|
|
57 |
optimizer: (torch.optim) optimizer for parameters of model |
|
|
58 |
loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch |
|
|
59 |
dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches training data |
|
|
60 |
metrics: (dict) a dictionary of functions that compute a metric using the output and labels of each batch |
|
|
61 |
params: (Params) hyperparameters |
|
|
62 |
num_steps: (int) number of batches to train on, each of size params.batch_size |
|
|
63 |
""" |
|
|
64 |
global train_tensor_keys, logdir |
|
|
65 |
|
|
|
66 |
# set model to training mode |
|
|
67 |
model.train() |
|
|
68 |
|
|
|
69 |
# summary for current training loop and a running average object for loss |
|
|
70 |
summ = [] |
|
|
71 |
loss_avg = utils.RunningAverage() |
|
|
72 |
|
|
|
73 |
# create storate for tensors for OLS after minibatches |
|
|
74 |
ts = [] |
|
|
75 |
Xs = [] |
|
|
76 |
Xtrues = [] |
|
|
77 |
Ys = [] |
|
|
78 |
Xhats = [] |
|
|
79 |
Yhats = [] |
|
|
80 |
Zhats = [] |
|
|
81 |
|
|
|
82 |
# Use tqdm for progress bar |
|
|
83 |
with tqdm(total=len(dataloader)) as progress_bar: |
|
|
84 |
for i, batch in enumerate(dataloader): |
|
|
85 |
summary_batch = {} |
|
|
86 |
# put batch on cuda |
|
|
87 |
batch = {k: v.to(params.device) for k, v in batch.items()} |
|
|
88 |
if not (setting.covar_mode and epoch > params.suppress_t_epochs): |
|
|
89 |
batch["t"] = torch.zeros_like(batch['t']) |
|
|
90 |
Xs.append(batch['x'].detach().cpu()) |
|
|
91 |
Xtrues.append(batch['x_true'].detach().cpu()) |
|
|
92 |
|
|
|
93 |
# compute model output and loss |
|
|
94 |
output_batch = model(batch['image'], batch['t'].view(-1,1), epoch) |
|
|
95 |
Yhats.append(output_batch['y'].detach().cpu()) |
|
|
96 |
|
|
|
97 |
# calculate loss |
|
|
98 |
if args.fase == "feature": |
|
|
99 |
# calculate loss for z directly, to get clear how well this can be measured |
|
|
100 |
loss_fn_z = torch.nn.MSELoss() |
|
|
101 |
loss_z = loss_fn_z(output_batch["y"].squeeze(), batch["z"]) |
|
|
102 |
loss = loss_z |
|
|
103 |
summary_batch["loss_z"] = loss_z.item() |
|
|
104 |
else: |
|
|
105 |
loss_fn_y = torch.nn.MSELoss() |
|
|
106 |
loss_y = loss_fn_y(output_batch["y"].squeeze(), batch["y"]) |
|
|
107 |
loss = loss_y |
|
|
108 |
summary_batch["loss_y"] = loss_y.item() |
|
|
109 |
|
|
|
110 |
# calculate loss for colllider x |
|
|
111 |
loss_fn_x = torch.nn.MSELoss() |
|
|
112 |
loss_x = loss_fn_x(output_batch["bnx"].squeeze(), batch["x"]) |
|
|
113 |
summary_batch["loss_x"] = loss_x.item() |
|
|
114 |
if not params.alpha == 1: |
|
|
115 |
# possibly weigh down contribution of estimating x |
|
|
116 |
loss_x *= params.alpha |
|
|
117 |
summary_batch["loss_x_weighted"] = loss_x.item() |
|
|
118 |
# add x loss to total loss |
|
|
119 |
loss += loss_x |
|
|
120 |
|
|
|
121 |
# add least squares regression on final layer |
|
|
122 |
if params.do_least_squares: |
|
|
123 |
X = batch["x"].view(-1,1) |
|
|
124 |
t = batch["t"].view(-1,1) |
|
|
125 |
Z = output_batch["bnz"] |
|
|
126 |
if Z.ndimension() == 1: |
|
|
127 |
Z.unsqueeze_(1) |
|
|
128 |
Xhat = output_batch["bnx"] |
|
|
129 |
# add intercept |
|
|
130 |
Zi = torch.cat([torch.ones_like(t), Z], 1) |
|
|
131 |
# add treatment info |
|
|
132 |
Zt = torch.cat([Zi, t], 1) |
|
|
133 |
Y = batch["y"].view(-1,1) |
|
|
134 |
|
|
|
135 |
# regress y on final layer, without x |
|
|
136 |
betas_y = net.cholesky_least_squares(Zt, Y, intercept=False) |
|
|
137 |
y_hat = Zt.matmul(betas_y).view(-1,1) |
|
|
138 |
mse_y = ((Y - y_hat)**2).mean() |
|
|
139 |
|
|
|
140 |
summary_batch["regr_b_t"] = betas_y[-1].item() |
|
|
141 |
summary_batch["regr_loss_y"] = mse_y.item() |
|
|
142 |
|
|
|
143 |
# regress x on final layer without x |
|
|
144 |
betas_x = net.cholesky_least_squares(Zi, Xhat, intercept=False) |
|
|
145 |
x_hat = Zi.matmul(betas_x).view(-1,1) |
|
|
146 |
mse_x = ((Xhat - x_hat)**2).mean() |
|
|
147 |
|
|
|
148 |
# store all tensors for single pass after epoch |
|
|
149 |
Xhats.append(Xhat.detach().cpu()) |
|
|
150 |
Zhats.append(Z.detach().cpu()) |
|
|
151 |
ts.append(t.detach().cpu()) |
|
|
152 |
Ys.append(Y.detach().cpu()) |
|
|
153 |
|
|
|
154 |
summary_batch["regr_loss_x"] = mse_x.item() |
|
|
155 |
|
|
|
156 |
# add loss_bn only after n epochs |
|
|
157 |
if params.bottleneck_loss and epoch > params.bn_loss_lag_epochs: |
|
|
158 |
# only add to loss when bigger than margin |
|
|
159 |
if params.bn_loss_margin_type == "dynamic-mean": |
|
|
160 |
# for each batch, calculate loss of just using mean for predicting x |
|
|
161 |
mse_x_mean = ((X - X.mean())**2).mean() |
|
|
162 |
loss_bn = torch.max(torch.zeros_like(mse_x), mse_x_mean - mse_x) |
|
|
163 |
elif params.bn_loss_margin_type == "fixed": |
|
|
164 |
mse_diff = params.bn_loss_margin - mse_x |
|
|
165 |
loss_bn = torch.max(torch.zeros_like(mse_x), mse_diff) |
|
|
166 |
else: |
|
|
167 |
raise NotImplementedError(f'bottleneck loss margin type not implemented: {params.bn_loss_margin_type}') |
|
|
168 |
|
|
|
169 |
# possibly reweigh bottleneck loss and add to total loss |
|
|
170 |
summary_batch["loss_bn"] = loss_bn.item() |
|
|
171 |
# note is this double? |
|
|
172 |
if loss_bn > params.bn_loss_margin: |
|
|
173 |
loss_bn *= params.bottleneck_loss_wt |
|
|
174 |
loss += loss_bn |
|
|
175 |
|
|
|
176 |
# perform parameter update |
|
|
177 |
optimizer.zero_grad() |
|
|
178 |
loss.backward() |
|
|
179 |
optimizer.step() |
|
|
180 |
|
|
|
181 |
summary_batch['loss'] = loss.item() |
|
|
182 |
summ.append(summary_batch) |
|
|
183 |
|
|
|
184 |
# if necessary, write out tensors |
|
|
185 |
if params.monitor_train_tensors and (epoch % params.save_summary_steps == 0): |
|
|
186 |
tensors = {} |
|
|
187 |
for tensor_key in train_tensor_keys: |
|
|
188 |
if tensor_key in batch.keys(): |
|
|
189 |
tensors[tensor_key] = batch[tensor_key].squeeze().numpy() |
|
|
190 |
elif tensor_key.endswith("hat"): |
|
|
191 |
tensor_key = tensor_key.split("_")[0] |
|
|
192 |
if tensor_key in output_batch.keys(): |
|
|
193 |
tensors[tensor_key+"_hat"] = output_batch[tensor_key].detach().cpu().squeeze().numpy() |
|
|
194 |
else: |
|
|
195 |
assert False, f"key not found: {tensor_key}" |
|
|
196 |
# print(tensors) |
|
|
197 |
df = pd.DataFrame.from_dict(tensors, orient='columns') |
|
|
198 |
df["epoch"] = epoch |
|
|
199 |
|
|
|
200 |
with open(os.path.join(logdir, 'train-tensors.csv'), 'a') as f: |
|
|
201 |
df[["epoch"]+train_tensor_keys].to_csv(f, header=False) |
|
|
202 |
|
|
|
203 |
# update the average loss |
|
|
204 |
loss_avg.update(loss.item()) |
|
|
205 |
|
|
|
206 |
progress_bar.set_postfix(loss='{:05.3f}'.format(loss_avg())) |
|
|
207 |
progress_bar.update() |
|
|
208 |
|
|
|
209 |
# visualize gradients |
|
|
210 |
if epoch % params.save_summary_steps == 0 and args.monitor_grads: |
|
|
211 |
abs_gradients = {} |
|
|
212 |
for name, param in model.named_parameters(): |
|
|
213 |
try: # patch here, there were names / params that were 'none' |
|
|
214 |
abs_gradients[name] = np.abs(param.grad.cpu().numpy()).mean() |
|
|
215 |
writer.add_histogram("grad-"+name, param.grad, epoch) |
|
|
216 |
writer.add_scalars("mean-abs-gradients", abs_gradients, epoch) |
|
|
217 |
except: |
|
|
218 |
pass |
|
|
219 |
|
|
|
220 |
# compute mean of all metrics in summary |
|
|
221 |
metrics_mean = {metric:np.nanmean([x[metric] for x in summ]) for metric in summ[0]} |
|
|
222 |
|
|
|
223 |
# collect tensors |
|
|
224 |
Xhat = torch.cat(Xhats,0).view(-1,1) |
|
|
225 |
Yhat = torch.cat(Yhats,0).view(-1,1) |
|
|
226 |
Zhat = torch.cat(Zhats,0) |
|
|
227 |
t = torch.cat(ts,0) |
|
|
228 |
X = torch.cat(Xs,0) |
|
|
229 |
Xtrue= torch.cat(Xtrues,0) |
|
|
230 |
Y = torch.cat(Ys,0) |
|
|
231 |
|
|
|
232 |
if params.do_least_squares: |
|
|
233 |
# after the minibatches, do a single OLS on the whole data |
|
|
234 |
Zi = torch.cat([torch.ones_like(t), Zhat], 1) |
|
|
235 |
# add treatment info |
|
|
236 |
Zt = torch.cat([Zi, t], 1) |
|
|
237 |
# add x for biased version |
|
|
238 |
XZt = torch.cat([torch.ones_like(t), Xhat, Zhat, t], 1) |
|
|
239 |
|
|
|
240 |
betas_y_bias = net.cholesky_least_squares(XZt, Y, intercept=False) |
|
|
241 |
betas_y_causal = net.cholesky_least_squares(Zt, Y, intercept=False) |
|
|
242 |
model.betas_bias = betas_y_bias |
|
|
243 |
model.betas_causal = betas_y_causal |
|
|
244 |
metrics_mean["regr_bias_coef_t"] = betas_y_bias.squeeze()[-1] |
|
|
245 |
metrics_mean["regr_bias_coef_z"] = betas_y_bias.squeeze()[-2] |
|
|
246 |
metrics_mean["regr_causal_coef_t"] = betas_y_causal.squeeze()[-1] |
|
|
247 |
metrics_mean["regr_causal_coef_z"] = betas_y_causal.squeeze()[-2] |
|
|
248 |
|
|
|
249 |
# create some plots |
|
|
250 |
xx_scatter = net.make_scatter_plot(X.numpy(), Xhat.numpy(), xlabel='x', ylabel='xhat') |
|
|
251 |
xtruex_scatter= net.make_scatter_plot(Xtrue.numpy(), Xhat.numpy(), xlabel='xtrue', ylabel='xhat') |
|
|
252 |
xyhat_scatter = net.make_scatter_plot(X.numpy(), Yhat.numpy(), c=t.numpy(), xlabel='x', ylabel='yhat') |
|
|
253 |
yy_scatter = net.make_scatter_plot(Y.numpy(), Yhat.numpy(), c=t.numpy(), xlabel='y', ylabel='yhat') |
|
|
254 |
writer.add_figure('x-xhat/train', xx_scatter, epoch+1) |
|
|
255 |
writer.add_figure('xtrue-xhat/train', xtruex_scatter, epoch+1) |
|
|
256 |
writer.add_figure('x-yhat/train', xyhat_scatter, epoch+1) |
|
|
257 |
writer.add_figure('y-yhat/train', yy_scatter, epoch+1) |
|
|
258 |
|
|
|
259 |
|
|
|
260 |
metrics_string = " ; ".join("{}: {:05.3f}".format(k, v) for k, v in metrics_mean.items()) |
|
|
261 |
logging.info("- Train metrics: " + metrics_string) |
|
|
262 |
|
|
|
263 |
return metrics_mean |
|
|
264 |
|
|
|
265 |
|
|
|
266 |
def train_and_evaluate(model, train_dataloader, val_dataloader, optimizer, loss_fn, metrics, params, setting, args, |
|
|
267 |
writer=None, logdir=None, restore_file=None): |
|
|
268 |
"""Train the model and evaluate every epoch. |
|
|
269 |
|
|
|
270 |
Args: |
|
|
271 |
model: (torch.nn.Module) the neural network |
|
|
272 |
train_dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches training data |
|
|
273 |
val_dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches validation data |
|
|
274 |
optimizer: (torch.optim) optimizer for parameters of model |
|
|
275 |
loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch |
|
|
276 |
metrics: (dict) a dictionary of functions that compute a metric using mnisthe output and labels of each batch |
|
|
277 |
params: (Params) hyperparameters |
|
|
278 |
model_dir: (string) directory containing config, weights and log |
|
|
279 |
restore_file: (string) optional- name of file to restore from (withoutmnistits extension .pth.tar) |
|
|
280 |
covar_mode: (bool) does the data-loader give back covariates / additional data |
|
|
281 |
""" |
|
|
282 |
|
|
|
283 |
# setup directories for data |
|
|
284 |
setting_home = setting.home |
|
|
285 |
if not args.fase == "feature": |
|
|
286 |
data_dir = os.path.join(setting_home, "data") |
|
|
287 |
else: |
|
|
288 |
if setting.mode3d: |
|
|
289 |
data_dir = "data" |
|
|
290 |
else: |
|
|
291 |
data_dir = "slices" |
|
|
292 |
covar_mode = setting.covar_mode |
|
|
293 |
|
|
|
294 |
x_frozen = False |
|
|
295 |
|
|
|
296 |
|
|
|
297 |
best_val_metric = 0.0 |
|
|
298 |
if "loss" in setting.metrics[0]: |
|
|
299 |
best_val_metric = 1.0e6 |
|
|
300 |
|
|
|
301 |
val_preds = np.zeros((len(val_dataloader.dataset), params.num_epochs)) |
|
|
302 |
|
|
|
303 |
for epoch in range(params.num_epochs): |
|
|
304 |
|
|
|
305 |
# Run one epoch |
|
|
306 |
logging.info(f"Epoch {epoch+1}/{params.num_epochs}; setting: {args.setting}, fase {args.fase}, experiment: {args.experiment}") |
|
|
307 |
|
|
|
308 |
# compute number of batches in one epoch (one full pass over the training set) |
|
|
309 |
train_metrics = train(model, optimizer, loss_fn, train_dataloader, metrics, params, setting, writer, epoch) |
|
|
310 |
print(train_metrics) |
|
|
311 |
for metric_name in train_metrics.keys(): |
|
|
312 |
metric_vals = {'train': train_metrics[metric_name]} |
|
|
313 |
writer.add_scalars(metric_name, metric_vals, epoch+1) |
|
|
314 |
|
|
|
315 |
|
|
|
316 |
# for name, param in model.named_parameters(): |
|
|
317 |
# writer.add_histogram(name, param.clone().cpu().data.numpy(), epoch+1) |
|
|
318 |
|
|
|
319 |
if epoch % params.save_summary_steps == 0: |
|
|
320 |
|
|
|
321 |
# Evaluate for one epoch on validation set |
|
|
322 |
valid_metrics, outtensors = evaluate(model, loss_fn, val_dataloader, metrics, params, setting, epoch, writer) |
|
|
323 |
valid_metrics["intercept"] = model.regressor.fc.bias.detach().cpu().numpy() |
|
|
324 |
print(valid_metrics) |
|
|
325 |
|
|
|
326 |
for name, module in model.regressor.named_children(): |
|
|
327 |
if name == "t": |
|
|
328 |
valid_metrics["b_t"] = module.weight.detach().cpu().numpy() |
|
|
329 |
elif name == "zt": |
|
|
330 |
weights = module.weight.detach().cpu().squeeze().numpy().reshape(-1) |
|
|
331 |
for i, weight in enumerate(weights): |
|
|
332 |
valid_metrics["b_zt"+str(i)] = weight |
|
|
333 |
else: |
|
|
334 |
pass |
|
|
335 |
for metric_name in valid_metrics.keys(): |
|
|
336 |
metric_vals = {'valid': valid_metrics[metric_name]} |
|
|
337 |
writer.add_scalars(metric_name, metric_vals, epoch+1) |
|
|
338 |
|
|
|
339 |
# create plots |
|
|
340 |
val_df = val_dataloader.dataset.df |
|
|
341 |
xx_scatter = net.make_scatter_plot(val_df.x.values, outtensors['xhat'], xlabel='x', ylabel='xhat') |
|
|
342 |
xtruex_scatter= net.make_scatter_plot(val_df.x_true.values, outtensors['xhat'], xlabel='x', ylabel='xhat') |
|
|
343 |
xyhat_scatter = net.make_scatter_plot(val_df.x.values, outtensors['predictions'], c=val_df.t, xlabel='x', ylabel='yhat') |
|
|
344 |
zyhat_scatter = net.make_scatter_plot(val_df.z.values, outtensors['predictions'], c=val_df.t, xlabel='z', ylabel='yhat') |
|
|
345 |
yy_scatter = net.make_scatter_plot(val_df.y.values, outtensors['predictions'], c=val_df.t, xlabel='yhat', ylabel='y') |
|
|
346 |
writer.add_figure('x-xhat/valid', xx_scatter, epoch+1) |
|
|
347 |
writer.add_figure('xtrue-xhat/valid', xtruex_scatter, epoch+1) |
|
|
348 |
writer.add_figure('x-yhat/valid', xyhat_scatter, epoch+1) |
|
|
349 |
writer.add_figure('z-yhat/valid', zyhat_scatter, epoch+1) |
|
|
350 |
writer.add_figure('y-yhat/valid', yy_scatter, epoch+1) |
|
|
351 |
|
|
|
352 |
if params.save_preds: |
|
|
353 |
# writer.add_histogram("predictions", preds) |
|
|
354 |
if setting.num_classes == 1: |
|
|
355 |
val_preds[:, epoch] = np.squeeze(outtensors['predictions']) |
|
|
356 |
|
|
|
357 |
# write preds to file |
|
|
358 |
pred_fname = os.path.join(setting.home, setting.fase+"-fase", "preds_val.csv") |
|
|
359 |
with open(pred_fname, 'ab') as f: |
|
|
360 |
np.savetxt(f, preds.T, newline="") |
|
|
361 |
|
|
|
362 |
np.save(os.path.join(setting.home, setting.fase+"-fase", "preds.npy"), preds) |
|
|
363 |
|
|
|
364 |
else: |
|
|
365 |
val_metric = valid_metrics[setting.metrics[0]] |
|
|
366 |
if "loss" in str(setting.metrics[0]): |
|
|
367 |
is_best = val_metric<=best_val_metric |
|
|
368 |
else: |
|
|
369 |
is_best = val_metric>=best_val_metric |
|
|
370 |
|
|
|
371 |
# Save weights |
|
|
372 |
state_dict = model.state_dict() |
|
|
373 |
optim_dict = optimizer.state_dict() |
|
|
374 |
|
|
|
375 |
state = { |
|
|
376 |
'epoch': epoch+1, |
|
|
377 |
'state_dict': state_dict, |
|
|
378 |
'optim_dict': optim_dict |
|
|
379 |
} |
|
|
380 |
|
|
|
381 |
|
|
|
382 |
utils.save_checkpoint(state, |
|
|
383 |
is_best=is_best, |
|
|
384 |
checkpoint=logdir) |
|
|
385 |
|
|
|
386 |
# If best_eval, best_save_path |
|
|
387 |
valid_metrics["epoch"] = epoch |
|
|
388 |
if is_best: |
|
|
389 |
logging.info("- Found new best {}: {:.3f}".format(setting.metrics[0], val_metric)) |
|
|
390 |
best_val_metric = val_metric |
|
|
391 |
|
|
|
392 |
# Save best val metrics in a json file in the model directory |
|
|
393 |
best_json_path = os.path.join(logdir, "metrics_val_best_weights.json") |
|
|
394 |
utils.save_dict_to_json(valid_metrics, best_json_path) |
|
|
395 |
|
|
|
396 |
# Save latest val metrics in a json file in the model directory |
|
|
397 |
last_json_path = os.path.join(logdir, "metrics_val_last_weights.json") |
|
|
398 |
utils.save_dict_to_json(valid_metrics, last_json_path) |
|
|
399 |
|
|
|
400 |
# final evaluation |
|
|
401 |
writer.export_scalars_to_json(os.path.join(logdir, "all_scalars.json")) |
|
|
402 |
|
|
|
403 |
if args.save_preds: |
|
|
404 |
np.save(os.path.join(setting.home, setting.fase + "-fase", "val_preds.npy"), val_preds) |
|
|
405 |
|
|
|
406 |
|
|
|
407 |
|
|
|
408 |
if __name__ == '__main__': |
|
|
409 |
|
|
|
410 |
# Load the parameters from json file |
|
|
411 |
args = parser.parse_args() |
|
|
412 |
|
|
|
413 |
|
|
|
414 |
# Load information from last setting if none provided: |
|
|
415 |
last_defaults = utils.Params("last-defaults.json") |
|
|
416 |
if args.setting == "": |
|
|
417 |
print("using last default setting") |
|
|
418 |
args.setting = last_defaults.dict["setting"] |
|
|
419 |
for param, value in last_defaults.dict.items(): |
|
|
420 |
print("{}: {}".format(param, value)) |
|
|
421 |
else: |
|
|
422 |
with open("last-defaults.json", "r+") as jsonFile: |
|
|
423 |
defaults = json.load(jsonFile) |
|
|
424 |
tmp = defaults["setting"] |
|
|
425 |
defaults["setting"] = args.setting |
|
|
426 |
jsonFile.seek(0) # rewind |
|
|
427 |
json.dump(defaults, jsonFile) |
|
|
428 |
jsonFile.truncate() |
|
|
429 |
|
|
|
430 |
# setup visdom environment |
|
|
431 |
# if args.visdom: |
|
|
432 |
# from visdom import Visdom |
|
|
433 |
# viz = Visdom(env=f"lidcr_{args.setting}_{args.fase}_{args.experiment}") |
|
|
434 |
|
|
|
435 |
# load setting (data generation, regression model etc) |
|
|
436 |
setting_home = os.path.join(args.setting_dir, args.setting) |
|
|
437 |
setting = utils.Params(os.path.join(setting_home, "setting.json")) |
|
|
438 |
setting.home = setting_home |
|
|
439 |
|
|
|
440 |
# when not specified in call, grab model specification from setting file |
|
|
441 |
if setting.cnn_model == "": |
|
|
442 |
json_path = os.path.join(args.model_dir, "t-suppression", args.experiment+".json") |
|
|
443 |
else: |
|
|
444 |
json_path = os.path.join(args.model_dir, setting.cnn_model, 'params.json') |
|
|
445 |
assert os.path.isfile(json_path), "No json configuration file found at {}".format(json_path) |
|
|
446 |
if not os.path.exists(os.path.join(setting.home, args.fase + "-fase")): |
|
|
447 |
os.makedirs(os.path.join(setting.home, args.fase + "-fase")) |
|
|
448 |
shutil.copy(json_path, os.path.join(setting_home, args.fase + "-fase", "params.json")) |
|
|
449 |
params = utils.Params(json_path) |
|
|
450 |
# covar_mode = setting.covar_mode |
|
|
451 |
# mode3d = setting.mode3d |
|
|
452 |
parallel = args.parallel |
|
|
453 |
|
|
|
454 |
params.device = None |
|
|
455 |
if not args.disable_cuda and torch.cuda.is_available(): |
|
|
456 |
params.device = torch.device('cuda') |
|
|
457 |
params.cuda = True |
|
|
458 |
# switch gpus for better use when running multiple experiments |
|
|
459 |
if not args.parallel: |
|
|
460 |
torch.cuda.set_device(int(args.gpu)) |
|
|
461 |
else: |
|
|
462 |
params.device = torch.device('cpu') |
|
|
463 |
|
|
|
464 |
# adapt fase |
|
|
465 |
setting.fase = args.fase |
|
|
466 |
setting.metrics = pd.Series(setting.metrics).drop_duplicates().tolist() |
|
|
467 |
print("metrics {}:".format(setting.metrics)) |
|
|
468 |
|
|
|
469 |
# Set the random seed for reproducible experiments |
|
|
470 |
torch.manual_seed(230) |
|
|
471 |
if params.cuda: torch.cuda.manual_seed(230) |
|
|
472 |
|
|
|
473 |
# Set the logger |
|
|
474 |
logdir=os.path.join(setting_home, setting.fase+"-fase", "runs") |
|
|
475 |
if not args.experiment == '': |
|
|
476 |
logdir=os.path.join(logdir, args.experiment) |
|
|
477 |
if not os.path.isdir(logdir): |
|
|
478 |
os.makedirs(logdir) |
|
|
479 |
|
|
|
480 |
# copy params as backupt to logdir |
|
|
481 |
shutil.copy(json_path, os.path.join(logdir, "params.json")) |
|
|
482 |
|
|
|
483 |
# utils.set_logger(os.path.join(args.model_dir, 'train.log')) |
|
|
484 |
utils.set_logger(os.path.join(logdir, 'train.log')) |
|
|
485 |
|
|
|
486 |
# Create the input data pipeline |
|
|
487 |
logging.info("Loading the datasets...") |
|
|
488 |
|
|
|
489 |
# fetch dataloaders |
|
|
490 |
dataloaders = data_loader.fetch_dataloader(args, params, setting, ["train", "valid"]) |
|
|
491 |
train_dl = dataloaders['train'] |
|
|
492 |
valid_dl = dataloaders['valid'] |
|
|
493 |
|
|
|
494 |
if setting.num_classes > 1 and params.balance_classes: |
|
|
495 |
train_labels = train_dl.dataset.df[setting.outcome[0]].values |
|
|
496 |
class_weights = compute_class_weight('balanced', np.unique(train_labels), train_labels) |
|
|
497 |
# valid_dl = train_dl |
|
|
498 |
|
|
|
499 |
logging.info("- done.") |
|
|
500 |
|
|
|
501 |
if args.intercept: |
|
|
502 |
assert len(setting.outcome) == 1, "Multiple outcomes not implemented for intercept yet" |
|
|
503 |
print("running intercept mode") |
|
|
504 |
mu = valid_dl.dataset.df[setting.outcome].values.mean() |
|
|
505 |
def new_forward(self, x, data, mu=mu): |
|
|
506 |
intercept = torch.autograd.Variable(mu * torch.ones((x.shape[0],1)), requires_grad=False).to(params.device, non_blocking=True) |
|
|
507 |
bn_activations = torch.autograd.Variable(torch.zeros((x.shape[0],)), requires_grad=False).to(params.device, non_blocking=True) |
|
|
508 |
return {setting.outcome[0]: intercept, "bn": bn_activations} |
|
|
509 |
|
|
|
510 |
net.Net3D.forward = new_forward |
|
|
511 |
params.num_epochs = 1 |
|
|
512 |
setting.metrics = [] |
|
|
513 |
logdir = os.path.join(logdir, "intercept") |
|
|
514 |
|
|
|
515 |
if setting.mode3d: |
|
|
516 |
model = net.Net3D(params, setting).to(params.device) |
|
|
517 |
else: |
|
|
518 |
model = net.CausalNet(params, setting).to(params.device) |
|
|
519 |
|
|
|
520 |
optimizers = {'sgd': optim.SGD, 'adam': optim.Adam} |
|
|
521 |
|
|
|
522 |
if parallel: |
|
|
523 |
print("parallel mode") |
|
|
524 |
model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) |
|
|
525 |
|
|
|
526 |
if params.momentum > 0: |
|
|
527 |
optimizer = optimizers[params.optimizer](model.parameters(), lr=params.learning_rate, weight_decay=params.wd, momentum=params.momentum) |
|
|
528 |
else: |
|
|
529 |
optimizer = optimizers[params.optimizer](model.parameters(), lr=params.learning_rate, weight_decay=params.wd) |
|
|
530 |
|
|
|
531 |
# if params.use_mi: |
|
|
532 |
# optimizer.add_param_group({'params': mine.parameters()}) |
|
|
533 |
|
|
|
534 |
if setting.covar_mode and params.lr_t_factor != 1: |
|
|
535 |
optimizer = net.speedup_t(model, params) |
|
|
536 |
|
|
|
537 |
if args.restore_last and (not args.cold_start): |
|
|
538 |
print("Loading state dict from last running setting") |
|
|
539 |
utils.load_checkpoint(os.path.join(setting.home, args.fase + "-fase", "last.pth.tar"), model, strict=False) |
|
|
540 |
elif args.restore_warm: |
|
|
541 |
utils.load_checkpoint(os.path.join(setting.home, 'warm-start.pth.tar'), model, strict=False) |
|
|
542 |
else: |
|
|
543 |
pass |
|
|
544 |
|
|
|
545 |
# fetch loss function and metrics |
|
|
546 |
if setting.num_classes > 1 and params.balance_classes: |
|
|
547 |
loss_fn = net.get_loss_fn(setting, weights=class_weights) |
|
|
548 |
else: |
|
|
549 |
loss_fn = net.get_loss_fn(setting) |
|
|
550 |
# metrics = {metric:net.all_metrics[metric] for metric in setting.metrics} |
|
|
551 |
metrics = None |
|
|
552 |
|
|
|
553 |
if params.monitor_train_tensors: |
|
|
554 |
print(f"Recording all train tensors") |
|
|
555 |
import csv |
|
|
556 |
train_tensor_keys = ['t','x', 'z', 'y', 'x_hat', 'z_hat', 'y_hat'] |
|
|
557 |
with open(os.path.join(logdir, 'train-tensors.csv'), 'w') as f: |
|
|
558 |
writer = csv.writer(f) |
|
|
559 |
writer.writerow(['epoch']+train_tensor_keys) |
|
|
560 |
|
|
|
561 |
# Train the model |
|
|
562 |
# print(model) |
|
|
563 |
# print(summary(model, (3, 224, 224), batch_size=1)) |
|
|
564 |
logging.info("Starting training for {} epoch(s)".format(params.num_epochs)) |
|
|
565 |
for split, dl in dataloaders.items(): |
|
|
566 |
logging.info("Number of %s samples: %s" % (split, str(len(dl.dataset)))) |
|
|
567 |
# logging.info("Number of valid examples: {}".format(len(valid.dataset))) |
|
|
568 |
|
|
|
569 |
|
|
|
570 |
with SummaryWriter(logdir) as writer: |
|
|
571 |
# train(model, optimizer, loss_fn, train_dl, metrics, params) |
|
|
572 |
train_and_evaluate(model, train_dl, valid_dl, optimizer, loss_fn, metrics, params, setting, args, |
|
|
573 |
writer, logdir, args.restore_file) |