--- a +++ b/eval.py @@ -0,0 +1,590 @@ + +import yaml +import tensorboard +import torch +import os +import shutil +import sys +import csv +import argparse +import pickle +from models.resnet_simclr import ResNetSimCLR +from clinical_ts.cpc import CPCModel +import torch.nn.functional as F +from tqdm import tqdm +import numpy as np +import matplotlib.pyplot as plt +from sklearn.decomposition import PCA +from sklearn.manifold import TSNE +from sklearn.metrics import roc_auc_score +from clinical_ts.simclr_dataset_wrapper import SimCLRDataSetWrapper +from clinical_ts.eval_utils_cafa import eval_scores, eval_scores_bootstrap +from clinical_ts.timeseries_utils import aggregate_predictions +import pdb +from copy import deepcopy +from os.path import join, isdir +device = 'cuda' if torch.cuda.is_available() else 'cpu' + + +def parse_args(): + parser = argparse.ArgumentParser("Finetuning tests") + parser.add_argument("--model_file") + parser.add_argument("--method") + parser.add_argument("--dataset", nargs="+", default="./data/ptb_xl_fs100") + parser.add_argument("--batch_size", type=int, default=512) + parser.add_argument("--discriminative_lr", default=False, action="store_true") + parser.add_argument("--num_workers", type=int, default=8) + parser.add_argument("--hidden", default=False, action="store_true") + parser.add_argument("--lr_schedule", default="{}") + parser.add_argument("--use_pretrained", default=False, action="store_true") + parser.add_argument("--linear_evaluation", + default=False, action="store_true", help="use linear evaluation") + parser.add_argument("--test_noised", default=False, action="store_true", help="validate also on a noisy dataset") + parser.add_argument("--noise_level", default=1, type=int, help="level of noise induced to the second validations set") + parser.add_argument("--folds", default=8, type=int, help="number of folds used in finetuning (between 1-8)") + parser.add_argument("--tag", default="") + parser.add_argument("--eval_only", action="store_true", default=False, help="only evaluate mode") + parser.add_argument("--load_finetuned", action="store_true", default=False) + parser.add_argument("--test", action="store_true", default=False) + parser.add_argument("--verbose", action="store_true", default=False) + parser.add_argument("--cpc", action="store_true", default=False) + parser.add_argument("--model_location") + parser.add_argument("--l_epochs", type=int, default=0, help="number of head-only epochs (these are performed first)") + parser.add_argument("--f_epochs", type=int, default=0, help="number of finetuning epochs (these are perfomed after head-only training") + parser.add_argument("--normalize", action="store_true", default=False, help="normalize dataset with ptbxl mean and std") + parser.add_argument("--bn_head", action="store_true", default=False) + parser.add_argument("--ps_head", type=float, default=0.0) + parser.add_argument("--conv_encoder", action="store_true", default=False) + parser.add_argument("--base_model", default="xresnet1d50") + parser.add_argument("--widen", default=1, type=int, help="use wide xresnet1d50") + args = parser.parse_args() + return args + + +def get_new_state_dict(init_state_dict, lightning_state_dict, method="simclr"): + # in case of moco model + from collections import OrderedDict + # lightning_state_dict = lightning_state_dict["state_dict"] + new_state_dict = OrderedDict() + if method != "cpc": + if method == "moco": + for key in init_state_dict: + l_key = "encoder_q." + key + if l_key in lightning_state_dict.keys(): + new_state_dict[key] = lightning_state_dict[l_key] + elif method == "simclr": + for key in init_state_dict: + if "features" in key: + l_key = key.replace("features", "encoder.features") + if l_key in lightning_state_dict.keys(): + new_state_dict[key] = lightning_state_dict[l_key] + elif method == "swav": + + for key in init_state_dict: + if "features" in key: + l_key = key.replace("features", "model.features") + if l_key in lightning_state_dict.keys(): + new_state_dict[key] = lightning_state_dict[l_key] + elif method == "byol": + for key in init_state_dict: + l_key = "online_network.encoder." + key + if l_key in lightning_state_dict.keys(): + new_state_dict[key] = lightning_state_dict[l_key] + else: + raise("method unknown") + new_state_dict["l1.weight"] = init_state_dict["l1.weight"] + new_state_dict["l1.bias"] = init_state_dict["l1.bias"] + if "l2.weight" in init_state_dict.keys(): + new_state_dict["l2.weight"] = init_state_dict["l2.weight"] + new_state_dict["l2.bias"] = init_state_dict["l2.bias"] + + assert(len(init_state_dict) == len(new_state_dict)) + else: + for key in init_state_dict: + l_key = "model_cpc." + key + if l_key in lightning_state_dict.keys(): + new_state_dict[key] = lightning_state_dict[l_key] + if "head" in key: + new_state_dict[key] = init_state_dict[key] + return new_state_dict + + +def adjust(model, num_classes, hidden=False): + in_features = model.l1.in_features + last_layer = torch.nn.modules.linear.Linear( + in_features, num_classes).to(device) + if hidden: + model.l1 = torch.nn.modules.linear.Linear( + in_features, in_features).to(device) + model.l2 = last_layer + else: + model.l1 = last_layer + + def def_forward(self): + def new_forward(x): + h = self.features(x) + h = h.squeeze() + + x = self.l1(h) + if hidden: + x = F.relu(x) + x = self.l2(x) + return x + return new_forward + + model.forward = def_forward(model) + + +def configure_optimizer(model, batch_size, head_only=False, discriminative_lr=False, base_model="xresnet1d", optimizer="adam", discriminative_lr_factor=1): + loss_fn = F.binary_cross_entropy_with_logits + if base_model == "xresnet1d": + wd = 1e-1 + if head_only: + lr = (8e-3*(batch_size/256)) + optimizer = torch.optim.AdamW( + model.l1.parameters(), lr=lr, weight_decay=wd) + else: + lr = 0.01 + if not discriminative_lr: + optimizer = torch.optim.AdamW( + model.parameters(), lr=lr, weight_decay=wd) + else: + param_dict = dict(model.named_parameters()) + keys = param_dict.keys() + weight_layer_nrs = set() + for key in keys: + if "features" in key: + # parameter names have the form features.x + weight_layer_nrs.add(key[9]) + weight_layer_nrs = sorted(weight_layer_nrs, reverse=True) + features_groups = [] + while len(weight_layer_nrs) > 0: + if len(weight_layer_nrs) > 1: + features_groups.append(list(filter( + lambda x: "features." + weight_layer_nrs[0] in x or "features." + weight_layer_nrs[1] in x, keys))) + del weight_layer_nrs[:2] + else: + features_groups.append( + list(filter(lambda x: "features." + weight_layer_nrs[0] in x, keys))) + del weight_layer_nrs[0] + # filter linear layers + linears = list(filter(lambda x: "l" in x, keys)) + groups = [linears] + features_groups + optimizer_param_list = [] + tmp_lr = lr + + for layers in groups: + layer_params = [param_dict[param_name] + for param_name in layers] + optimizer_param_list.append( + {"params": layer_params, "lr": tmp_lr}) + tmp_lr /= 4 + optimizer = torch.optim.AdamW( + optimizer_param_list, lr=lr, weight_decay=wd) + + print("lr", lr) + print("wd", wd) + print("batch size", batch_size) + + elif base_model == "cpc": + if(optimizer == "sgd"): + opt = torch.optim.SGD + elif(optimizer == "adam"): + opt = torch.optim.AdamW + else: + raise NotImplementedError("Unknown Optimizer.") + lr = 1e-4 + wd = 1e-3 + if(head_only): + lr = 1e-3 + print("Linear eval: model head", model.head) + optimizer = opt(model.head.parameters(), lr, weight_decay=wd) + elif(discriminative_lr_factor != 1.): # discrimative lrs + optimizer = opt([{"params": model.encoder.parameters(), "lr": lr*discriminative_lr_factor*discriminative_lr_factor}, { + "params": model.rnn.parameters(), "lr": lr*discriminative_lr_factor}, {"params": model.head.parameters(), "lr": lr}], lr, weight_decay=wd) + print("Finetuning: model head", model.head) + print("discriminative lr: ", discriminative_lr_factor) + else: + lr = 1e-3 + print("normal supervised training") + optimizer = opt(model.parameters(), lr, weight_decay=wd) + else: + raise("model unknown") + return loss_fn, optimizer + + +def load_model(linear_evaluation, num_classes, use_pretrained, discriminative_lr=False, hidden=False, conv_encoder=False, bn_head=False, ps_head=0.5, location="./checkpoints/moco_baselinewonder200.ckpt", method="simclr", base_model="xresnet1d50", out_dim=16, widen=1): + discriminative_lr_factor = 1 + if use_pretrained: + print("load model from " + location) + discriminative_lr_factor = 0.1 + if base_model == "cpc": + lightning_state_dict = torch.load(location, map_location=device) + + # num_head = np.sum([1 if 'proj' in f else 0 for f in lightning_state_dict.keys()]) + if linear_evaluation: + lin_ftrs_head = [] + bn_head = False + ps_head = 0.0 + else: + if hidden: + lin_ftrs_head = [512] + else: + lin_ftrs_head = [] + + if conv_encoder: + strides = [2, 2, 2, 2] + kss = [10, 4, 4, 4] + else: + strides = [1]*4 + kss = [1]*4 + + model = CPCModel(input_channels=12, strides=strides, kss=kss, features=[512]*4, n_hidden=512, n_layers=2, mlp=False, lstm=True, bias_proj=False, + num_classes=num_classes, skip_encoder=False, bn_encoder=True, lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_head=bn_head).to(device) + + if "state_dict" in lightning_state_dict.keys(): + print("load pretrained model") + model_state_dict = get_new_state_dict( + model.state_dict(), lightning_state_dict["state_dict"], method="cpc") + else: + print("load already finetuned model") + model_state_dict = lightning_state_dict + model.load_state_dict(model_state_dict) + else: + model = ResNetSimCLR(base_model, out_dim, hidden=hidden, widen=widen).to(device) + model_state_dict = torch.load(location, map_location=device) + if "state_dict" in model_state_dict.keys(): + model_state_dict = model_state_dict["state_dict"] + if "l1.weight" in model_state_dict.keys(): # load already fine-tuned model + model_classes = model_state_dict["l1.weight"].shape[0] + if model_classes != num_classes: + raise Exception("Loaded model has different output dim ({}) than needed ({})".format( + model_classes, num_classes)) + adjust(model, num_classes, hidden=hidden) + if not hidden and "l2.weight" in model_state_dict: + del model_state_dict["l2.weight"] + del model_state_dict["l2.bias"] + model.load_state_dict(model_state_dict) + else: # load pretrained model + base_dict = model.state_dict() + model_state_dict = get_new_state_dict( + base_dict, model_state_dict, method=method) + model.load_state_dict(model_state_dict) + adjust(model, num_classes, hidden=hidden) + + else: + if "xresnet1d" in base_model: + model = ResNetSimCLR(base_model, out_dim, hidden=hidden, widen=widen).to(device) + adjust(model, num_classes, hidden=hidden) + elif base_model == "cpc": + if linear_evaluation: + lin_ftrs_head = [] + bn_head = False + ps_head = 0.0 + else: + if hidden: + lin_ftrs_head = [512] + else: + lin_ftrs_head = [] + + if conv_encoder: + strides = [2, 2, 2, 2] + kss = [10, 4, 4, 4] + else: + strides = [1]*4 + kss = [1]*4 + + model = CPCModel(input_channels=12, strides=strides, kss=kss, features=[512]*4, n_hidden=512, n_layers=2, mlp=False, lstm=True, bias_proj=False, + num_classes=num_classes, skip_encoder=False, bn_encoder=True, lin_ftrs_head=lin_ftrs_head, ps_head=ps_head, bn_head=bn_head).to(device) + + else: + raise Exception("model unknown") + + return model + + +def evaluate(model, dataloader, idmap, lbl_itos, cpc=False): + preds, targs = eval_model(model, dataloader, cpc=cpc) + scores = eval_scores(targs, preds, classes=lbl_itos, parallel=True) + preds_agg, targs_agg = aggregate_predictions(preds, targs, idmap) + scores_agg = eval_scores(targs_agg, preds_agg, + classes=lbl_itos, parallel=True) + macro = scores["label_AUC"]["macro"] + macro_agg = scores_agg["label_AUC"]["macro"] + return preds, macro, macro_agg + + +def set_train_eval(model, cpc, linear_evaluation): + if linear_evaluation: + if cpc: + model.encoder.eval() + else: + model.features.eval() + else: + model.train() + + +def train_model(model, train_loader, valid_loader, test_loader, epochs, loss_fn, optimizer, head_only=True, linear_evaluation=False, percentage=1, lr_schedule=None, save_model_at=None, val_idmap=None, test_idmap=None, lbl_itos=None, cpc=False): + if head_only: + if linear_evaluation: + print("linear evaluation for {} epochs".format(epochs)) + else: + print("head-only for {} epochs".format(epochs)) + else: + print("fine tuning for {} epochs".format(epochs)) + + if head_only: + for key, param in model.named_parameters(): + if "l1." not in key and "head." not in key: + param.requires_grad = False + print("copying state dict before training for sanity check after training") + + else: + for param in model.parameters(): + param.requires_grad = True + if cpc: + data_type = model.encoder[0][0].weight.type() + else: + data_type = model.features[0][0].weight.type() + + set_train_eval(model, cpc, linear_evaluation) + state_dict_pre = deepcopy(model.state_dict()) + print("epoch", "batch", "loss\n========================") + loss_per_epoch = [] + macro_agg_per_epoch = [] + max_batches = len(train_loader) + break_point = int(percentage*max_batches) + best_macro = 0 + best_macro_agg = 0 + best_epoch = 0 + best_preds = None + test_macro = 0 + test_macro_agg = 0 + for epoch in tqdm(range(epochs)): + if type(lr_schedule) == dict: + if epoch in lr_schedule.keys(): + for param_group in optimizer.param_groups: + param_group['lr'] /= lr_schedule[epoch] + total_loss_one_epoch = 0 + for batch_idx, samples in enumerate(train_loader): + if batch_idx == break_point: + print("break at batch nr.", batch_idx) + break + data = samples[0].to(device).type(data_type) + labels = samples[1].to(device).type(data_type) + optimizer.zero_grad() + preds = model(data) + loss = loss_fn(preds, labels) + loss.backward() + optimizer.step() + total_loss_one_epoch += loss.item() + if(batch_idx % 100 == 0): + print(epoch, batch_idx, loss.item()) + loss_per_epoch.append(total_loss_one_epoch) + + preds, macro, macro_agg = evaluate( + model, valid_loader, val_idmap, lbl_itos, cpc=cpc) + macro_agg_per_epoch.append(macro_agg) + + print("loss:", total_loss_one_epoch) + print("aggregated macro:", macro_agg) + if macro_agg > best_macro_agg: + torch.save(model.state_dict(), save_model_at) + best_macro_agg = macro_agg + best_macro = macro + best_epoch = epoch + best_preds = preds + _, test_macro, test_macro_agg = evaluate( + model, test_loader, test_idmap, lbl_itos, cpc=cpc) + + set_train_eval(model, cpc, linear_evaluation) + + if epochs > 0: + sanity_check(model, state_dict_pre, linear_evaluation, head_only) + return loss_per_epoch, macro_agg_per_epoch, best_macro, best_macro_agg, test_macro, test_macro_agg, best_epoch, best_preds + + +def sanity_check(model, state_dict_pre, linear_evaluation, head_only): + """ + Linear classifier should not change any weights other than the linear layer. + This sanity check asserts nothing wrong happens (e.g., BN stats updated). + """ + print("=> loading state dict for sanity check") + state_dict = model.state_dict() + if linear_evaluation: + for k in list(state_dict.keys()): + # only ignore fc layer + if 'fc.' in k or 'head.' in k or 'l1.' in k: + continue + + equals = (state_dict[k].cpu() == state_dict_pre[k].cpu()).all() + if (linear_evaluation != equals): + raise Exception( + '=> failed sanity check in {}'.format("linear_evaluation")) + elif head_only: + for k in list(state_dict.keys()): + # only ignore fc layer + if 'fc.' in k or 'head.' in k: + continue + + equals = (state_dict[k].cpu() == state_dict_pre[k].cpu()).all() + if (equals and "running_mean" in k): + raise Exception( + '=> failed sanity check in {}'.format("head-only")) + # else: + # for k in list(state_dict.keys()): + # equals=(state_dict[k].cpu() == state_dict_pre[k].cpu()).all() + # if equals: + # pdb.set_trace() + # raise Exception('=> failed sanity check in {}'.format("fine_tuning")) + + print("=> sanity check passed.") + + +def eval_model(model, valid_loader, cpc=False): + if cpc: + data_type = model.encoder[0][0].weight.type() + else: + data_type = model.features[0][0].weight.type() + model.eval() + preds = [] + targs = [] + with torch.no_grad(): + for batch_idx, samples in tqdm(enumerate(valid_loader)): + data = samples[0].to(device).type(data_type) + preds_tmp = torch.sigmoid(model(data)) + targs.append(samples[1]) + preds.append(preds_tmp.cpu()) + preds = torch.cat(preds).numpy() + targs = torch.cat(targs).numpy() + + return preds, targs + + +def get_dataset(batch_size, num_workers, target_folder, apply_noise=False, percentage=1.0, folds=8, t_params=None, test=False, normalize=False): + if apply_noise: + transformations = ["BaselineWander", + "PowerlineNoise", "EMNoise", "BaselineShift"] + if normalize: + transformations.append("Normalize") + dataset = SimCLRDataSetWrapper(batch_size,num_workers,None,"(12, 250)",None,target_folder,[target_folder],None,None, + mode="linear_evaluation", transformations=transformations, percentage=percentage, folds=folds, t_params=t_params, test=test, ptb_xl_label="label_all") + else: + if normalize: + # always use PTB-XL stats + transformations = ["Normalize"] + dataset = SimCLRDataSetWrapper(batch_size,num_workers,None,"(12, 250)",None,target_folder,[target_folder],None,None, + mode="linear_evaluation", percentage=percentage, folds=folds, test=test, transformations=transformations, ptb_xl_label="label_all") + else: + dataset = SimCLRDataSetWrapper(batch_size,num_workers,None,"(12, 250)",None,target_folder,[target_folder],None,None, + mode="linear_evaluation", percentage=percentage, folds=folds, test=test, ptb_xl_label="label_all") + + train_loader, valid_loader = dataset.get_data_loaders() + return dataset, train_loader, valid_loader + + +if __name__ == "__main__": + args = parse_args() + dataset, train_loader, _ = get_dataset( + args.batch_size, args.num_workers, args.dataset, folds=args.folds, test=args.test, normalize=args.normalize) + _, _, valid_loader = get_dataset( + args.batch_size, args.num_workers, args.dataset, folds=args.folds, test=False, normalize=args.normalize) + val_idmap = dataset.val_ds_idmap + dataset, _, test_loader = get_dataset( + args.batch_size, args.num_workers, args.dataset, test=True, normalize=args.normalize) + test_idmap = dataset.val_ds_idmap + lbl_itos = dataset.lbl_itos + tag = "f=" + str(args.folds) + "_" + args.tag + tag = tag if args.use_pretrained else "ran_" + tag + tag = "eval_" + tag if args.eval_only else tag + model_tag = "finetuned" if args.load_finetuned else "ckpt" + if args.test_noised: + t_params_by_level = { + 1: {"bw_cmax": 0.05, "em_cmax": 0.25, "pl_cmax": 0.1, "bs_cmax": 0.5}, + 2: {"bw_cmax": 0.1, "em_cmax": 0.5, "pl_cmax": 0.2, "bs_cmax": 1}, + 3: {"bw_cmax": 0.1, "em_cmax": 1, "pl_cmax": 0.2, "bs_cmax": 2}, + 4: {"bw_cmax": 0.2, "em_cmax": 1, "pl_cmax": 0.4, "bs_cmax": 2}, + 5: {"bw_cmax": 0.2, "em_cmax": 1.5, "pl_cmax": 0.4, "bs_cmax": 2.5}, + 6: {"bw_cmax": 0.3, "em_cmax": 2, "pl_cmax": 0.5, "bs_cmax": 3}, + } + if args.noise_level not in t_params_by_level.keys(): + raise("noise level does not exist") + t_params = t_params_by_level[args.noise_level] + dataset, _, noise_valid_loader = get_dataset( + args.batch_size, args.num_workers, args.dataset, apply_noise=True, t_params=t_params, test=args.test) + else: + noise_valid_loader = None + losses, macros, predss, result_macros, result_macros_agg, test_macros, test_macros_agg, noised_macros, noised_macros_agg = [ + ], [], [], [], [], [], [], [], [] + ckpt_epoch_lin=0 + ckpt_epoch_fin=0 + if args.f_epochs == 0: + save_model_at = os.path.join(os.path.dirname( + args.model_file), "n=" + str(args.noise_level) + "_"+tag + "lin_finetuned") + filename = os.path.join(os.path.dirname( + args.model_file), "n=" + str(args.noise_level) + "_"+tag + "res_lin.pkl") + else: + save_model_at = os.path.join(os.path.dirname( + args.model_file), "n=" + str(args.noise_level) + "_"+tag + "fin_finetuned") + filename = os.path.join(os.path.dirname( + args.model_file), "n=" + str(args.noise_level) + "_"+tag + "res_fin.pkl") + + model = load_model( + args.linear_evaluation, 71, args.use_pretrained or args.load_finetuned, hidden=args.hidden, + location=args.model_file, discriminative_lr=args.discriminative_lr, method=args.method) + loss_fn, optimizer = configure_optimizer( + model, args.batch_size, head_only=True, discriminative_lr=args.discriminative_lr, discriminative_lr_factor=0.1 if args.use_pretrained and args.discriminative_lr else 1) + if not args.eval_only: + print("train model...") + if not isdir(save_model_at): + os.mkdir(save_model_at) + + l1, m1, bm, bm_agg, tm, tm_agg, ckpt_epoch_lin, preds = train_model(model, train_loader, valid_loader, test_loader, args.l_epochs, loss_fn, + optimizer, head_only=True, linear_evaluation=args.linear_evaluation, lr_schedule=args.lr_schedule, save_model_at=join(save_model_at, "finetuned.pt"), + val_idmap=val_idmap, test_idmap=test_idmap, lbl_itos=lbl_itos, cpc=(args.method == "cpc")) + if bm != 0: + print("best macro after head-only training:", bm_agg) + l2 = [] + m2 = [] + if args.f_epochs != 0: + if args.l_epochs != 0: + model = load_model( + False, 71, True, hidden=args.hidden, + location=join(save_model_at, "finetuned.pt"), discriminative_lr=args.discriminative_lr, method=args.method) + loss_fn, optimizer = configure_optimizer( + model, args.batch_size, head_only=False, discriminative_lr=args.discriminative_lr, discriminative_lr_factor=0.1 if args.use_pretrained and args.discriminative_lr else 1) + l2, m2, bm, bm_agg, tm, tm_agg, ckpt_epoch_fin, preds = train_model(model, train_loader, valid_loader, test_loader, args.f_epochs, loss_fn, + optimizer, head_only=False, linear_evaluation=False, lr_schedule=args.lr_schedule, save_model_at=join(save_model_at, "finetuned.pt"), + val_idmap=val_idmap, test_idmap=test_idmap, lbl_itos=lbl_itos, cpc=(args.method == "cpc")) + losses.append(l1+l2) + macros.append(m1+m2) + test_macros.append(tm) + test_macros_agg.append(tm_agg) + result_macros.append(bm) + result_macros_agg.append(bm_agg) + + else: + preds, eval_macro, eval_macro_agg = evaluate( + model, test_loader, test_idmap, lbl_itos, cpc=(args.method == "cpc")) + result_macros.append(eval_macro) + result_macros_agg.append(eval_macro_agg) + if args.verbose: + print("macro:", eval_macro) + predss.append(preds) + + if noise_valid_loader is not None: + _, noise_macro, noise_macro_agg = evaluate( + model, noise_valid_loader, val_idmap, lbl_itos) + noised_macros.append(noise_macro) + noised_macros_agg.append(noise_macro_agg) + res = {"filename": filename, "epochs": args.l_epochs+args.f_epochs, "model_location": args.model_location, + "losses": losses, "macros": macros, "predss": predss, "result_macros": result_macros, "result_macros_agg": result_macros_agg, + "test_macros": test_macros, "test_macros_agg": test_macros_agg, "noised_macros": noised_macros, "noised_macros_agg": noised_macros_agg, "ckpt_epoch_lin": ckpt_epoch_lin, "ckpt_epoch_fin": ckpt_epoch_fin, + "discriminative_lr": args.discriminative_lr, "hidden": args.hidden, "lr_schedule": args.lr_schedule, + "use_pretrained": args.use_pretrained, "linear_evaluation": args.linear_evaluation, "loaded_finetuned": args.load_finetuned, + "eval_only": args.eval_only, "noise_level": args.noise_level, "test_noised": args.test_noised, "normalized": args.normalize} + pickle.dump(res, open(filename, "wb")) + print("dumped results to", filename) + print(res) + print("Done!") +