Diff of /eval.py [000000] .. [134fd7]

Switch to unified view

a b/eval.py
1
2
import yaml
3
import tensorboard
4
import torch
5
import os
6
import shutil
7
import sys
8
import csv
9
import argparse
10
import pickle
11
from models.resnet_simclr import ResNetSimCLR
12
from clinical_ts.cpc import CPCModel
13
import torch.nn.functional as F
14
from tqdm import tqdm
15
import numpy as np
16
import matplotlib.pyplot as plt
17
from sklearn.decomposition import PCA
18
from sklearn.manifold import TSNE
19
from sklearn.metrics import roc_auc_score
20
from clinical_ts.simclr_dataset_wrapper import SimCLRDataSetWrapper
21
from clinical_ts.eval_utils_cafa import eval_scores, eval_scores_bootstrap
22
from clinical_ts.timeseries_utils import aggregate_predictions
23
import pdb
24
from copy import deepcopy
25
from os.path import join, isdir
26
device = 'cuda' if torch.cuda.is_available() else 'cpu'
27
28
29
def parse_args():
30
    parser = argparse.ArgumentParser("Finetuning tests")
31
    parser.add_argument("--model_file")
32
    parser.add_argument("--method")
33
    parser.add_argument("--dataset", nargs="+", default="./data/ptb_xl_fs100")
34
    parser.add_argument("--batch_size", type=int, default=512)
35
    parser.add_argument("--discriminative_lr", default=False, action="store_true")
36
    parser.add_argument("--num_workers", type=int, default=8)
37
    parser.add_argument("--hidden", default=False, action="store_true")
38
    parser.add_argument("--lr_schedule", default="{}")
39
    parser.add_argument("--use_pretrained", default=False, action="store_true")
40
    parser.add_argument("--linear_evaluation",
41
                        default=False, action="store_true", help="use linear evaluation")
42
    parser.add_argument("--test_noised", default=False, action="store_true", help="validate also on a noisy dataset")
43
    parser.add_argument("--noise_level", default=1, type=int, help="level of noise induced to the second validations set")
44
    parser.add_argument("--folds", default=8, type=int, help="number of folds used in finetuning (between 1-8)")
45
    parser.add_argument("--tag", default="")
46
    parser.add_argument("--eval_only", action="store_true", default=False, help="only evaluate mode")
47
    parser.add_argument("--load_finetuned", action="store_true", default=False)
48
    parser.add_argument("--test", action="store_true", default=False)
49
    parser.add_argument("--verbose", action="store_true", default=False)
50
    parser.add_argument("--cpc", action="store_true", default=False)
51
    parser.add_argument("--model_location")
52
    parser.add_argument("--l_epochs", type=int, default=0, help="number of head-only epochs (these are performed first)")
53
    parser.add_argument("--f_epochs", type=int, default=0, help="number of finetuning epochs (these are perfomed after head-only training")
54
    parser.add_argument("--normalize", action="store_true", default=False, help="normalize dataset with ptbxl mean and std")
55
    parser.add_argument("--bn_head", action="store_true", default=False)
56
    parser.add_argument("--ps_head", type=float, default=0.0)
57
    parser.add_argument("--conv_encoder", action="store_true", default=False)
58
    parser.add_argument("--base_model", default="xresnet1d50")
59
    parser.add_argument("--widen", default=1, type=int, help="use wide xresnet1d50")
60
    args = parser.parse_args()
61
    return args
62
63
64
def get_new_state_dict(init_state_dict, lightning_state_dict, method="simclr"):
65
    # in case of moco model
66
    from collections import OrderedDict
67
    # lightning_state_dict = lightning_state_dict["state_dict"]
68
    new_state_dict = OrderedDict()
69
    if method != "cpc":
70
        if method == "moco":
71
            for key in init_state_dict:
72
                l_key = "encoder_q." + key
73
                if l_key in lightning_state_dict.keys():
74
                    new_state_dict[key] = lightning_state_dict[l_key]
75
        elif method == "simclr":
76
            for key in init_state_dict:
77
                if "features" in key:
78
                    l_key = key.replace("features", "encoder.features")
79
                if l_key in lightning_state_dict.keys():
80
                    new_state_dict[key] = lightning_state_dict[l_key]
81
        elif method == "swav":
