import yaml
import tensorboard
import torch
import os
import shutil
import sys
import csv
import argparse
import pickle
from models.resnet_simclr import ResNetSimCLR
from clinical_ts.cpc import CPCModel
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.metrics import roc_auc_score
from clinical_ts.simclr_dataset_wrapper import SimCLRDataSetWrapper
from clinical_ts.eval_utils_cafa import eval_scores, eval_scores_bootstrap
from clinical_ts.timeseries_utils import aggregate_predictions
import pdb
from copy import deepcopy
from os.path import join, isdir
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def parse_args():
parser = argparse.ArgumentParser("Finetuning tests")
parser.add_argument("--model_file")
parser.add_argument("--method")
parser.add_argument("--dataset", nargs="+", default="./data/ptb_xl_fs100")
parser.add_argument("--batch_size", type=int, default=512)
parser.add_argument("--discriminative_lr", default=False, action="store_true")
parser.add_argument("--num_workers", type=int, default=8)
parser.add_argument("--hidden", default=False, action="store_true")
parser.add_argument("--lr_schedule", default="{}")
parser.add_argument("--use_pretrained", default=False, action="store_true")
parser.add_argument("--linear_evaluation",
default=False, action="store_true", help="use linear evaluation")
parser.add_argument("--test_noised", default=False, action="store_true", help="validate also on a noisy dataset")
parser.add_argument("--noise_level", default=1, type=int, help="level of noise induced to the second validations set")
parser.add_argument("--folds", default=8, type=int, help="number of folds used in finetuning (between 1-8)")
parser.add_argument("--tag", default="")
parser.add_argument("--eval_only", action="store_true", default=False, help="only evaluate mode")
parser.add_argument("--load_finetuned", action="store_true", default=False)
parser.add_argument("--test", action="store_true", default=False)
parser.add_argument("--verbose", action="store_true", default=False)
parser.add_argument("--cpc", action="store_true", default=False)
parser.add_argument("--model_location")
parser.add_argument("--l_epochs", type=int, default=0, help="number of head-only epochs (these are performed first)")
parser.add_argument("--f_epochs", type=int, default=0, help="number of finetuning epochs (these are perfomed after head-only training")
parser.add_argument("--normalize", action="store_true", default=False, help="normalize dataset with ptbxl mean and std")
parser.add_argument("--bn_head", action="store_true", default=False)
parser.add_argument("--ps_head", type=float, default=0.0)
parser.add_argument("--conv_encoder", action="store_true", default=False)
parser.add_argument("--base_model", default="xresnet1d50")
parser.add_argument("--widen", default=1, type=int, help="use wide xresnet1d50")
args = parser.parse_args()
return args
def get_new_state_dict(init_state_dict, lightning_state_dict, method="simclr"):
# in case of moco model
from collections import OrderedDict
# lightning_state_dict = lightning_state_dict["state_dict"]
new_state_dict = OrderedDict()
if method != "cpc":
if method == "moco":
for key in init_state_dict:
l_key = "encoder_q." + key
if l_key in lightning_state_dict.keys():
new_state_dict[key] = lightning_state_dict[l_key]
elif method == "simclr":
for key in init_state_dict:
if "features" in key:
l_key = key.replace("features", "encoder.features")
if l_key in lightning_state_dict.keys():
new_state_dict[key] = lightning_state_dict[l_key]
elif method == "swav":
for key in init_state_dict:
if "features" in key:
l_key = key.replace("features", "model.features")
if l_key in lightning_state_dict.keys():
new_state_dict[key] = lightning_state_dict[l_key]
elif method == "byol":
for key in init_state_dict:
l_key = "online_network.encoder." + key
if l_key in lightning_state_dict.keys():
new_state_dict[key] = lightning_state_dict[l_key]
else:
raise("method unknown")
new_state_dict["l1.weight"] = init_state_dict["l1.weight"]
new_state_dict["l1.bias"] = init_state_dict["l1.bias"]
if "l2.weight" in init_state_dict.keys():
new_state_dict["l2.weight"] = init_state_dict["l2.weight"]
new_state_dict["l2.bias"] = init_state_dict["l2.bias"]
assert(len(init_state_dict) == len(new_state_dict))
else:
for key in init_state_dict:
l_key = "model_cpc." + key
if l_key in lightning_state_dict.keys():
new_state_dict[key] = lightning_state_dict[l_key]
if "head" in key:
new_state_dict[key] = init_state_dict[key]
return new_state_dict
def adjust(model, num_classes, hidden=False):
in_features = model.l1.in_features
last_layer = torch.nn.modules.linear.Linear(
in_features, num_classes).to(device)
if hidden:
model.l1 = torch.nn.modules.linear.Linear(
in_features, in_features).to(device)
model.l2 = last_layer
else:
model.l1 = last_layer
def def_forward(self):
def new_forward(x):
h = self.features(x)
h = h.squeeze()
x = self.l1(h)
if hidden:
x = F.relu(x)
x = self.l2(x)
return x
return new_forward
model.forward = def_forward(model)
def configure_optimizer(model, batch_size, head_only=False, discriminative_lr=False, base_model="xresnet1d", optimizer="adam", discriminative_lr_factor=1):
loss_fn = F.binary_cross_entropy_with_logits
if base_model == "xresnet1d":
wd = 1e-1
if head_only:
lr = (8e-3*(batch_size/256))
optimizer = torch.optim.AdamW(
model.l1.parameters(), lr=lr, weight_decay=wd)
else:
lr = 0.01
if not discriminative_lr:
optimizer = torch.optim.AdamW(
model.parameters(), lr=lr, weight_decay=wd)
else:
param_dict = dict(model.named_parameters())
keys = param_dict.keys()
weight_layer_nrs = set()
for key in keys:
if "features" in key:
# parameter names have the form features.x
weight_layer_nrs.add(key[9])
weight_layer_nrs = sorted(weight_layer_nrs, reverse=True)
features_groups = []
while len(weight_layer_nrs) > 0:
if len(weight_layer_nrs) > 1:
features_groups.append(list(filter(
lambda x: "features." + weight_layer_nrs[0] in x or "features." + weight_layer_nrs[1] in x, keys)))
del weight_layer_nrs[:2]
else:
features_groups.append(
list(filter(lambda x: "features." + weight_layer_nrs[0] in x, keys)))
del weight_layer_nrs[0]
# filter linear layers
linears = list(filter(lambda x: "l" in x, keys))
groups = [linears] + features_groups
optimizer_param_list = []
tmp_lr = lr
for layers in groups:
layer_params = [param_dict[param_name]
for param_name in layers]
optimizer_param_list.append(
{"params": layer_params, "lr": tmp_lr})
tmp_lr /= 4
optimizer = torch.optim.AdamW(
optimizer_param_list, lr=lr, weight_decay=wd)
print("lr", lr)
print("wd", wd)
print("batch size", batch_size)
elif base_model == "cpc":
if(optimizer == "sgd"):
opt = torch.optim.SGD
elif(optimizer == "adam"):
opt = torch.optim.AdamW
else:
raise NotImplementedError("Unknown Optimizer.")
lr = 1e-4
wd = 1e-3
if(head_only):
lr = 1e-3
print("Linear eval: model head", model.head)
optimizer = opt(model.head.parameters(), lr, weight_decay=wd)
elif(discriminative_lr_factor != 1.): # discrimative lrs
optimizer = opt([{"params": model.encoder.parameters(), "lr": lr*discriminative_lr_factor*discriminative_lr_factor}, {
"params": model.rnn.parameters(), "lr": lr*discriminative_lr_factor}, {"params": model.head.parameters(), "lr": lr}], lr, weight_decay=wd)
print("Finetuning: model head", model.head)
print("discriminative lr: ", discriminative_lr_factor)
else:
lr = 1e-3
print("normal supervised training")
optimizer = opt(model.parameters(), lr, weight_decay=wd)
else:
raise("model unknown")
return loss_fn, optimizer
def load_model(linear_evaluation, num_classes, use_pretrained, discriminative_lr=False, hidden=False, conv_encoder=False, bn_head=False, ps_head=0.5, location="./checkpoints/moco_baselinewonder200.ckpt", method="simclr", base_model="xresnet1d50", out_dim=16, widen=1):
discriminative_lr_factor = 1
if use_pretrained:
print("load model from " + location)
discriminative_lr_factor = 0.1
if base_model == "cpc":
lightning_state_dict = torch.load(location, map_location=device)
# num_head = np.sum([1 if 'proj' in f else 0 for f in lightning_state_dict.keys()])
if linear_evaluation:
lin_ftrs_head = []
bn_head = False
ps_head = 0.0
else:
if hidden:
lin_ftrs_head = [512]
else:
lin_ftrs_head = []
if conv_encoder:
strides = [2, 2, 2, 2]
kss = [10, 4, 4, 4]
else:
strides = [1]*4
kss = [1]*4
model = CPCModel(input_channels=12, strides=strides, kss=kss, features=[512]*4, n_hidden=512, n_layers=2, mlp=False, lstm=True, bias_proj=False,
num_classes=num_classes, skip_encoder=False, bn_encoder=True, lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_head=bn_head).to(device)
if "state_dict" in lightning_state_dict.keys():
print("load pretrained model")
model_state_dict = get_new_state_dict(
model.state_dict(), lightning_state_dict["state_dict"], method="cpc")
else:
print("load already finetuned model")
model_state_dict = lightning_state_dict
model.load_state_dict(model_state_dict)
else:
model = ResNetSimCLR(base_model, out_dim, hidden=hidden, widen=widen).to(device)
model_state_dict = torch.load(location, map_location=device)
if "state_dict" in model_state_dict.keys():
model_state_dict = model_state_dict["state_dict"]
if "l1.weight" in model_state_dict.keys(): # load already fine-tuned model
model_classes = model_state_dict["l1.weight"].shape[0]
if model_classes != num_classes:
raise Exception("Loaded model has different output dim ({}) than needed ({})".format(
model_classes, num_classes))
adjust(model, num_classes, hidden=hidden)
if not hidden and "l2.weight" in model_state_dict:
del model_state_dict["l2.weight"]
del model_state_dict["l2.bias"]
model.load_state_dict(model_state_dict)
else: # load pretrained model
base_dict = model.state_dict()
model_state_dict = get_new_state_dict(
base_dict, model_state_dict, method=method)
model.load_state_dict(model_state_dict)
adjust(model, num_classes, hidden=hidden)
else:
if "xresnet1d" in base_model:
model = ResNetSimCLR(base_model, out_dim, hidden=hidden, widen=widen).to(device)
adjust(model, num_classes, hidden=hidden)
elif base_model == "cpc":
if linear_evaluation:
lin_ftrs_head = []
bn_head = False
ps_head = 0.0
else:
if hidden:
lin_ftrs_head = [512]
else:
lin_ftrs_head = []
if conv_encoder:
strides = [2, 2, 2, 2]
kss = [10, 4, 4, 4]
else:
strides = [1]*4
kss = [1]*4
model = CPCModel(input_channels=12, strides=strides, kss=kss, features=[512]*4, n_hidden=512, n_layers=2, mlp=False, lstm=True, bias_proj=False,
num_classes=num_classes, skip_encoder=False, bn_encoder=True, lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_head=bn_head).to(device)
else:
raise Exception("model unknown")
return model
def evaluate(model, dataloader, idmap, lbl_itos, cpc=False):
preds, targs = eval_model(model, dataloader, cpc=cpc)
scores = eval_scores(targs, preds, classes=lbl_itos, parallel=True)
preds_agg, targs_agg = aggregate_predictions(preds, targs, idmap)
scores_agg = eval_scores(targs_agg, preds_agg,
classes=lbl_itos, parallel=True)
macro = scores["label_AUC"]["macro"]
macro_agg = scores_agg["label_AUC"]["macro"]
return preds, macro, macro_agg
def set_train_eval(model, cpc, linear_evaluation):
if linear_evaluation:
if cpc:
model.encoder.eval()
else:
model.features.eval()
else:
model.train()
def train_model(model, train_loader, valid_loader, test_loader, epochs, loss_fn, optimizer, head_only=True, linear_evaluation=False, percentage=1, lr_schedule=None, save_model_at=None, val_idmap=None, test_idmap=None, lbl_itos=None, cpc=False):
if head_only:
if linear_evaluation:
print("linear evaluation for {} epochs".format(epochs))
else:
print("head-only for {} epochs".format(epochs))
else:
print("fine tuning for {} epochs".format(epochs))
if head_only:
for key, param in model.named_parameters():
if "l1." not in key and "head." not in key:
param.requires_grad = False
print("copying state dict before training for sanity check after training")
else:
for param in model.parameters():
param.requires_grad = True
if cpc:
data_type = model.encoder[0][0].weight.type()
else:
data_type = model.features[0][0].weight.type()
set_train_eval(model, cpc, linear_evaluation)
state_dict_pre = deepcopy(model.state_dict())
print("epoch", "batch", "loss\n========================")
loss_per_epoch = []
macro_agg_per_epoch = []
max_batches = len(train_loader)
break_point = int(percentage*max_batches)
best_macro = 0
best_macro_agg = 0
best_epoch = 0
best_preds = None
test_macro = 0
test_macro_agg = 0
for epoch in tqdm(range(epochs)):
if type(lr_schedule) == dict:
if epoch in lr_schedule.keys():
for param_group in optimizer.param_groups:
param_group['lr'] /= lr_schedule[epoch]
total_loss_one_epoch = 0
for batch_idx, samples in enumerate(train_loader):
if batch_idx == break_point:
print("break at batch nr.", batch_idx)
break
data = samples[0].to(device).type(data_type)
labels = samples[1].to(device).type(data_type)
optimizer.zero_grad()
preds = model(data)
loss = loss_fn(preds, labels)
loss.backward()
optimizer.step()
total_loss_one_epoch += loss.item()
if(batch_idx % 100 == 0):
print(epoch, batch_idx, loss.item())
loss_per_epoch.append(total_loss_one_epoch)
preds, macro, macro_agg = evaluate(
model, valid_loader, val_idmap, lbl_itos, cpc=cpc)
macro_agg_per_epoch.append(macro_agg)
print("loss:", total_loss_one_epoch)
print("aggregated macro:", macro_agg)
if macro_agg > best_macro_agg:
torch.save(model.state_dict(), save_model_at)
best_macro_agg = macro_agg
best_macro = macro
best_epoch = epoch
best_preds = preds
_, test_macro, test_macro_agg = evaluate(
model, test_loader, test_idmap, lbl_itos, cpc=cpc)
set_train_eval(model, cpc, linear_evaluation)
if epochs > 0:
sanity_check(model, state_dict_pre, linear_evaluation, head_only)
return loss_per_epoch, macro_agg_per_epoch, best_macro, best_macro_agg, test_macro, test_macro_agg, best_epoch, best_preds
def sanity_check(model, state_dict_pre, linear_evaluation, head_only):
"""
Linear classifier should not change any weights other than the linear layer.
This sanity check asserts nothing wrong happens (e.g., BN stats updated).
"""
print("=> loading state dict for sanity check")
state_dict = model.state_dict()
if linear_evaluation:
for k in list(state_dict.keys()):
# only ignore fc layer
if 'fc.' in k or 'head.' in k or 'l1.' in k:
continue
equals = (state_dict[k].cpu() == state_dict_pre[k].cpu()).all()
if (linear_evaluation != equals):
raise Exception(
'=> failed sanity check in {}'.format("linear_evaluation"))
elif head_only:
for k in list(state_dict.keys()):
# only ignore fc layer
if 'fc.' in k or 'head.' in k:
continue
equals = (state_dict[k].cpu() == state_dict_pre[k].cpu()).all()
if (equals and "running_mean" in k):
raise Exception(
'=> failed sanity check in {}'.format("head-only"))
# else:
# for k in list(state_dict.keys()):
# equals=(state_dict[k].cpu() == state_dict_pre[k].cpu()).all()
# if equals:
# pdb.set_trace()
# raise Exception('=> failed sanity check in {}'.format("fine_tuning"))
print("=> sanity check passed.")
def eval_model(model, valid_loader, cpc=False):
if cpc:
data_type = model.encoder[0][0].weight.type()
else:
data_type = model.features[0][0].weight.type()
model.eval()
preds = []
targs = []
with torch.no_grad():
for batch_idx, samples in tqdm(enumerate(valid_loader)):
data = samples[0].to(device).type(data_type)
preds_tmp = torch.sigmoid(model(data))
targs.append(samples[1])
preds.append(preds_tmp.cpu())
preds = torch.cat(preds).numpy()
targs = torch.cat(targs).numpy()
return preds, targs
def get_dataset(batch_size, num_workers, target_folder, apply_noise=False, percentage=1.0, folds=8, t_params=None, test=False, normalize=False):
if apply_noise:
transformations = ["BaselineWander",
"PowerlineNoise", "EMNoise", "BaselineShift"]
if normalize:
transformations.append("Normalize")
dataset = SimCLRDataSetWrapper(batch_size,num_workers,None,"(12, 250)",None,target_folder,[target_folder],None,None,
mode="linear_evaluation", transformations=transformations, percentage=percentage, folds=folds, t_params=t_params, test=test, ptb_xl_label="label_all")
else:
if normalize:
# always use PTB-XL stats
transformations = ["Normalize"]
dataset = SimCLRDataSetWrapper(batch_size,num_workers,None,"(12, 250)",None,target_folder,[target_folder],None,None,
mode="linear_evaluation", percentage=percentage, folds=folds, test=test, transformations=transformations, ptb_xl_label="label_all")
else:
dataset = SimCLRDataSetWrapper(batch_size,num_workers,None,"(12, 250)",None,target_folder,[target_folder],None,None,
mode="linear_evaluation", percentage=percentage, folds=folds, test=test, ptb_xl_label="label_all")
train_loader, valid_loader = dataset.get_data_loaders()
return dataset, train_loader, valid_loader
if __name__ == "__main__":
args = parse_args()
dataset, train_loader, _ = get_dataset(
args.batch_size, args.num_workers, args.dataset, folds=args.folds, test=args.test, normalize=args.normalize)
_, _, valid_loader = get_dataset(
args.batch_size, args.num_workers, args.dataset, folds=args.folds, test=False, normalize=args.normalize)
val_idmap = dataset.val_ds_idmap
dataset, _, test_loader = get_dataset(
args.batch_size, args.num_workers, args.dataset, test=True, normalize=args.normalize)
test_idmap = dataset.val_ds_idmap
lbl_itos = dataset.lbl_itos
tag = "f=" + str(args.folds) + "_" + args.tag
tag = tag if args.use_pretrained else "ran_" + tag
tag = "eval_" + tag if args.eval_only else tag
model_tag = "finetuned" if args.load_finetuned else "ckpt"
if args.test_noised:
t_params_by_level = {
1: {"bw_cmax": 0.05, "em_cmax": 0.25, "pl_cmax": 0.1, "bs_cmax": 0.5},
2: {"bw_cmax": 0.1, "em_cmax": 0.5, "pl_cmax": 0.2, "bs_cmax": 1},
3: {"bw_cmax": 0.1, "em_cmax": 1, "pl_cmax": 0.2, "bs_cmax": 2},
4: {"bw_cmax": 0.2, "em_cmax": 1, "pl_cmax": 0.4, "bs_cmax": 2},
5: {"bw_cmax": 0.2, "em_cmax": 1.5, "pl_cmax": 0.4, "bs_cmax": 2.5},
6: {"bw_cmax": 0.3, "em_cmax": 2, "pl_cmax": 0.5, "bs_cmax": 3},
}
if args.noise_level not in t_params_by_level.keys():
raise("noise level does not exist")
t_params = t_params_by_level[args.noise_level]
dataset, _, noise_valid_loader = get_dataset(
args.batch_size, args.num_workers, args.dataset, apply_noise=True, t_params=t_params, test=args.test)
else:
noise_valid_loader = None
losses, macros, predss, result_macros, result_macros_agg, test_macros, test_macros_agg, noised_macros, noised_macros_agg = [
], [], [], [], [], [], [], [], []
ckpt_epoch_lin=0
ckpt_epoch_fin=0
if args.f_epochs == 0:
save_model_at = os.path.join(os.path.dirname(
args.model_file), "n=" + str(args.noise_level) + "_"+tag + "lin_finetuned")
filename = os.path.join(os.path.dirname(
args.model_file), "n=" + str(args.noise_level) + "_"+tag + "res_lin.pkl")
else:
save_model_at = os.path.join(os.path.dirname(
args.model_file), "n=" + str(args.noise_level) + "_"+tag + "fin_finetuned")
filename = os.path.join(os.path.dirname(
args.model_file), "n=" + str(args.noise_level) + "_"+tag + "res_fin.pkl")
model = load_model(
args.linear_evaluation, 71, args.use_pretrained or args.load_finetuned, hidden=args.hidden,
location=args.model_file, discriminative_lr=args.discriminative_lr, method=args.method)
loss_fn, optimizer = configure_optimizer(
model, args.batch_size, head_only=True, discriminative_lr=args.discriminative_lr, discriminative_lr_factor=0.1 if args.use_pretrained and args.discriminative_lr else 1)
if not args.eval_only:
print("train model...")
if not isdir(save_model_at):
os.mkdir(save_model_at)
l1, m1, bm, bm_agg, tm, tm_agg, ckpt_epoch_lin, preds = train_model(model, train_loader, valid_loader, test_loader, args.l_epochs, loss_fn,
optimizer, head_only=True, linear_evaluation=args.linear_evaluation, lr_schedule=args.lr_schedule, save_model_at=join(save_model_at, "finetuned.pt"),
val_idmap=val_idmap, test_idmap=test_idmap, lbl_itos=lbl_itos, cpc=(args.method == "cpc"))
if bm != 0:
print("best macro after head-only training:", bm_agg)
l2 = []
m2 = []
if args.f_epochs != 0:
if args.l_epochs != 0:
model = load_model(
False, 71, True, hidden=args.hidden,
location=join(save_model_at, "finetuned.pt"), discriminative_lr=args.discriminative_lr, method=args.method)
loss_fn, optimizer = configure_optimizer(
model, args.batch_size, head_only=False, discriminative_lr=args.discriminative_lr, discriminative_lr_factor=0.1 if args.use_pretrained and args.discriminative_lr else 1)
l2, m2, bm, bm_agg, tm, tm_agg, ckpt_epoch_fin, preds = train_model(model, train_loader, valid_loader, test_loader, args.f_epochs, loss_fn,
optimizer, head_only=False, linear_evaluation=False, lr_schedule=args.lr_schedule, save_model_at=join(save_model_at, "finetuned.pt"),
val_idmap=val_idmap, test_idmap=test_idmap, lbl_itos=lbl_itos, cpc=(args.method == "cpc"))
losses.append(l1+l2)
macros.append(m1+m2)
test_macros.append(tm)
test_macros_agg.append(tm_agg)
result_macros.append(bm)
result_macros_agg.append(bm_agg)
else:
preds, eval_macro, eval_macro_agg = evaluate(
model, test_loader, test_idmap, lbl_itos, cpc=(args.method == "cpc"))
result_macros.append(eval_macro)
result_macros_agg.append(eval_macro_agg)
if args.verbose:
print("macro:", eval_macro)
predss.append(preds)
if noise_valid_loader is not None:
_, noise_macro, noise_macro_agg = evaluate(
model, noise_valid_loader, val_idmap, lbl_itos)
noised_macros.append(noise_macro)
noised_macros_agg.append(noise_macro_agg)
res = {"filename": filename, "epochs": args.l_epochs+args.f_epochs, "model_location": args.model_location,
"losses": losses, "macros": macros, "predss": predss, "result_macros": result_macros, "result_macros_agg": result_macros_agg,
"test_macros": test_macros, "test_macros_agg": test_macros_agg, "noised_macros": noised_macros, "noised_macros_agg": noised_macros_agg, "ckpt_epoch_lin": ckpt_epoch_lin, "ckpt_epoch_fin": ckpt_epoch_fin,
"discriminative_lr": args.discriminative_lr, "hidden": args.hidden, "lr_schedule": args.lr_schedule,
"use_pretrained": args.use_pretrained, "linear_evaluation": args.linear_evaluation, "loaded_finetuned": args.load_finetuned,
"eval_only": args.eval_only, "noise_level": args.noise_level, "test_noised": args.test_noised, "normalized": args.normalize}
pickle.dump(res, open(filename, "wb"))
print("dumped results to", filename)
print(res)
print("Done!")