|
a |
|
b/eval.py |
|
|
1 |
|
|
|
2 |
import yaml |
|
|
3 |
import tensorboard |
|
|
4 |
import torch |
|
|
5 |
import os |
|
|
6 |
import shutil |
|
|
7 |
import sys |
|
|
8 |
import csv |
|
|
9 |
import argparse |
|
|
10 |
import pickle |
|
|
11 |
from models.resnet_simclr import ResNetSimCLR |
|
|
12 |
from clinical_ts.cpc import CPCModel |
|
|
13 |
import torch.nn.functional as F |
|
|
14 |
from tqdm import tqdm |
|
|
15 |
import numpy as np |
|
|
16 |
import matplotlib.pyplot as plt |
|
|
17 |
from sklearn.decomposition import PCA |
|
|
18 |
from sklearn.manifold import TSNE |
|
|
19 |
from sklearn.metrics import roc_auc_score |
|
|
20 |
from clinical_ts.simclr_dataset_wrapper import SimCLRDataSetWrapper |
|
|
21 |
from clinical_ts.eval_utils_cafa import eval_scores, eval_scores_bootstrap |
|
|
22 |
from clinical_ts.timeseries_utils import aggregate_predictions |
|
|
23 |
import pdb |
|
|
24 |
from copy import deepcopy |
|
|
25 |
from os.path import join, isdir |
|
|
26 |
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
27 |
|
|
|
28 |
|
|
|
29 |
def parse_args(): |
|
|
30 |
parser = argparse.ArgumentParser("Finetuning tests") |
|
|
31 |
parser.add_argument("--model_file") |
|
|
32 |
parser.add_argument("--method") |
|
|
33 |
parser.add_argument("--dataset", nargs="+", default="./data/ptb_xl_fs100") |
|
|
34 |
parser.add_argument("--batch_size", type=int, default=512) |
|
|
35 |
parser.add_argument("--discriminative_lr", default=False, action="store_true") |
|
|
36 |
parser.add_argument("--num_workers", type=int, default=8) |
|
|
37 |
parser.add_argument("--hidden", default=False, action="store_true") |
|
|
38 |
parser.add_argument("--lr_schedule", default="{}") |
|
|
39 |
parser.add_argument("--use_pretrained", default=False, action="store_true") |
|
|
40 |
parser.add_argument("--linear_evaluation", |
|
|
41 |
default=False, action="store_true", help="use linear evaluation") |
|
|
42 |
parser.add_argument("--test_noised", default=False, action="store_true", help="validate also on a noisy dataset") |
|
|
43 |
parser.add_argument("--noise_level", default=1, type=int, help="level of noise induced to the second validations set") |
|
|
44 |
parser.add_argument("--folds", default=8, type=int, help="number of folds used in finetuning (between 1-8)") |
|
|
45 |
parser.add_argument("--tag", default="") |
|
|
46 |
parser.add_argument("--eval_only", action="store_true", default=False, help="only evaluate mode") |
|
|
47 |
parser.add_argument("--load_finetuned", action="store_true", default=False) |
|
|
48 |
parser.add_argument("--test", action="store_true", default=False) |
|
|
49 |
parser.add_argument("--verbose", action="store_true", default=False) |
|
|
50 |
parser.add_argument("--cpc", action="store_true", default=False) |
|
|
51 |
parser.add_argument("--model_location") |
|
|
52 |
parser.add_argument("--l_epochs", type=int, default=0, help="number of head-only epochs (these are performed first)") |
|
|
53 |
parser.add_argument("--f_epochs", type=int, default=0, help="number of finetuning epochs (these are perfomed after head-only training") |
|
|
54 |
parser.add_argument("--normalize", action="store_true", default=False, help="normalize dataset with ptbxl mean and std") |
|
|
55 |
parser.add_argument("--bn_head", action="store_true", default=False) |
|
|
56 |
parser.add_argument("--ps_head", type=float, default=0.0) |
|
|
57 |
parser.add_argument("--conv_encoder", action="store_true", default=False) |
|
|
58 |
parser.add_argument("--base_model", default="xresnet1d50") |
|
|
59 |
parser.add_argument("--widen", default=1, type=int, help="use wide xresnet1d50") |
|
|
60 |
args = parser.parse_args() |
|
|
61 |
return args |
|
|
62 |
|
|
|
63 |
|
|
|
64 |
def get_new_state_dict(init_state_dict, lightning_state_dict, method="simclr"): |
|
|
65 |
# in case of moco model |
|
|
66 |
from collections import OrderedDict |
|
|
67 |
# lightning_state_dict = lightning_state_dict["state_dict"] |
|
|
68 |
new_state_dict = OrderedDict() |
|
|
69 |
if method != "cpc": |
|
|
70 |
if method == "moco": |
|
|
71 |
for key in init_state_dict: |
|
|
72 |
l_key = "encoder_q." + key |
|
|
73 |
if l_key in lightning_state_dict.keys(): |
|
|
74 |
new_state_dict[key] = lightning_state_dict[l_key] |
|
|
75 |
elif method == "simclr": |
|
|
76 |
for key in init_state_dict: |
|
|
77 |
if "features" in key: |
|
|
78 |
l_key = key.replace("features", "encoder.features") |
|
|
79 |
if l_key in lightning_state_dict.keys(): |
|
|
80 |
new_state_dict[key] = lightning_state_dict[l_key] |
|
|
81 |
elif method == "swav": |
|
|
82 |
|
|
|
83 |
for key in init_state_dict: |
|
|
84 |
if "features" in key: |
|
|
85 |
l_key = key.replace("features", "model.features") |
|
|
86 |
if l_key in lightning_state_dict.keys(): |
|
|
87 |
new_state_dict[key] = lightning_state_dict[l_key] |
|
|
88 |
elif method == "byol": |
|
|
89 |
for key in init_state_dict: |
|
|
90 |
l_key = "online_network.encoder." + key |
|
|
91 |
if l_key in lightning_state_dict.keys(): |
|
|
92 |
new_state_dict[key] = lightning_state_dict[l_key] |
|
|
93 |
else: |
|
|
94 |
raise("method unknown") |
|
|
95 |
new_state_dict["l1.weight"] = init_state_dict["l1.weight"] |
|
|
96 |
new_state_dict["l1.bias"] = init_state_dict["l1.bias"] |
|
|
97 |
if "l2.weight" in init_state_dict.keys(): |
|
|
98 |
new_state_dict["l2.weight"] = init_state_dict["l2.weight"] |
|
|
99 |
new_state_dict["l2.bias"] = init_state_dict["l2.bias"] |
|
|
100 |
|
|
|
101 |
assert(len(init_state_dict) == len(new_state_dict)) |
|
|
102 |
else: |
|
|
103 |
for key in init_state_dict: |
|
|
104 |
l_key = "model_cpc." + key |
|
|
105 |
if l_key in lightning_state_dict.keys(): |
|
|
106 |
new_state_dict[key] = lightning_state_dict[l_key] |
|
|
107 |
if "head" in key: |
|
|
108 |
new_state_dict[key] = init_state_dict[key] |
|
|
109 |
return new_state_dict |
|
|
110 |
|
|
|
111 |
|
|
|
112 |
def adjust(model, num_classes, hidden=False): |
|
|
113 |
in_features = model.l1.in_features |
|
|
114 |
last_layer = torch.nn.modules.linear.Linear( |
|
|
115 |
in_features, num_classes).to(device) |
|
|
116 |
if hidden: |
|
|
117 |
model.l1 = torch.nn.modules.linear.Linear( |
|
|
118 |
in_features, in_features).to(device) |
|
|
119 |
model.l2 = last_layer |
|
|
120 |
else: |
|
|
121 |
model.l1 = last_layer |
|
|
122 |
|
|
|
123 |
def def_forward(self): |
|
|
124 |
def new_forward(x): |
|
|
125 |
h = self.features(x) |
|
|
126 |
h = h.squeeze() |
|
|
127 |
|
|
|
128 |
x = self.l1(h) |
|
|
129 |
if hidden: |
|
|
130 |
x = F.relu(x) |
|
|
131 |
x = self.l2(x) |
|
|
132 |
return x |
|
|
133 |
return new_forward |
|
|
134 |
|
|
|
135 |
model.forward = def_forward(model) |
|
|
136 |
|
|
|
137 |
|
|
|
138 |
def configure_optimizer(model, batch_size, head_only=False, discriminative_lr=False, base_model="xresnet1d", optimizer="adam", discriminative_lr_factor=1): |
|
|
139 |
loss_fn = F.binary_cross_entropy_with_logits |
|
|
140 |
if base_model == "xresnet1d": |
|
|
141 |
wd = 1e-1 |
|
|
142 |
if head_only: |
|
|
143 |
lr = (8e-3*(batch_size/256)) |
|
|
144 |
optimizer = torch.optim.AdamW( |
|
|
145 |
model.l1.parameters(), lr=lr, weight_decay=wd) |
|
|
146 |
else: |
|
|
147 |
lr = 0.01 |
|
|
148 |
if not discriminative_lr: |
|
|
149 |
optimizer = torch.optim.AdamW( |
|
|
150 |
model.parameters(), lr=lr, weight_decay=wd) |
|
|
151 |
else: |
|
|
152 |
param_dict = dict(model.named_parameters()) |
|
|
153 |
keys = param_dict.keys() |
|
|
154 |
weight_layer_nrs = set() |
|
|
155 |
for key in keys: |
|
|
156 |
if "features" in key: |
|
|
157 |
# parameter names have the form features.x |
|
|
158 |
weight_layer_nrs.add(key[9]) |
|
|
159 |
weight_layer_nrs = sorted(weight_layer_nrs, reverse=True) |
|
|
160 |
features_groups = [] |
|
|
161 |
while len(weight_layer_nrs) > 0: |
|
|
162 |
if len(weight_layer_nrs) > 1: |
|
|
163 |
features_groups.append(list(filter( |
|
|
164 |
lambda x: "features." + weight_layer_nrs[0] in x or "features." + weight_layer_nrs[1] in x, keys))) |
|
|
165 |
del weight_layer_nrs[:2] |
|
|
166 |
else: |
|
|
167 |
features_groups.append( |
|
|
168 |
list(filter(lambda x: "features." + weight_layer_nrs[0] in x, keys))) |
|
|
169 |
del weight_layer_nrs[0] |
|
|
170 |
# filter linear layers |
|
|
171 |
linears = list(filter(lambda x: "l" in x, keys)) |
|
|
172 |
groups = [linears] + features_groups |
|
|
173 |
optimizer_param_list = [] |
|
|
174 |
tmp_lr = lr |
|
|
175 |
|
|
|
176 |
for layers in groups: |
|
|
177 |
layer_params = [param_dict[param_name] |
|
|
178 |
for param_name in layers] |
|
|
179 |
optimizer_param_list.append( |
|
|
180 |
{"params": layer_params, "lr": tmp_lr}) |
|
|
181 |
tmp_lr /= 4 |
|
|
182 |
optimizer = torch.optim.AdamW( |
|
|
183 |
optimizer_param_list, lr=lr, weight_decay=wd) |
|
|
184 |
|
|
|
185 |
print("lr", lr) |
|
|
186 |
print("wd", wd) |
|
|
187 |
print("batch size", batch_size) |
|
|
188 |
|
|
|
189 |
elif base_model == "cpc": |
|
|
190 |
if(optimizer == "sgd"): |
|
|
191 |
opt = torch.optim.SGD |
|
|
192 |
elif(optimizer == "adam"): |
|
|
193 |
opt = torch.optim.AdamW |
|
|
194 |
else: |
|
|
195 |
raise NotImplementedError("Unknown Optimizer.") |
|
|
196 |
lr = 1e-4 |
|
|
197 |
wd = 1e-3 |
|
|
198 |
if(head_only): |
|
|
199 |
lr = 1e-3 |
|
|
200 |
print("Linear eval: model head", model.head) |
|
|
201 |
optimizer = opt(model.head.parameters(), lr, weight_decay=wd) |
|
|
202 |
elif(discriminative_lr_factor != 1.): # discrimative lrs |
|
|
203 |
optimizer = opt([{"params": model.encoder.parameters(), "lr": lr*discriminative_lr_factor*discriminative_lr_factor}, { |
|
|
204 |
"params": model.rnn.parameters(), "lr": lr*discriminative_lr_factor}, {"params": model.head.parameters(), "lr": lr}], lr, weight_decay=wd) |
|
|
205 |
print("Finetuning: model head", model.head) |
|
|
206 |
print("discriminative lr: ", discriminative_lr_factor) |
|
|
207 |
else: |
|
|
208 |
lr = 1e-3 |
|
|
209 |
print("normal supervised training") |
|
|
210 |
optimizer = opt(model.parameters(), lr, weight_decay=wd) |
|
|
211 |
else: |
|
|
212 |
raise("model unknown") |
|
|
213 |
return loss_fn, optimizer |
|
|
214 |
|
|
|
215 |
|
|
|
216 |
def load_model(linear_evaluation, num_classes, use_pretrained, discriminative_lr=False, hidden=False, conv_encoder=False, bn_head=False, ps_head=0.5, location="./checkpoints/moco_baselinewonder200.ckpt", method="simclr", base_model="xresnet1d50", out_dim=16, widen=1): |
|
|
217 |
discriminative_lr_factor = 1 |
|
|
218 |
if use_pretrained: |
|
|
219 |
print("load model from " + location) |
|
|
220 |
discriminative_lr_factor = 0.1 |
|
|
221 |
if base_model == "cpc": |
|
|
222 |
lightning_state_dict = torch.load(location, map_location=device) |
|
|
223 |
|
|
|
224 |
# num_head = np.sum([1 if 'proj' in f else 0 for f in lightning_state_dict.keys()]) |
|
|
225 |
if linear_evaluation: |
|
|
226 |
lin_ftrs_head = [] |
|
|
227 |
bn_head = False |
|
|
228 |
ps_head = 0.0 |
|
|
229 |
else: |
|
|
230 |
if hidden: |
|
|
231 |
lin_ftrs_head = [512] |
|
|
232 |
else: |
|
|
233 |
lin_ftrs_head = [] |
|
|
234 |
|
|
|
235 |
if conv_encoder: |
|
|
236 |
strides = [2, 2, 2, 2] |
|
|
237 |
kss = [10, 4, 4, 4] |
|
|
238 |
else: |
|
|
239 |
strides = [1]*4 |
|
|
240 |
kss = [1]*4 |
|
|
241 |
|
|
|
242 |
model = CPCModel(input_channels=12, strides=strides, kss=kss, features=[512]*4, n_hidden=512, n_layers=2, mlp=False, lstm=True, bias_proj=False, |
|
|
243 |
num_classes=num_classes, skip_encoder=False, bn_encoder=True, lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_head=bn_head).to(device) |
|
|
244 |
|
|
|
245 |
if "state_dict" in lightning_state_dict.keys(): |
|
|
246 |
print("load pretrained model") |
|
|
247 |
model_state_dict = get_new_state_dict( |
|
|
248 |
model.state_dict(), lightning_state_dict["state_dict"], method="cpc") |
|
|
249 |
else: |
|
|
250 |
print("load already finetuned model") |
|
|
251 |
model_state_dict = lightning_state_dict |
|
|
252 |
model.load_state_dict(model_state_dict) |
|
|
253 |
else: |
|
|
254 |
model = ResNetSimCLR(base_model, out_dim, hidden=hidden, widen=widen).to(device) |
|
|
255 |
model_state_dict = torch.load(location, map_location=device) |
|
|
256 |
if "state_dict" in model_state_dict.keys(): |
|
|
257 |
model_state_dict = model_state_dict["state_dict"] |
|
|
258 |
if "l1.weight" in model_state_dict.keys(): # load already fine-tuned model |
|
|
259 |
model_classes = model_state_dict["l1.weight"].shape[0] |
|
|
260 |
if model_classes != num_classes: |
|
|
261 |
raise Exception("Loaded model has different output dim ({}) than needed ({})".format( |
|
|
262 |
model_classes, num_classes)) |
|
|
263 |
adjust(model, num_classes, hidden=hidden) |
|
|
264 |
if not hidden and "l2.weight" in model_state_dict: |
|
|
265 |
del model_state_dict["l2.weight"] |
|
|
266 |
del model_state_dict["l2.bias"] |
|
|
267 |
model.load_state_dict(model_state_dict) |
|
|
268 |
else: # load pretrained model |
|
|
269 |
base_dict = model.state_dict() |
|
|
270 |
model_state_dict = get_new_state_dict( |
|
|
271 |
base_dict, model_state_dict, method=method) |
|
|
272 |
model.load_state_dict(model_state_dict) |
|
|
273 |
adjust(model, num_classes, hidden=hidden) |
|
|
274 |
|
|
|
275 |
else: |
|
|
276 |
if "xresnet1d" in base_model: |
|
|
277 |
model = ResNetSimCLR(base_model, out_dim, hidden=hidden, widen=widen).to(device) |
|
|
278 |
adjust(model, num_classes, hidden=hidden) |
|
|
279 |
elif base_model == "cpc": |
|
|
280 |
if linear_evaluation: |
|
|
281 |
lin_ftrs_head = [] |
|
|
282 |
bn_head = False |
|
|
283 |
ps_head = 0.0 |
|
|
284 |
else: |
|
|
285 |
if hidden: |
|
|
286 |
lin_ftrs_head = [512] |
|
|
287 |
else: |
|
|
288 |
lin_ftrs_head = [] |
|
|
289 |
|
|
|
290 |
if conv_encoder: |
|
|
291 |
strides = [2, 2, 2, 2] |
|
|
292 |
kss = [10, 4, 4, 4] |
|
|
293 |
else: |
|
|
294 |
strides = [1]*4 |
|
|
295 |
kss = [1]*4 |
|
|
296 |
|
|
|
297 |
model = CPCModel(input_channels=12, strides=strides, kss=kss, features=[512]*4, n_hidden=512, n_layers=2, mlp=False, lstm=True, bias_proj=False, |
|
|
298 |
num_classes=num_classes, skip_encoder=False, bn_encoder=True, lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_head=bn_head).to(device) |
|
|
299 |
|
|
|
300 |
else: |
|
|
301 |
raise Exception("model unknown") |
|
|
302 |
|
|
|
303 |
return model |
|
|
304 |
|
|
|
305 |
|
|
|
306 |
def evaluate(model, dataloader, idmap, lbl_itos, cpc=False): |
|
|
307 |
preds, targs = eval_model(model, dataloader, cpc=cpc) |
|
|
308 |
scores = eval_scores(targs, preds, classes=lbl_itos, parallel=True) |
|
|
309 |
preds_agg, targs_agg = aggregate_predictions(preds, targs, idmap) |
|
|
310 |
scores_agg = eval_scores(targs_agg, preds_agg, |
|
|
311 |
classes=lbl_itos, parallel=True) |
|
|
312 |
macro = scores["label_AUC"]["macro"] |
|
|
313 |
macro_agg = scores_agg["label_AUC"]["macro"] |
|
|
314 |
return preds, macro, macro_agg |
|
|
315 |
|
|
|
316 |
|
|
|
317 |
def set_train_eval(model, cpc, linear_evaluation): |
|
|
318 |
if linear_evaluation: |
|
|
319 |
if cpc: |
|
|
320 |
model.encoder.eval() |
|
|
321 |
else: |
|
|
322 |
model.features.eval() |
|
|
323 |
else: |
|
|
324 |
model.train() |
|
|
325 |
|
|
|
326 |
|
|
|
327 |
def train_model(model, train_loader, valid_loader, test_loader, epochs, loss_fn, optimizer, head_only=True, linear_evaluation=False, percentage=1, lr_schedule=None, save_model_at=None, val_idmap=None, test_idmap=None, lbl_itos=None, cpc=False): |
|
|
328 |
if head_only: |
|
|
329 |
if linear_evaluation: |
|
|
330 |
print("linear evaluation for {} epochs".format(epochs)) |
|
|
331 |
else: |
|
|
332 |
print("head-only for {} epochs".format(epochs)) |
|
|
333 |
else: |
|
|
334 |
print("fine tuning for {} epochs".format(epochs)) |
|
|
335 |
|
|
|
336 |
if head_only: |
|
|
337 |
for key, param in model.named_parameters(): |
|
|
338 |
if "l1." not in key and "head." not in key: |
|
|
339 |
param.requires_grad = False |
|
|
340 |
print("copying state dict before training for sanity check after training") |
|
|
341 |
|
|
|
342 |
else: |
|
|
343 |
for param in model.parameters(): |
|
|
344 |
param.requires_grad = True |
|
|
345 |
if cpc: |
|
|
346 |
data_type = model.encoder[0][0].weight.type() |
|
|
347 |
else: |
|
|
348 |
data_type = model.features[0][0].weight.type() |
|
|
349 |
|
|
|
350 |
set_train_eval(model, cpc, linear_evaluation) |
|
|
351 |
state_dict_pre = deepcopy(model.state_dict()) |
|
|
352 |
print("epoch", "batch", "loss\n========================") |
|
|
353 |
loss_per_epoch = [] |
|
|
354 |
macro_agg_per_epoch = [] |
|
|
355 |
max_batches = len(train_loader) |
|
|
356 |
break_point = int(percentage*max_batches) |
|
|
357 |
best_macro = 0 |
|
|
358 |
best_macro_agg = 0 |
|
|
359 |
best_epoch = 0 |
|
|
360 |
best_preds = None |
|
|
361 |
test_macro = 0 |
|
|
362 |
test_macro_agg = 0 |
|
|
363 |
for epoch in tqdm(range(epochs)): |
|
|
364 |
if type(lr_schedule) == dict: |
|
|
365 |
if epoch in lr_schedule.keys(): |
|
|
366 |
for param_group in optimizer.param_groups: |
|
|
367 |
param_group['lr'] /= lr_schedule[epoch] |
|
|
368 |
total_loss_one_epoch = 0 |
|
|
369 |
for batch_idx, samples in enumerate(train_loader): |
|
|
370 |
if batch_idx == break_point: |
|
|
371 |
print("break at batch nr.", batch_idx) |
|
|
372 |
break |
|
|
373 |
data = samples[0].to(device).type(data_type) |
|
|
374 |
labels = samples[1].to(device).type(data_type) |
|
|
375 |
optimizer.zero_grad() |
|
|
376 |
preds = model(data) |
|
|
377 |
loss = loss_fn(preds, labels) |
|
|
378 |
loss.backward() |
|
|
379 |
optimizer.step() |
|
|
380 |
total_loss_one_epoch += loss.item() |
|
|
381 |
if(batch_idx % 100 == 0): |
|
|
382 |
print(epoch, batch_idx, loss.item()) |
|
|
383 |
loss_per_epoch.append(total_loss_one_epoch) |
|
|
384 |
|
|
|
385 |
preds, macro, macro_agg = evaluate( |
|
|
386 |
model, valid_loader, val_idmap, lbl_itos, cpc=cpc) |
|
|
387 |
macro_agg_per_epoch.append(macro_agg) |
|
|
388 |
|
|
|
389 |
print("loss:", total_loss_one_epoch) |
|
|
390 |
print("aggregated macro:", macro_agg) |
|
|
391 |
if macro_agg > best_macro_agg: |
|
|
392 |
torch.save(model.state_dict(), save_model_at) |
|
|
393 |
best_macro_agg = macro_agg |
|
|
394 |
best_macro = macro |
|
|
395 |
best_epoch = epoch |
|
|
396 |
best_preds = preds |
|
|
397 |
_, test_macro, test_macro_agg = evaluate( |
|
|
398 |
model, test_loader, test_idmap, lbl_itos, cpc=cpc) |
|
|
399 |
|
|
|
400 |
set_train_eval(model, cpc, linear_evaluation) |
|
|
401 |
|
|
|
402 |
if epochs > 0: |
|
|
403 |
sanity_check(model, state_dict_pre, linear_evaluation, head_only) |
|
|
404 |
return loss_per_epoch, macro_agg_per_epoch, best_macro, best_macro_agg, test_macro, test_macro_agg, best_epoch, best_preds |
|
|
405 |
|
|
|
406 |
|
|
|
407 |
def sanity_check(model, state_dict_pre, linear_evaluation, head_only): |
|
|
408 |
""" |
|
|
409 |
Linear classifier should not change any weights other than the linear layer. |
|
|
410 |
This sanity check asserts nothing wrong happens (e.g., BN stats updated). |
|
|
411 |
""" |
|
|
412 |
print("=> loading state dict for sanity check") |
|
|
413 |
state_dict = model.state_dict() |
|
|
414 |
if linear_evaluation: |
|
|
415 |
for k in list(state_dict.keys()): |
|
|
416 |
# only ignore fc layer |
|
|
417 |
if 'fc.' in k or 'head.' in k or 'l1.' in k: |
|
|
418 |
continue |
|
|
419 |
|
|
|
420 |
equals = (state_dict[k].cpu() == state_dict_pre[k].cpu()).all() |
|
|
421 |
if (linear_evaluation != equals): |
|
|
422 |
raise Exception( |
|
|
423 |
'=> failed sanity check in {}'.format("linear_evaluation")) |
|
|
424 |
elif head_only: |
|
|
425 |
for k in list(state_dict.keys()): |
|
|
426 |
# only ignore fc layer |
|
|
427 |
if 'fc.' in k or 'head.' in k: |
|
|
428 |
continue |
|
|
429 |
|
|
|
430 |
equals = (state_dict[k].cpu() == state_dict_pre[k].cpu()).all() |
|
|
431 |
if (equals and "running_mean" in k): |
|
|
432 |
raise Exception( |
|
|
433 |
'=> failed sanity check in {}'.format("head-only")) |
|
|
434 |
# else: |
|
|
435 |
# for k in list(state_dict.keys()): |
|
|
436 |
# equals=(state_dict[k].cpu() == state_dict_pre[k].cpu()).all() |
|
|
437 |
# if equals: |
|
|
438 |
# pdb.set_trace() |
|
|
439 |
# raise Exception('=> failed sanity check in {}'.format("fine_tuning")) |
|
|
440 |
|
|
|
441 |
print("=> sanity check passed.") |
|
|
442 |
|
|
|
443 |
|
|
|
444 |
def eval_model(model, valid_loader, cpc=False): |
|
|
445 |
if cpc: |
|
|
446 |
data_type = model.encoder[0][0].weight.type() |
|
|
447 |
else: |
|
|
448 |
data_type = model.features[0][0].weight.type() |
|
|
449 |
model.eval() |
|
|
450 |
preds = [] |
|
|
451 |
targs = [] |
|
|
452 |
with torch.no_grad(): |
|
|
453 |
for batch_idx, samples in tqdm(enumerate(valid_loader)): |
|
|
454 |
data = samples[0].to(device).type(data_type) |
|
|
455 |
preds_tmp = torch.sigmoid(model(data)) |
|
|
456 |
targs.append(samples[1]) |
|
|
457 |
preds.append(preds_tmp.cpu()) |
|
|
458 |
preds = torch.cat(preds).numpy() |
|
|
459 |
targs = torch.cat(targs).numpy() |
|
|
460 |
|
|
|
461 |
return preds, targs |
|
|
462 |
|
|
|
463 |
|
|
|
464 |
def get_dataset(batch_size, num_workers, target_folder, apply_noise=False, percentage=1.0, folds=8, t_params=None, test=False, normalize=False): |
|
|
465 |
if apply_noise: |
|
|
466 |
transformations = ["BaselineWander", |
|
|
467 |
"PowerlineNoise", "EMNoise", "BaselineShift"] |
|
|
468 |
if normalize: |
|
|
469 |
transformations.append("Normalize") |
|
|
470 |
dataset = SimCLRDataSetWrapper(batch_size,num_workers,None,"(12, 250)",None,target_folder,[target_folder],None,None, |
|
|
471 |
mode="linear_evaluation", transformations=transformations, percentage=percentage, folds=folds, t_params=t_params, test=test, ptb_xl_label="label_all") |
|
|
472 |
else: |
|
|
473 |
if normalize: |
|
|
474 |
# always use PTB-XL stats |
|
|
475 |
transformations = ["Normalize"] |
|
|
476 |
dataset = SimCLRDataSetWrapper(batch_size,num_workers,None,"(12, 250)",None,target_folder,[target_folder],None,None, |
|
|
477 |
mode="linear_evaluation", percentage=percentage, folds=folds, test=test, transformations=transformations, ptb_xl_label="label_all") |
|
|
478 |
else: |
|
|
479 |
dataset = SimCLRDataSetWrapper(batch_size,num_workers,None,"(12, 250)",None,target_folder,[target_folder],None,None, |
|
|
480 |
mode="linear_evaluation", percentage=percentage, folds=folds, test=test, ptb_xl_label="label_all") |
|
|
481 |
|
|
|
482 |
train_loader, valid_loader = dataset.get_data_loaders() |
|
|
483 |
return dataset, train_loader, valid_loader |
|
|
484 |
|
|
|
485 |
|
|
|
486 |
if __name__ == "__main__": |
|
|
487 |
args = parse_args() |
|
|
488 |
dataset, train_loader, _ = get_dataset( |
|
|
489 |
args.batch_size, args.num_workers, args.dataset, folds=args.folds, test=args.test, normalize=args.normalize) |
|
|
490 |
_, _, valid_loader = get_dataset( |
|
|
491 |
args.batch_size, args.num_workers, args.dataset, folds=args.folds, test=False, normalize=args.normalize) |
|
|
492 |
val_idmap = dataset.val_ds_idmap |
|
|
493 |
dataset, _, test_loader = get_dataset( |
|
|
494 |
args.batch_size, args.num_workers, args.dataset, test=True, normalize=args.normalize) |
|
|
495 |
test_idmap = dataset.val_ds_idmap |
|
|
496 |
lbl_itos = dataset.lbl_itos |
|
|
497 |
tag = "f=" + str(args.folds) + "_" + args.tag |
|
|
498 |
tag = tag if args.use_pretrained else "ran_" + tag |
|
|
499 |
tag = "eval_" + tag if args.eval_only else tag |
|
|
500 |
model_tag = "finetuned" if args.load_finetuned else "ckpt" |
|
|
501 |
if args.test_noised: |
|
|
502 |
t_params_by_level = { |
|
|
503 |
1: {"bw_cmax": 0.05, "em_cmax": 0.25, "pl_cmax": 0.1, "bs_cmax": 0.5}, |
|
|
504 |
2: {"bw_cmax": 0.1, "em_cmax": 0.5, "pl_cmax": 0.2, "bs_cmax": 1}, |
|
|
505 |
3: {"bw_cmax": 0.1, "em_cmax": 1, "pl_cmax": 0.2, "bs_cmax": 2}, |
|
|
506 |
4: {"bw_cmax": 0.2, "em_cmax": 1, "pl_cmax": 0.4, "bs_cmax": 2}, |
|
|
507 |
5: {"bw_cmax": 0.2, "em_cmax": 1.5, "pl_cmax": 0.4, "bs_cmax": 2.5}, |
|
|
508 |
6: {"bw_cmax": 0.3, "em_cmax": 2, "pl_cmax": 0.5, "bs_cmax": 3}, |
|
|
509 |
} |
|
|
510 |
if args.noise_level not in t_params_by_level.keys(): |
|
|
511 |
raise("noise level does not exist") |
|
|
512 |
t_params = t_params_by_level[args.noise_level] |
|
|
513 |
dataset, _, noise_valid_loader = get_dataset( |
|
|
514 |
args.batch_size, args.num_workers, args.dataset, apply_noise=True, t_params=t_params, test=args.test) |
|
|
515 |
else: |
|
|
516 |
noise_valid_loader = None |
|
|
517 |
losses, macros, predss, result_macros, result_macros_agg, test_macros, test_macros_agg, noised_macros, noised_macros_agg = [ |
|
|
518 |
], [], [], [], [], [], [], [], [] |
|
|
519 |
ckpt_epoch_lin=0 |
|
|
520 |
ckpt_epoch_fin=0 |
|
|
521 |
if args.f_epochs == 0: |
|
|
522 |
save_model_at = os.path.join(os.path.dirname( |
|
|
523 |
args.model_file), "n=" + str(args.noise_level) + "_"+tag + "lin_finetuned") |
|
|
524 |
filename = os.path.join(os.path.dirname( |
|
|
525 |
args.model_file), "n=" + str(args.noise_level) + "_"+tag + "res_lin.pkl") |
|
|
526 |
else: |
|
|
527 |
save_model_at = os.path.join(os.path.dirname( |
|
|
528 |
args.model_file), "n=" + str(args.noise_level) + "_"+tag + "fin_finetuned") |
|
|
529 |
filename = os.path.join(os.path.dirname( |
|
|
530 |
args.model_file), "n=" + str(args.noise_level) + "_"+tag + "res_fin.pkl") |
|
|
531 |
|
|
|
532 |
model = load_model( |
|
|
533 |
args.linear_evaluation, 71, args.use_pretrained or args.load_finetuned, hidden=args.hidden, |
|
|
534 |
location=args.model_file, discriminative_lr=args.discriminative_lr, method=args.method) |
|
|
535 |
loss_fn, optimizer = configure_optimizer( |
|
|
536 |
model, args.batch_size, head_only=True, discriminative_lr=args.discriminative_lr, discriminative_lr_factor=0.1 if args.use_pretrained and args.discriminative_lr else 1) |
|
|
537 |
if not args.eval_only: |
|
|
538 |
print("train model...") |
|
|
539 |
if not isdir(save_model_at): |
|
|
540 |
os.mkdir(save_model_at) |
|
|
541 |
|
|
|
542 |
l1, m1, bm, bm_agg, tm, tm_agg, ckpt_epoch_lin, preds = train_model(model, train_loader, valid_loader, test_loader, args.l_epochs, loss_fn, |
|
|
543 |
optimizer, head_only=True, linear_evaluation=args.linear_evaluation, lr_schedule=args.lr_schedule, save_model_at=join(save_model_at, "finetuned.pt"), |
|
|
544 |
val_idmap=val_idmap, test_idmap=test_idmap, lbl_itos=lbl_itos, cpc=(args.method == "cpc")) |
|
|
545 |
if bm != 0: |
|
|
546 |
print("best macro after head-only training:", bm_agg) |
|
|
547 |
l2 = [] |
|
|
548 |
m2 = [] |
|
|
549 |
if args.f_epochs != 0: |
|
|
550 |
if args.l_epochs != 0: |
|
|
551 |
model = load_model( |
|
|
552 |
False, 71, True, hidden=args.hidden, |
|
|
553 |
location=join(save_model_at, "finetuned.pt"), discriminative_lr=args.discriminative_lr, method=args.method) |
|
|
554 |
loss_fn, optimizer = configure_optimizer( |
|
|
555 |
model, args.batch_size, head_only=False, discriminative_lr=args.discriminative_lr, discriminative_lr_factor=0.1 if args.use_pretrained and args.discriminative_lr else 1) |
|
|
556 |
l2, m2, bm, bm_agg, tm, tm_agg, ckpt_epoch_fin, preds = train_model(model, train_loader, valid_loader, test_loader, args.f_epochs, loss_fn, |
|
|
557 |
optimizer, head_only=False, linear_evaluation=False, lr_schedule=args.lr_schedule, save_model_at=join(save_model_at, "finetuned.pt"), |
|
|
558 |
val_idmap=val_idmap, test_idmap=test_idmap, lbl_itos=lbl_itos, cpc=(args.method == "cpc")) |
|
|
559 |
losses.append(l1+l2) |
|
|
560 |
macros.append(m1+m2) |
|
|
561 |
test_macros.append(tm) |
|
|
562 |
test_macros_agg.append(tm_agg) |
|
|
563 |
result_macros.append(bm) |
|
|
564 |
result_macros_agg.append(bm_agg) |
|
|
565 |
|
|
|
566 |
else: |
|
|
567 |
preds, eval_macro, eval_macro_agg = evaluate( |
|
|
568 |
model, test_loader, test_idmap, lbl_itos, cpc=(args.method == "cpc")) |
|
|
569 |
result_macros.append(eval_macro) |
|
|
570 |
result_macros_agg.append(eval_macro_agg) |
|
|
571 |
if args.verbose: |
|
|
572 |
print("macro:", eval_macro) |
|
|
573 |
predss.append(preds) |
|
|
574 |
|
|
|
575 |
if noise_valid_loader is not None: |
|
|
576 |
_, noise_macro, noise_macro_agg = evaluate( |
|
|
577 |
model, noise_valid_loader, val_idmap, lbl_itos) |
|
|
578 |
noised_macros.append(noise_macro) |
|
|
579 |
noised_macros_agg.append(noise_macro_agg) |
|
|
580 |
res = {"filename": filename, "epochs": args.l_epochs+args.f_epochs, "model_location": args.model_location, |
|
|
581 |
"losses": losses, "macros": macros, "predss": predss, "result_macros": result_macros, "result_macros_agg": result_macros_agg, |
|
|
582 |
"test_macros": test_macros, "test_macros_agg": test_macros_agg, "noised_macros": noised_macros, "noised_macros_agg": noised_macros_agg, "ckpt_epoch_lin": ckpt_epoch_lin, "ckpt_epoch_fin": ckpt_epoch_fin, |
|
|
583 |
"discriminative_lr": args.discriminative_lr, "hidden": args.hidden, "lr_schedule": args.lr_schedule, |
|
|
584 |
"use_pretrained": args.use_pretrained, "linear_evaluation": args.linear_evaluation, "loaded_finetuned": args.load_finetuned, |
|
|
585 |
"eval_only": args.eval_only, "noise_level": args.noise_level, "test_noised": args.test_noised, "normalized": args.normalize} |
|
|
586 |
pickle.dump(res, open(filename, "wb")) |
|
|
587 |
print("dumped results to", filename) |
|
|
588 |
print(res) |
|
|
589 |
print("Done!") |
|
|
590 |
|