82
83
            for key in init_state_dict:
84
                if "features" in key:
85
                    l_key = key.replace("features", "model.features")
86
                if l_key in lightning_state_dict.keys():
87
                    new_state_dict[key] = lightning_state_dict[l_key]
88
        elif method == "byol":
89
            for key in init_state_dict:
90
                l_key = "online_network.encoder." + key
91
                if l_key in lightning_state_dict.keys():
92
                    new_state_dict[key] = lightning_state_dict[l_key]
93
        else:
94
            raise("method unknown")
95
        new_state_dict["l1.weight"] = init_state_dict["l1.weight"]
96
        new_state_dict["l1.bias"] = init_state_dict["l1.bias"]
97
        if "l2.weight" in init_state_dict.keys():
98
            new_state_dict["l2.weight"] = init_state_dict["l2.weight"]
99
            new_state_dict["l2.bias"] = init_state_dict["l2.bias"]
100
101
        assert(len(init_state_dict) == len(new_state_dict))
102
    else:
103
        for key in init_state_dict:
104
            l_key = "model_cpc." + key
105
            if l_key in lightning_state_dict.keys():
106
                new_state_dict[key] = lightning_state_dict[l_key]
107
            if "head" in key:
108
                new_state_dict[key] = init_state_dict[key]
109
    return new_state_dict
110
111
112
def adjust(model, num_classes, hidden=False):
113
    in_features = model.l1.in_features
114
    last_layer = torch.nn.modules.linear.Linear(
115
        in_features, num_classes).to(device)
116
    if hidden:
117
        model.l1 = torch.nn.modules.linear.Linear(
118
            in_features, in_features).to(device)
119
        model.l2 = last_layer
120
    else:
121
        model.l1 = last_layer
122
123
    def def_forward(self):
124
        def new_forward(x):
125
            h = self.features(x)
126
            h = h.squeeze()
127
128
            x = self.l1(h)
129
            if hidden:
130
                x = F.relu(x)
131
                x = self.l2(x)
132
            return x
133
        return new_forward
134
135
    model.forward = def_forward(model)
136
137
138
def configure_optimizer(model, batch_size, head_only=False, discriminative_lr=False, base_model="xresnet1d", optimizer="adam", discriminative_lr_factor=1):
139
    loss_fn = F.binary_cross_entropy_with_logits
140
    if base_model == "xresnet1d":
141
        wd = 1e-1
142
        if head_only:
143
            lr = (8e-3*(batch_size/256))
144
            optimizer = torch.optim.AdamW(
145
                model.l1.parameters(), lr=lr, weight_decay=wd)
146
        else:
147
            lr = 0.01
148
            if not discriminative_lr:
149
                optimizer = torch.optim.AdamW(
150
                    model.parameters(), lr=lr, weight_decay=wd)
151
            else:
152
                param_dict = dict(model.named_parameters())
153
                keys = param_dict.keys()
154
                weight_layer_nrs = set()
155
                for key in keys:
156
                    if "features" in key:
157
                        # parameter names have the form features.x
158
                        weight_layer_nrs.add(key[9])
159
                weight_layer_nrs = sorted(weight_layer_nrs, reverse=True)
160
                features_groups = []
161
                while len(weight_layer_nrs) > 0:
162
                    if len(weight_layer_nrs) > 1:
163
                        features_groups.append(list(filter(
164
                            lambda x: "features." + weight_layer_nrs[0] in x or "features." + weight_layer_nrs[1] in x,  keys)))
165
                        del weight_layer_nrs[:2]
166
                    else:
167
                        features_groups.append(
168
                            list(filter(lambda x: "features." + weight_layer_nrs[0] in x,  keys)))
169
                        del weight_layer_nrs[0]
170
                # filter linear layers
171
                linears = list(filter(lambda x: "l" in x, keys))
172
                groups = [linears] + features_groups
173
                optimizer_param_list = []
174
                tmp_lr = lr
175
176
                for layers in groups:
177
                    layer_params = [param_dict[param_name]
178
                                    for param_name in layers]
179
                    optimizer_param_list.append(
180
                        {"params": layer_params, "lr": tmp_lr})
181
                    tmp_lr /= 4
182
                optimizer = torch.optim.AdamW(
183
                    optimizer_param_list, lr=lr, weight_decay=wd)
