a b/train.py
1
"""Train the model"""
2
3
import argparse
4
import logging
5
import os, shutil
6
7
import numpy as np
8
import pandas as pd
9
from sklearn.utils.class_weight import compute_class_weight
10
import torch
11
import torch.optim as optim
12
import torchvision.models as models
13
from torch.autograd import Variable
14
from torch.utils.tensorboard import SummaryWriter
15
from tqdm import tqdm
16
# from torchsummary import summary
17
18
import utils
19
import json
20
import model.net as net
21
import model.data_loader as data_loader
22
from evaluate import evaluate
23
24
parser = argparse.ArgumentParser()
25
parser.add_argument('--data-dir', default='data', help="Directory containing the dataset")
26
parser.add_argument('--model-dir', default='experiments', help="Directory containing params.json")
27
parser.add_argument('--setting-dir', default='settings', help="Directory with different settings")
28
parser.add_argument('--setting', default='collider-prognosticfactor', help="Directory contain setting.json, experimental setting, data-generation, regression model etc")
29
parser.add_argument('--fase', default='xybn', help='fase of training model, see manuscript for details. x, y, xy, bn, or feature')
30
parser.add_argument('--experiment', default='', help="Manual name for experiment for logging, will be subdir of setting")
31
parser.add_argument('--restore-file', default=None,
32
                    help="Optional, name of the file in --model_dir containing weights to reload before \
33
                    training")  # 'best' or 'train'
34
parser.add_argument('--restore-last', action='store_true', help="continue a last run")
35
parser.add_argument('--restore-warm', action='store_true', help="continue on the run called 'warm-start.pth'")
36
parser.add_argument('--use-last', action="store_true", help="use last state dict instead of 'best' (use for early stopping manually)")
37
parser.add_argument('--cold-start', action='store_true', help="ignore previous state dicts (weights), even if they exist")
38
parser.add_argument('--warm-start', dest='cold_start', action='store_false', help="start from previous state dict")
39
parser.add_argument('--disable-cuda', action='store_true', help="Disable Cuda")
40
parser.add_argument('--no-parallel', action="store_false", help="no multiple GPU", dest="parallel")
41
parser.add_argument('--parallel', action="store_true", help="multiple GPU", dest="parallel")
42
parser.add_argument('--gpu', default=0, type=int, help='if not running in parallel (=all gpus), only use this gpu')
43
parser.add_argument('--intercept', action="store_true", help="dummy run for getting intercept baseline results")
44
# parser.add_argument('--visdom', action='store_true', help='generate plots with visdom')
45
# parser.add_argument('--novisdom', dest='visdom', action='store_false', help='dont plot with visdom')
46
parser.add_argument('--monitor-grads', action='store_true', help='keep track of mean norm of gradients')
47
parser.set_defaults(parallel=False, cold_start=True, use_last=False, intercept=False, restore_last=False, save_preds=False,
48
                    monitor_grads=False, restore_warm=False
49
                    # visdom=False
50
                    )
51
52
def train(model, optimizer, loss_fn, dataloader, metrics, params, setting, writer=None, epoch=None):
53
    """Train the model on `num_steps` batches
54
55
    Args:
56
        model: (torch.nn.Module) the neural network
57
        optimizer: (torch.optim) optimizer for parameters of model
58
        loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch
59
        dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches training data
60
        metrics: (dict) a dictionary of functions that compute a metric using the output and labels of each batch
61
        params: (Params) hyperparameters
62
        num_steps: (int) number of batches to train on, each of size params.batch_size
63
    """
64
    global train_tensor_keys, logdir
65
66
    # set model to training mode
67
    model.train()
68
69
    # summary for current training loop and a running average object for loss
70
    summ = []
71
    loss_avg = utils.RunningAverage()
72
73
    # create storate for tensors for OLS after minibatches
74
    ts = []
75
    Xs = []
76
    Xtrues = []
77
    Ys = []
78
    Xhats = []
79
    Yhats = []
80
    Zhats = []
81
82
    # Use tqdm for progress bar
83
    with tqdm(total=len(dataloader)) as progress_bar:
84
        for i, batch in enumerate(dataloader):
85
            summary_batch = {}
86
            # put batch on cuda
87
            batch = {k: v.to(params.device) for k, v in batch.items()}
88
            if not (setting.covar_mode and epoch > params.suppress_t_epochs):
89
                batch["t"] = torch.zeros_like(batch['t'])
90
            Xs.append(batch['x'].detach().cpu())
91
            Xtrues.append(batch['x_true'].detach().cpu())
92
93
            # compute model output and loss
94
            output_batch = model(batch['image'], batch['t'].view(-1,1), epoch)
95
            Yhats.append(output_batch['y'].detach().cpu())
96
97
            # calculate loss
98
            if args.fase == "feature":
99
                # calculate loss for z directly, to get clear how well this can be measured
100
                loss_fn_z = torch.nn.MSELoss()
101
                loss_z = loss_fn_z(output_batch["y"].squeeze(), batch["z"])
102
                loss   = loss_z
103
                summary_batch["loss_z"] = loss_z.item()
104
            else:
105
                loss_fn_y = torch.nn.MSELoss()
106
                loss_y = loss_fn_y(output_batch["y"].squeeze(), batch["y"])
107
                loss   = loss_y
108
                summary_batch["loss_y"] = loss_y.item()
109
110
            # calculate loss for colllider x
111
            loss_fn_x = torch.nn.MSELoss()
112
            loss_x = loss_fn_x(output_batch["bnx"].squeeze(), batch["x"])
113
            summary_batch["loss_x"] = loss_x.item()
114
            if not params.alpha == 1:
115
                # possibly weigh down contribution of estimating x
116
                loss_x *= params.alpha
117
                summary_batch["loss_x_weighted"] = loss_x.item()
118
            # add x loss to total loss
119
            loss += loss_x            
120
121
            # add least squares regression on final layer
122
            if params.do_least_squares:
123
                X    = batch["x"].view(-1,1)
124
                t    = batch["t"].view(-1,1)
125
                Z    = output_batch["bnz"]
126
                if Z.ndimension() == 1:
127
                    Z.unsqueeze_(1)
128
                Xhat = output_batch["bnx"]
129
                # add intercept
130
                Zi = torch.cat([torch.ones_like(t), Z], 1)
131
                # add treatment info
132
                Zt = torch.cat([Zi, t], 1)
133
                Y  = batch["y"].view(-1,1)
134
135
                # regress y on final layer, without x
136
                betas_y = net.cholesky_least_squares(Zt, Y, intercept=False)
137
                y_hat   = Zt.matmul(betas_y).view(-1,1)
138
                mse_y  = ((Y - y_hat)**2).mean()
139
140
                summary_batch["regr_b_t"] = betas_y[-1].item()
141
                summary_batch["regr_loss_y"] = mse_y.item()
142
143
                # regress x on final layer without x
144
                betas_x = net.cholesky_least_squares(Zi, Xhat, intercept=False)
145
                x_hat   = Zi.matmul(betas_x).view(-1,1)
146
                mse_x  = ((Xhat - x_hat)**2).mean()
147
148
                # store all tensors for single pass after epoch
149
                Xhats.append(Xhat.detach().cpu())
150
                Zhats.append(Z.detach().cpu())
151
                ts.append(t.detach().cpu())
152
                Ys.append(Y.detach().cpu())
153
154
                summary_batch["regr_loss_x"] = mse_x.item()
155
156
            # add loss_bn only after n epochs
157
            if params.bottleneck_loss and epoch > params.bn_loss_lag_epochs:
158
                # only add to loss when bigger than margin
159
                if params.bn_loss_margin_type == "dynamic-mean":
160
                    # for each batch, calculate loss of just using mean for predicting x
161
                    mse_x_mean = ((X - X.mean())**2).mean()
162
                    loss_bn = torch.max(torch.zeros_like(mse_x), mse_x_mean - mse_x)
163
                elif params.bn_loss_margin_type == "fixed":
164
                    mse_diff = params.bn_loss_margin - mse_x
165
                    loss_bn = torch.max(torch.zeros_like(mse_x), mse_diff)
166
                else:
167
                    raise NotImplementedError(f'bottleneck loss margin type not implemented: {params.bn_loss_margin_type}')
168
                
169
                # possibly reweigh bottleneck loss and add to total loss
170
                summary_batch["loss_bn"] = loss_bn.item()
171
                # note is this double?
172
                if loss_bn > params.bn_loss_margin:
173
                    loss_bn *= params.bottleneck_loss_wt
174
                    loss    += loss_bn
175
176
            # perform parameter update
177
            optimizer.zero_grad()
178
            loss.backward()
179
            optimizer.step()
180
181
            summary_batch['loss'] = loss.item()
182
            summ.append(summary_batch)
183
184
            # if necessary, write out tensors
185
            if params.monitor_train_tensors and (epoch % params.save_summary_steps == 0):
186
                tensors = {}
187
                for tensor_key in train_tensor_keys:
188
                    if tensor_key in batch.keys():
189
                        tensors[tensor_key] = batch[tensor_key].squeeze().numpy()
190
                    elif tensor_key.endswith("hat"):
191
                        tensor_key = tensor_key.split("_")[0]
192
                        if tensor_key in output_batch.keys():
193
                            tensors[tensor_key+"_hat"] = output_batch[tensor_key].detach().cpu().squeeze().numpy()
194
                    else:
195
                        assert False, f"key not found: {tensor_key}"
196
                # print(tensors)
197
                df = pd.DataFrame.from_dict(tensors, orient='columns')
198
                df["epoch"] = epoch
199
200
                with open(os.path.join(logdir, 'train-tensors.csv'), 'a') as f:
201
                    df[["epoch"]+train_tensor_keys].to_csv(f, header=False)
202
203
            # update the average loss
204
            loss_avg.update(loss.item())
205
206
            progress_bar.set_postfix(loss='{:05.3f}'.format(loss_avg()))
207
            progress_bar.update()
208
209
    # visualize gradients
210
    if epoch % params.save_summary_steps == 0 and args.monitor_grads:
211
        abs_gradients = {}
212
        for name, param in model.named_parameters():
213
            try: # patch here, there were names / params that were 'none'
214
                abs_gradients[name] = np.abs(param.grad.cpu().numpy()).mean()
215
                writer.add_histogram("grad-"+name, param.grad, epoch)
216
                writer.add_scalars("mean-abs-gradients", abs_gradients, epoch)
217
            except:
218
                pass
219
220
    # compute mean of all metrics in summary
221
    metrics_mean = {metric:np.nanmean([x[metric] for x in summ]) for metric in summ[0]}
222
    
223
    # collect tensors
224
    Xhat = torch.cat(Xhats,0).view(-1,1)
225
    Yhat = torch.cat(Yhats,0).view(-1,1)
226
    Zhat = torch.cat(Zhats,0)
227
    t    = torch.cat(ts,0)
228
    X    = torch.cat(Xs,0)
229
    Xtrue= torch.cat(Xtrues,0)
230
    Y    = torch.cat(Ys,0)
231
    
232
    if params.do_least_squares:
233
        # after the minibatches, do a single OLS on the whole data
234
        Zi = torch.cat([torch.ones_like(t), Zhat], 1)
235
        # add treatment info
236
        Zt = torch.cat([Zi, t], 1)
237
        # add x for biased version
238
        XZt = torch.cat([torch.ones_like(t), Xhat, Zhat, t], 1)
239
240
        betas_y_bias       = net.cholesky_least_squares(XZt, Y, intercept=False)
241
        betas_y_causal     = net.cholesky_least_squares(Zt, Y, intercept=False)
242
        model.betas_bias   = betas_y_bias
243
        model.betas_causal = betas_y_causal
244
        metrics_mean["regr_bias_coef_t"]   = betas_y_bias.squeeze()[-1]
245
        metrics_mean["regr_bias_coef_z"]   = betas_y_bias.squeeze()[-2]
246
        metrics_mean["regr_causal_coef_t"] = betas_y_causal.squeeze()[-1]
247
        metrics_mean["regr_causal_coef_z"] = betas_y_causal.squeeze()[-2]
248
       
249
    # create some plots
250
    xx_scatter    = net.make_scatter_plot(X.numpy(), Xhat.numpy(), xlabel='x', ylabel='xhat') 
251
    xtruex_scatter= net.make_scatter_plot(Xtrue.numpy(), Xhat.numpy(), xlabel='xtrue', ylabel='xhat') 
252
    xyhat_scatter = net.make_scatter_plot(X.numpy(), Yhat.numpy(), c=t.numpy(), xlabel='x', ylabel='yhat')
253
    yy_scatter    = net.make_scatter_plot(Y.numpy(), Yhat.numpy(), c=t.numpy(), xlabel='y', ylabel='yhat') 
254
    writer.add_figure('x-xhat/train', xx_scatter, epoch+1)
255
    writer.add_figure('xtrue-xhat/train', xtruex_scatter, epoch+1)
256
    writer.add_figure('x-yhat/train', xyhat_scatter, epoch+1)
257
    writer.add_figure('y-yhat/train', yy_scatter, epoch+1)
258
259
260
    metrics_string = " ; ".join("{}: {:05.3f}".format(k, v) for k, v in metrics_mean.items())
261
    logging.info("- Train metrics: " + metrics_string)
262
263
    return metrics_mean
264
265
266
def train_and_evaluate(model, train_dataloader, val_dataloader, optimizer, loss_fn, metrics, params, setting, args,
267
                       writer=None, logdir=None, restore_file=None):
268
    """Train the model and evaluate every epoch.
269
270
    Args:
271
        model: (torch.nn.Module) the neural network
272
        train_dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches training data
273
        val_dataloader: (DataLoader) a torch.utils.data.DataLoader object that fetches validation data
274
        optimizer: (torch.optim) optimizer for parameters of model
275
        loss_fn: a function that takes batch_output and batch_labels and computes the loss for the batch
276
        metrics: (dict) a dictionary of functions that compute a metric using mnisthe output and labels of each batch
277
        params: (Params) hyperparameters
278
        model_dir: (string) directory containing config, weights and log
279
        restore_file: (string) optional- name of file to restore from (withoutmnistits extension .pth.tar)
280
        covar_mode: (bool) does the data-loader give back covariates / additional data
281
    """
282
283
    # setup directories for data
284
    setting_home = setting.home
285
    if not args.fase == "feature":
286
        data_dir = os.path.join(setting_home, "data")
287
    else:
288
        if setting.mode3d:
289
            data_dir = "data"
290
        else:
291
            data_dir = "slices"
292
    covar_mode = setting.covar_mode
293
294
    x_frozen = False
295
296
297
    best_val_metric = 0.0
298
    if "loss" in setting.metrics[0]:
299
        best_val_metric = 1.0e6
300
301
    val_preds = np.zeros((len(val_dataloader.dataset), params.num_epochs))
302
303
    for epoch in range(params.num_epochs):
304
305
        # Run one epoch
306
        logging.info(f"Epoch {epoch+1}/{params.num_epochs}; setting: {args.setting}, fase {args.fase}, experiment: {args.experiment}")
307
308
        # compute number of batches in one epoch (one full pass over the training set)
309
        train_metrics = train(model, optimizer, loss_fn, train_dataloader, metrics, params, setting, writer, epoch)
310
        print(train_metrics)
311
        for metric_name in train_metrics.keys():
312
            metric_vals = {'train': train_metrics[metric_name]}
313
            writer.add_scalars(metric_name, metric_vals, epoch+1)
314
315
316
        # for name, param in model.named_parameters():
317
        #     writer.add_histogram(name, param.clone().cpu().data.numpy(), epoch+1)
318
        
319
        if epoch % params.save_summary_steps == 0:
320
321
            # Evaluate for one epoch on validation set
322
            valid_metrics, outtensors = evaluate(model, loss_fn, val_dataloader, metrics, params, setting, epoch, writer) 
323
            valid_metrics["intercept"] = model.regressor.fc.bias.detach().cpu().numpy()
324
            print(valid_metrics) 
325
            
326
            for name, module in model.regressor.named_children():
327
                if name == "t":
328
                    valid_metrics["b_t"] = module.weight.detach().cpu().numpy()
329
                elif name == "zt":
330
                    weights = module.weight.detach().cpu().squeeze().numpy().reshape(-1)
331
                    for i, weight in enumerate(weights):
332
                        valid_metrics["b_zt"+str(i)] = weight
333
                else:
334
                    pass
335
            for metric_name in valid_metrics.keys():
336
                metric_vals = {'valid': valid_metrics[metric_name]}
337
                writer.add_scalars(metric_name, metric_vals, epoch+1)
338
339
            # create plots
340
            val_df = val_dataloader.dataset.df
341
            xx_scatter    = net.make_scatter_plot(val_df.x.values, outtensors['xhat'], xlabel='x', ylabel='xhat') 
342
            xtruex_scatter= net.make_scatter_plot(val_df.x_true.values, outtensors['xhat'], xlabel='x', ylabel='xhat') 
343
            xyhat_scatter = net.make_scatter_plot(val_df.x.values, outtensors['predictions'], c=val_df.t, xlabel='x', ylabel='yhat')
344
            zyhat_scatter = net.make_scatter_plot(val_df.z.values, outtensors['predictions'], c=val_df.t, xlabel='z', ylabel='yhat')
345
            yy_scatter    = net.make_scatter_plot(val_df.y.values, outtensors['predictions'], c=val_df.t, xlabel='yhat', ylabel='y') 
346
            writer.add_figure('x-xhat/valid', xx_scatter, epoch+1)
347
            writer.add_figure('xtrue-xhat/valid', xtruex_scatter, epoch+1)
348
            writer.add_figure('x-yhat/valid', xyhat_scatter, epoch+1)
349
            writer.add_figure('z-yhat/valid', zyhat_scatter, epoch+1)
350
            writer.add_figure('y-yhat/valid', yy_scatter, epoch+1)
351
352
            if params.save_preds:
353
                # writer.add_histogram("predictions", preds)
354
                if setting.num_classes == 1:
355
                    val_preds[:, epoch] = np.squeeze(outtensors['predictions'])
356
                    
357
                    # write preds to file
358
                    pred_fname = os.path.join(setting.home, setting.fase+"-fase", "preds_val.csv")
359
                    with open(pred_fname, 'ab') as f:
360
                        np.savetxt(f, preds.T, newline="")
361
362
                np.save(os.path.join(setting.home, setting.fase+"-fase", "preds.npy"), preds)
363
364
            else:
365
                val_metric = valid_metrics[setting.metrics[0]]
366
            if "loss" in str(setting.metrics[0]):
367
                is_best = val_metric<=best_val_metric
368
            else:
369
                is_best = val_metric>=best_val_metric
370
371
            # Save weights
372
            state_dict = model.state_dict()
373
            optim_dict = optimizer.state_dict()
374
375
            state = {
376
                'epoch': epoch+1,
377
                'state_dict': state_dict,
378
                'optim_dict': optim_dict
379
            }
380
381
382
            utils.save_checkpoint(state,
383
                                is_best=is_best,
384
                                checkpoint=logdir)
385
386
            # If best_eval, best_save_path
387
            valid_metrics["epoch"] = epoch
388
            if is_best:
389
                logging.info("- Found new best {}: {:.3f}".format(setting.metrics[0], val_metric))
390
                best_val_metric = val_metric
391
392
                # Save best val metrics in a json file in the model directory
393
                best_json_path = os.path.join(logdir, "metrics_val_best_weights.json")
394
                utils.save_dict_to_json(valid_metrics, best_json_path)
395
396
            # Save latest val metrics in a json file in the model directory
397
            last_json_path = os.path.join(logdir, "metrics_val_last_weights.json")
398
            utils.save_dict_to_json(valid_metrics, last_json_path)
399
    
400
    # final evaluation
401
    writer.export_scalars_to_json(os.path.join(logdir, "all_scalars.json"))
402
403
    if args.save_preds:
404
        np.save(os.path.join(setting.home, setting.fase + "-fase", "val_preds.npy"), val_preds)
405
406
407
408
if __name__ == '__main__':
409
410
    # Load the parameters from json file
411
    args = parser.parse_args()
412
413
414
    # Load information from last setting if none provided:
415
    last_defaults = utils.Params("last-defaults.json")
416
    if args.setting == "":
417
        print("using last default setting")
418
        args.setting = last_defaults.dict["setting"]
419
        for param, value in last_defaults.dict.items():
420
            print("{}: {}".format(param, value))
421
    else:
422
        with open("last-defaults.json", "r+") as jsonFile:
423
            defaults = json.load(jsonFile)
424
            tmp = defaults["setting"]
425
            defaults["setting"] = args.setting
426
            jsonFile.seek(0)  # rewind
427
            json.dump(defaults, jsonFile)
428
            jsonFile.truncate()
429
430
    # setup visdom environment
431
    # if args.visdom:
432
        # from visdom import Visdom
433
        # viz = Visdom(env=f"lidcr_{args.setting}_{args.fase}_{args.experiment}")
434
435
    # load setting (data generation, regression model etc)
436
    setting_home = os.path.join(args.setting_dir, args.setting)
437
    setting = utils.Params(os.path.join(setting_home, "setting.json"))
438
    setting.home = setting_home
439
440
    # when not specified in call, grab model specification from setting file
441
    if setting.cnn_model == "":
442
        json_path = os.path.join(args.model_dir, "t-suppression", args.experiment+".json")
443
    else:
444
        json_path = os.path.join(args.model_dir, setting.cnn_model, 'params.json')
445
    assert os.path.isfile(json_path), "No json configuration file found at {}".format(json_path)
446
    if not os.path.exists(os.path.join(setting.home, args.fase + "-fase")):
447
        os.makedirs(os.path.join(setting.home, args.fase + "-fase"))
448
    shutil.copy(json_path, os.path.join(setting_home, args.fase + "-fase", "params.json"))
449
    params = utils.Params(json_path)
450
    # covar_mode = setting.covar_mode
451
    # mode3d = setting.mode3d
452
    parallel = args.parallel
453
454
    params.device = None
455
    if not args.disable_cuda and torch.cuda.is_available():
456
        params.device = torch.device('cuda')
457
        params.cuda = True
458
        # switch gpus for better use when running multiple experiments
459
        if not args.parallel:
460
            torch.cuda.set_device(int(args.gpu))
461
    else:
462
        params.device = torch.device('cpu')
463
464
    # adapt fase
465
    setting.fase = args.fase
466
    setting.metrics = pd.Series(setting.metrics).drop_duplicates().tolist()
467
    print("metrics {}:".format(setting.metrics))
468
469
    # Set the random seed for reproducible experiments
470
    torch.manual_seed(230)
471
    if params.cuda: torch.cuda.manual_seed(230)
472
473
    # Set the logger
474
    logdir=os.path.join(setting_home, setting.fase+"-fase", "runs")
475
    if not args.experiment == '':
476
        logdir=os.path.join(logdir, args.experiment)
477
    if not os.path.isdir(logdir):
478
        os.makedirs(logdir)
479
480
    # copy params as backupt to logdir
481
    shutil.copy(json_path, os.path.join(logdir, "params.json"))
482
483
    # utils.set_logger(os.path.join(args.model_dir, 'train.log'))
484
    utils.set_logger(os.path.join(logdir, 'train.log'))
485
486
    # Create the input data pipeline
487
    logging.info("Loading the datasets...")
488
489
    # fetch dataloaders
490
    dataloaders = data_loader.fetch_dataloader(args, params, setting, ["train", "valid"])
491
    train_dl = dataloaders['train']
492
    valid_dl = dataloaders['valid']
493
494
    if setting.num_classes > 1 and params.balance_classes:
495
        train_labels = train_dl.dataset.df[setting.outcome[0]].values
496
        class_weights = compute_class_weight('balanced', np.unique(train_labels), train_labels)
497
    # valid_dl = train_dl
498
499
    logging.info("- done.")
500
501
    if args.intercept:
502
        assert len(setting.outcome) == 1, "Multiple outcomes not implemented for intercept yet"
503
        print("running intercept mode")
504
        mu = valid_dl.dataset.df[setting.outcome].values.mean()
505
        def new_forward(self, x, data, mu=mu):
506
            intercept = torch.autograd.Variable(mu * torch.ones((x.shape[0],1)), requires_grad=False).to(params.device, non_blocking=True)
507
            bn_activations = torch.autograd.Variable(torch.zeros((x.shape[0],)), requires_grad=False).to(params.device, non_blocking=True)
508
            return {setting.outcome[0]: intercept, "bn": bn_activations}
509
510
        net.Net3D.forward = new_forward
511
        params.num_epochs = 1
512
        setting.metrics = []
513
        logdir = os.path.join(logdir, "intercept")
514
515
    if setting.mode3d:
516
        model = net.Net3D(params, setting).to(params.device)
517
    else:
518
        model = net.CausalNet(params, setting).to(params.device)
519
520
    optimizers = {'sgd': optim.SGD, 'adam': optim.Adam}
521
522
    if parallel:
523
        print("parallel mode")
524
        model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
525
526
    if params.momentum > 0:
527
        optimizer = optimizers[params.optimizer](model.parameters(), lr=params.learning_rate, weight_decay=params.wd, momentum=params.momentum)
528
    else:
529
        optimizer = optimizers[params.optimizer](model.parameters(), lr=params.learning_rate, weight_decay=params.wd)
530
531
    # if params.use_mi:
532
    #     optimizer.add_param_group({'params': mine.parameters()})
533
534
    if setting.covar_mode and params.lr_t_factor != 1:
535
        optimizer = net.speedup_t(model, params)
536
537
    if args.restore_last and (not args.cold_start):
538
        print("Loading state dict from last running setting")
539
        utils.load_checkpoint(os.path.join(setting.home, args.fase + "-fase", "last.pth.tar"), model, strict=False)
540
    elif args.restore_warm:
541
        utils.load_checkpoint(os.path.join(setting.home, 'warm-start.pth.tar'), model, strict=False)
542
    else:
543
        pass
544
    
545
    # fetch loss function and metrics
546
    if setting.num_classes > 1 and params.balance_classes:
547
        loss_fn = net.get_loss_fn(setting, weights=class_weights)
548
    else:
549
        loss_fn = net.get_loss_fn(setting)
550
    # metrics = {metric:net.all_metrics[metric] for metric in setting.metrics}
551
    metrics = None
552
553
    if params.monitor_train_tensors:
554
        print(f"Recording all train tensors")
555
        import csv   
556
        train_tensor_keys = ['t','x', 'z', 'y', 'x_hat', 'z_hat', 'y_hat']
557
        with open(os.path.join(logdir, 'train-tensors.csv'), 'w') as f:
558
            writer = csv.writer(f)
559
            writer.writerow(['epoch']+train_tensor_keys)
560
561
    # Train the model
562
    # print(model)
563
    # print(summary(model, (3, 224, 224), batch_size=1))
564
    logging.info("Starting training for {} epoch(s)".format(params.num_epochs))
565
    for split, dl in dataloaders.items():
566
        logging.info("Number of %s samples: %s" % (split, str(len(dl.dataset))))
567
        # logging.info("Number of valid examples: {}".format(len(valid.dataset)))
568
569
    
570
    with SummaryWriter(logdir) as writer:
571
        # train(model, optimizer, loss_fn, train_dl, metrics, params)
572
        train_and_evaluate(model, train_dl, valid_dl, optimizer, loss_fn, metrics, params, setting, args,
573
                           writer, logdir, args.restore_file)