184
185
        print("lr", lr)
186
        print("wd", wd)
187
        print("batch size", batch_size)
188
189
    elif base_model == "cpc":
190
        if(optimizer == "sgd"):
191
            opt = torch.optim.SGD
192
        elif(optimizer == "adam"):
193
            opt = torch.optim.AdamW
194
        else:
195
            raise NotImplementedError("Unknown Optimizer.")
196
        lr = 1e-4
197
        wd = 1e-3
198
        if(head_only):
199
            lr = 1e-3
200
            print("Linear eval: model head", model.head)
201
            optimizer = opt(model.head.parameters(), lr, weight_decay=wd)
202
        elif(discriminative_lr_factor != 1.):  # discrimative lrs
203
            optimizer = opt([{"params": model.encoder.parameters(), "lr": lr*discriminative_lr_factor*discriminative_lr_factor}, {
204
                            "params": model.rnn.parameters(), "lr": lr*discriminative_lr_factor}, {"params": model.head.parameters(), "lr": lr}], lr, weight_decay=wd)
205
            print("Finetuning: model head", model.head)
206
            print("discriminative lr: ", discriminative_lr_factor)
207
        else:
208
            lr = 1e-3
209
            print("normal supervised training")
210
            optimizer = opt(model.parameters(), lr, weight_decay=wd)
211
    else:
212
        raise("model unknown")
213
    return loss_fn, optimizer
214
215
216
def load_model(linear_evaluation, num_classes, use_pretrained, discriminative_lr=False, hidden=False, conv_encoder=False, bn_head=False, ps_head=0.5, location="./checkpoints/moco_baselinewonder200.ckpt", method="simclr", base_model="xresnet1d50", out_dim=16, widen=1):
217
    discriminative_lr_factor = 1
218
    if use_pretrained:
219
        print("load model from " + location)
220
        discriminative_lr_factor = 0.1
221
        if base_model == "cpc":
222
            lightning_state_dict = torch.load(location, map_location=device)
223
224
            # num_head = np.sum([1 if 'proj' in f else 0 for f in lightning_state_dict.keys()])
225
            if linear_evaluation:
226
                lin_ftrs_head = []
227
                bn_head = False
228
                ps_head = 0.0
229
            else:
230
                if hidden:
231
                    lin_ftrs_head = [512]
232
                else:
233
                    lin_ftrs_head = []
234
235
            if conv_encoder:
236
                strides = [2, 2, 2, 2]
237
                kss = [10, 4, 4, 4]
238
            else:
239
                strides = [1]*4
240
                kss = [1]*4
241
242
            model = CPCModel(input_channels=12, strides=strides, kss=kss, features=[512]*4, n_hidden=512, n_layers=2, mlp=False, lstm=True, bias_proj=False,
243
                             num_classes=num_classes, skip_encoder=False, bn_encoder=True, lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_head=bn_head).to(device)
244
245
            if "state_dict" in lightning_state_dict.keys():
246
                print("load pretrained model")
247
                model_state_dict = get_new_state_dict(
248
                    model.state_dict(), lightning_state_dict["state_dict"], method="cpc")
249
            else:
250
                print("load already finetuned model")
251
                model_state_dict = lightning_state_dict
252
            model.load_state_dict(model_state_dict)
253
        else:
254
            model = ResNetSimCLR(base_model, out_dim, hidden=hidden, widen=widen).to(device)
255
            model_state_dict = torch.load(location, map_location=device)
256
            if "state_dict" in model_state_dict.keys():
257
                model_state_dict = model_state_dict["state_dict"]
258
            if "l1.weight" in model_state_dict.keys():  # load already fine-tuned model
259
                model_classes = model_state_dict["l1.weight"].shape[0]
260
                if model_classes != num_classes:
261
                    raise Exception("Loaded model has different output dim ({}) than needed ({})".format(
262
                        model_classes, num_classes))
263
                adjust(model, num_classes, hidden=hidden)
264
                if not hidden and "l2.weight" in model_state_dict:
265
                    del model_state_dict["l2.weight"]
266
                    del model_state_dict["l2.bias"]
267
                model.load_state_dict(model_state_dict)
268
            else:  # load pretrained model
269
                base_dict = model.state_dict()
270
                model_state_dict = get_new_state_dict(
271
                    base_dict, model_state_dict, method=method)
272
                model.load_state_dict(model_state_dict)
273
                adjust(model, num_classes, hidden=hidden)
274
275
    else:
276
        if "xresnet1d" in base_model:
277
            model = ResNetSimCLR(base_model, out_dim, hidden=hidden, widen=widen).to(device)
278
            adjust(model, num_classes, hidden=hidden)
279
        elif base_model == "cpc":
280
            if linear_evaluation:
281
                lin_ftrs_head = []
282
                bn_head = False
283
                ps_head = 0.0
284
            else:
285
                if hidden:
286
                    lin_ftrs_head = [512]
287
                else:
288
                    lin_ftrs_head = []
289
290
            if conv_encoder:
291
                strides = [2, 2, 2, 2]
292
                kss = [10, 4, 4, 4]
293
            else:
294
                strides = [1]*4
295
                kss = [1]*4
296
297
            model = CPCModel(input_channels=12, strides=strides, kss=kss, features=[512]*4, n_hidden=512, n_layers=2, mlp=False, lstm=True, bias_proj=False,
298
                             num_classes=num_classes, skip_encoder=False, bn_encoder=True, lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_head=bn_head).to(device)
299
300
        else:
301
            raise Exception("model unknown")
302
303
    return model
304
305
306
def evaluate(model, dataloader, idmap, lbl_itos, cpc=False):
307
    preds, targs = eval_model(model, dataloader, cpc=cpc)
308
    scores = eval_scores(targs, preds, classes=lbl_itos, parallel=True)
309
    preds_agg, targs_agg = aggregate_predictions(preds, targs, idmap)
310
    scores_agg = eval_scores(targs_agg, preds_agg,
311
                             classes=lbl_itos, parallel=True)
312
    macro = scores["label_AUC"]["macro"]
313
    macro_agg = scores_agg["label_AUC"]["macro"]
314
    return preds, macro, macro_agg
315
316
317
def set_train_eval(model, cpc, linear_evaluation):
318
    if linear_evaluation:
319
        if cpc:
320
            model.encoder.eval()
321
        else:
322
            model.features.eval()
323
    else:
324
        model.train()
325
326
327
def train_model(model, train_loader, valid_loader, test_loader, epochs, loss_fn, optimizer, head_only=True, linear_evaluation=False, percentage=1, lr_schedule=None, save_model_at=None, val_idmap=None, test_idmap=None, lbl_itos=None, cpc=False):
328
    if head_only:
329
        if linear_evaluation:
330
            print("linear evaluation for {} epochs".format(epochs))
331
        else:
332
            print("head-only for {} epochs".format(epochs))
333
    else:
334
        print("fine tuning for {} epochs".format(epochs))
335
336
    if head_only:
337
        for key, param in model.named_parameters():
338
            if "l1." not in key and "head." not in key:
339
                param.requires_grad = False
340
        print("copying state dict before training for sanity check after training")
341
342
    else:
343
        for param in model.parameters():
344
            param.requires_grad = True
345
    if cpc:
346
        data_type = model.encoder[0][0].weight.type()
347
    else:
348
        data_type = model.features[0][0].weight.type()
349
350
    set_train_eval(model, cpc, linear_evaluation)
351
    state_dict_pre = deepcopy(model.state_dict())
352
    print("epoch", "batch", "loss\n========================")
353
    loss_per_epoch = []
354
    macro_agg_per_epoch = []
355
    max_batches = len(train_loader)
356
    break_point = int(percentage*max_batches)
357
    best_macro = 0
358
    best_macro_agg = 0
359
    best_epoch = 0
360
    best_preds = None
361
    test_macro = 0
362
    test_macro_agg = 0
363
    for epoch in tqdm(range(epochs)):
364
        if type(lr_schedule) == dict:
365
            if epoch in lr_schedule.keys():
366
                for param_group in optimizer.param_groups:
367
                    param_group['lr'] /= lr_schedule[epoch]
368
        total_loss_one_epoch = 0
369
        for batch_idx, samples in enumerate(train_loader):
370
            if batch_idx == break_point:
371
                print("break at batch nr.", batch_idx)
372
                break
373
            data = samples[0].to(device).type(data_type)
374
            labels = samples[1].to(device).type(data_type)
375
            optimizer.zero_grad()
376
            preds = model(data)
377
            loss = loss_fn(preds, labels)
378
            loss.backward()
379
            optimizer.step()
380
            total_loss_one_epoch += loss.item()
381
            if(batch_idx % 100 == 0):
382
                print(epoch, batch_idx, loss.item())
383
        loss_per_epoch.append(total_loss_one_epoch)
384
385
        preds, macro, macro_agg = evaluate(
386
            model, valid_loader, val_idmap, lbl_itos, cpc=cpc)
387
        macro_agg_per_epoch.append(macro_agg)
388
389
        print("loss:", total_loss_one_epoch)
390
        print("aggregated macro:", macro_agg)
391
        if macro_agg > best_macro_agg:
392
            torch.save(model.state_dict(), save_model_at)
393
            best_macro_agg = macro_agg
394
            best_macro = macro
395
            best_epoch = epoch
396
            best_preds = preds
397
            _, test_macro, test_macro_agg = evaluate(
398
                model, test_loader, test_idmap, lbl_itos, cpc=cpc)
399
400
        set_train_eval(model, cpc, linear_evaluation)
401
402
    if epochs > 0:
403
        sanity_check(model, state_dict_pre, linear_evaluation, head_only)
404
    return loss_per_epoch, macro_agg_per_epoch, best_macro, best_macro_agg, test_macro, test_macro_agg, best_epoch, best_preds
405
406
407
def sanity_check(model, state_dict_pre, linear_evaluation, head_only):
408
    """
409
    Linear classifier should not change any weights other than the linear layer.
410
    This sanity check asserts nothing wrong happens (e.g., BN stats updated).
411
    """
412
    print("=> loading state dict for sanity check")
413
    state_dict = model.state_dict()
414
    if linear_evaluation:
415
        for k in list(state_dict.keys()):
416
            # only ignore fc layer
417
            if 'fc.' in k or 'head.' in k or 'l1.' in k:
418
                continue
419
420
            equals = (state_dict[k].cpu() == state_dict_pre[k].cpu()).all()
421
            if (linear_evaluation != equals):
422
                raise Exception(
423
                    '=> failed sanity check in {}'.format("linear_evaluation"))
424
    elif head_only:
425
        for k in list(state_dict.keys()):
426
            # only ignore fc layer
427
            if 'fc.' in k or 'head.' in k:
428
                continue
429
430
            equals = (state_dict[k].cpu() == state_dict_pre[k].cpu()).all()
431
            if (equals and "running_mean" in k):
432
                raise Exception(
433
                    '=> failed sanity check in {}'.format("head-only"))
434
    # else:
435
    #     for k in list(state_dict.keys()):
436
    #         equals=(state_dict[k].cpu() == state_dict_pre[k].cpu()).all()
437
    #         if equals:
438
    #             pdb.set_trace()
439
    #             raise Exception('=> failed sanity check in {}'.format("fine_tuning"))
440
441
    print("=> sanity check passed.")
442
443
444
def eval_model(model, valid_loader, cpc=False):
445
    if cpc:
446
        data_type = model.encoder[0][0].weight.type()
447
    else:
448
        data_type = model.features[0][0].weight.type()
449
    model.eval()
450
    preds = []
451
    targs = []
452
    with torch.no_grad():
453
        for batch_idx, samples in tqdm(enumerate(valid_loader)):
454
            data = samples[0].to(device).type(data_type)
455
            preds_tmp = torch.sigmoid(model(data))
456
            targs.append(samples[1])
457
            preds.append(preds_tmp.cpu())
458
        preds = torch.cat(preds).numpy()
459
        targs = torch.cat(targs).numpy()
460
461
    return preds, targs
462
463
464
def get_dataset(batch_size, num_workers, target_folder, apply_noise=False, percentage=1.0, folds=8, t_params=None, test=False, normalize=False):
465
    if apply_noise:
466
        transformations = ["BaselineWander",
467
                           "PowerlineNoise", "EMNoise", "BaselineShift"]
468
        if normalize:
469
            transformations.append("Normalize")
470
        dataset = SimCLRDataSetWrapper(batch_size,num_workers,None,"(12, 250)",None,target_folder,[target_folder],None,None,
471
                                       mode="linear_evaluation", transformations=transformations, percentage=percentage, folds=folds, t_params=t_params, test=test, ptb_xl_label="label_all")
472
    else:
473
        if normalize:
474
            # always use PTB-XL stats
475
            transformations = ["Normalize"]
476
            dataset = SimCLRDataSetWrapper(batch_size,num_workers,None,"(12, 250)",None,target_folder,[target_folder],None,None,
477
            mode="linear_evaluation", percentage=percentage, folds=folds, test=test, transformations=transformations, ptb_xl_label="label_all")
478
        else:
479
            dataset = SimCLRDataSetWrapper(batch_size,num_workers,None,"(12, 250)",None,target_folder,[target_folder],None,None,
480
                                           mode="linear_evaluation", percentage=percentage, folds=folds, test=test, ptb_xl_label="label_all")
481
482
    train_loader, valid_loader = dataset.get_data_loaders()
483
    return dataset, train_loader, valid_loader
484
485
486
if __name__ == "__main__":
487
    args = parse_args()
488
    dataset, train_loader, _ = get_dataset(
489
        args.batch_size, args.num_workers, args.dataset, folds=args.folds, test=args.test, normalize=args.normalize)
490
    _, _, valid_loader = get_dataset(
491
        args.batch_size, args.num_workers, args.dataset, folds=args.folds, test=False, normalize=args.normalize)
492
    val_idmap = dataset.val_ds_idmap
493
    dataset, _, test_loader = get_dataset(
494
        args.batch_size, args.num_workers, args.dataset, test=True, normalize=args.normalize)
495
    test_idmap = dataset.val_ds_idmap
496
    lbl_itos = dataset.lbl_itos
497
    tag = "f=" + str(args.folds) + "_" + args.tag
498
    tag = tag if args.use_pretrained else "ran_" + tag
499
    tag = "eval_" + tag if args.eval_only else tag
500
    model_tag = "finetuned" if args.load_finetuned else "ckpt"
501
    if args.test_noised:
502
        t_params_by_level = {
503
            1: {"bw_cmax": 0.05, "em_cmax": 0.25, "pl_cmax": 0.1, "bs_cmax": 0.5},
504
            2: {"bw_cmax": 0.1, "em_cmax": 0.5, "pl_cmax": 0.2, "bs_cmax": 1},
505
            3: {"bw_cmax": 0.1, "em_cmax": 1, "pl_cmax": 0.2, "bs_cmax": 2},
506
            4: {"bw_cmax": 0.2, "em_cmax": 1, "pl_cmax": 0.4, "bs_cmax": 2},
507
            5: {"bw_cmax": 0.2, "em_cmax": 1.5, "pl_cmax": 0.4, "bs_cmax": 2.5},
508
            6: {"bw_cmax": 0.3, "em_cmax": 2, "pl_cmax": 0.5, "bs_cmax": 3},
509
        }
510
        if args.noise_level not in t_params_by_level.keys():
511
            raise("noise level does not exist")
512
        t_params = t_params_by_level[args.noise_level]
513
        dataset, _, noise_valid_loader = get_dataset(
514
            args.batch_size, args.num_workers, args.dataset, apply_noise=True, t_params=t_params, test=args.test)
515
    else:
516
        noise_valid_loader = None
517
    losses, macros, predss, result_macros, result_macros_agg, test_macros, test_macros_agg, noised_macros, noised_macros_agg = [
518
    ], [], [], [], [], [], [], [], []
519
    ckpt_epoch_lin=0
520
    ckpt_epoch_fin=0
521
    if args.f_epochs == 0:
522
        save_model_at = os.path.join(os.path.dirname(
523
            args.model_file), "n=" + str(args.noise_level) + "_"+tag + "lin_finetuned")
524
        filename = os.path.join(os.path.dirname(
525
            args.model_file), "n=" + str(args.noise_level) + "_"+tag + "res_lin.pkl")
526
    else:
527
        save_model_at = os.path.join(os.path.dirname(
528
            args.model_file), "n=" + str(args.noise_level) + "_"+tag + "fin_finetuned")
529
        filename = os.path.join(os.path.dirname(
530
            args.model_file), "n=" + str(args.noise_level) + "_"+tag + "res_fin.pkl")
531
532
    model = load_model(
533
        args.linear_evaluation, 71, args.use_pretrained or args.load_finetuned, hidden=args.hidden,
534
        location=args.model_file, discriminative_lr=args.discriminative_lr, method=args.method)
535
    loss_fn, optimizer = configure_optimizer(
536
        model, args.batch_size, head_only=True, discriminative_lr=args.discriminative_lr, discriminative_lr_factor=0.1 if args.use_pretrained and args.discriminative_lr else 1)
537
    if not args.eval_only:
538
        print("train model...")
539
        if not isdir(save_model_at):
540
            os.mkdir(save_model_at)
541
542
        l1, m1, bm, bm_agg, tm, tm_agg, ckpt_epoch_lin, preds = train_model(model, train_loader, valid_loader, test_loader, args.l_epochs, loss_fn,
543
                                                                            optimizer, head_only=True, linear_evaluation=args.linear_evaluation, lr_schedule=args.lr_schedule, save_model_at=join(save_model_at, "finetuned.pt"),
544
                                                                            val_idmap=val_idmap, test_idmap=test_idmap, lbl_itos=lbl_itos, cpc=(args.method == "cpc"))
545
        if bm != 0:
546
            print("best macro after head-only training:", bm_agg)
547
        l2 = []
548
        m2 = []
549
        if args.f_epochs != 0:
550
            if args.l_epochs != 0:
551
                model = load_model(
552
                    False, 71, True, hidden=args.hidden,
553
                    location=join(save_model_at, "finetuned.pt"), discriminative_lr=args.discriminative_lr, method=args.method)
554
            loss_fn, optimizer = configure_optimizer(
555
                model, args.batch_size, head_only=False, discriminative_lr=args.discriminative_lr, discriminative_lr_factor=0.1 if args.use_pretrained and args.discriminative_lr else 1)
556
            l2, m2, bm, bm_agg, tm, tm_agg, ckpt_epoch_fin, preds = train_model(model, train_loader, valid_loader, test_loader, args.f_epochs, loss_fn,
557
                                                                                optimizer, head_only=False, linear_evaluation=False, lr_schedule=args.lr_schedule, save_model_at=join(save_model_at, "finetuned.pt"),
558
                                                                                val_idmap=val_idmap, test_idmap=test_idmap, lbl_itos=lbl_itos, cpc=(args.method == "cpc"))
559
        losses.append(l1+l2)
560
        macros.append(m1+m2)
561
        test_macros.append(tm)
562
        test_macros_agg.append(tm_agg)
563
        result_macros.append(bm)
564
        result_macros_agg.append(bm_agg)
565
566
    else:
567
        preds, eval_macro, eval_macro_agg = evaluate(
568
            model, test_loader, test_idmap, lbl_itos, cpc=(args.method == "cpc"))
569
        result_macros.append(eval_macro)
570
        result_macros_agg.append(eval_macro_agg)
571
        if args.verbose:
572
            print("macro:", eval_macro)
573
    predss.append(preds)
574
575
    if noise_valid_loader is not None:
576
        _, noise_macro, noise_macro_agg = evaluate(
577
            model, noise_valid_loader, val_idmap, lbl_itos)
578
        noised_macros.append(noise_macro)
579
        noised_macros_agg.append(noise_macro_agg)
580
    res = {"filename": filename, "epochs": args.l_epochs+args.f_epochs, "model_location": args.model_location,
581
           "losses": losses, "macros": macros, "predss": predss, "result_macros": result_macros, "result_macros_agg": result_macros_agg,
582
           "test_macros": test_macros, "test_macros_agg": test_macros_agg, "noised_macros": noised_macros, "noised_macros_agg": noised_macros_agg, "ckpt_epoch_lin": ckpt_epoch_lin, "ckpt_epoch_fin": ckpt_epoch_fin,
583
           "discriminative_lr": args.discriminative_lr, "hidden": args.hidden, "lr_schedule": args.lr_schedule,
584
           "use_pretrained": args.use_pretrained, "linear_evaluation": args.linear_evaluation, "loaded_finetuned": args.load_finetuned,
585
           "eval_only": args.eval_only, "noise_level": args.noise_level, "test_noised": args.test_noised, "normalized": args.normalize}
586
    pickle.dump(res, open(filename, "wb"))
587
    print("dumped results to", filename)
588
    print(res)
589
    print("Done!")
590