--- a +++ b/main_cpc_lightning.py @@ -0,0 +1,413 @@ +############### +#generic +import torch +from torch import nn +import pytorch_lightning as pl +from torch.utils.data import DataLoader, ConcatDataset +from torchvision import transforms +import torch.nn.functional as F + +import torchvision +import os +import argparse +from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor +import copy + +################# +#specific +from clinical_ts.timeseries_utils import * +from clinical_ts.ecg_utils import * + +from functools import partial +from pathlib import Path +import pandas as pd +import numpy as np + +from clinical_ts.xresnet1d import xresnet1d50,xresnet1d101 +from clinical_ts.basic_conv1d import weight_init +from clinical_ts.eval_utils_cafa import eval_scores, eval_scores_bootstrap +from clinical_ts.cpc import * + +def _freeze_bn_stats(model, freeze=True): + for m in model.modules(): + if(isinstance(m,nn.BatchNorm1d)): + if(freeze): + m.eval() + else: + m.train() + +def sanity_check(model, state_dict_pre): + """ + Linear classifier should not change any weights other than the linear layer. + This sanity check asserts nothing wrong happens (e.g., BN stats updated). + """ + print("=> loading state dict for sanity check") + state_dict = model.state_dict() + + for k in list(state_dict.keys()): + print(k) + # only ignore fc layer + if 'head.1.weight' in k or 'head.1.bias' in k: + continue + + + assert ((state_dict[k].cpu() == state_dict_pre[k].cpu()).all()), \ + '{} is changed in linear classifier training.'.format(k) + + print("=> sanity check passed.") + +class LightningCPC(pl.LightningModule): + + def __init__(self, hparams): + super(LightningCPC, self).__init__() + + self.hparams = hparams + self.lr = self.hparams.lr + + #these coincide with the adapted wav2vec2 params + if(self.hparams.fc_encoder): + strides=[1]*4 + kss = [1]*4 + features = [512]*4 + else: #strided conv encoder + strides=[2,2,2,2] #original wav2vec2 [5,2,2,2,2,2] original cpc [5,4,2,2,2] + kss = [10,4,4,4] #original wav2vec2 [10,3,3,3,3,2] original cpc [18,8,4,4,4] + features = [512]*4 #wav2vec2 [512]*6 original cpc [512]*5 + + if(self.hparams.finetune): + self.criterion = F.cross_entropy if self.hparams.finetune_dataset == "thew" else F.binary_cross_entropy_with_logits + if(self.hparams.finetune_dataset == "thew"): + num_classes = 5 + elif(self.hparams.finetune_dataset == "ptbxl_super"): + num_classes = 5 + if(self.hparams.finetune_dataset == "ptbxl_all"): + num_classes = 71 + else: + num_classes = None + + self.model_cpc = CPCModel(input_channels=self.hparams.input_channels, strides=strides,kss=kss,features=features,n_hidden=self.hparams.n_hidden,n_layers=self.hparams.n_layers,mlp=self.hparams.mlp,lstm=not(self.hparams.gru),bias_proj=self.hparams.bias,num_classes=num_classes,skip_encoder=self.hparams.skip_encoder,bn_encoder=not(self.hparams.no_bn_encoder),lin_ftrs_head=[] if self.hparams.linear_eval else eval(self.hparams.lin_ftrs_head),ps_head=0 if self.hparams.linear_eval else self.hparams.dropout_head,bn_head=False if self.hparams.linear_eval else not(self.hparams.no_bn_head)) + + target_fs=100 + if(not(self.hparams.finetune)): + print("CPC pretraining:\ndownsampling factor:",self.model_cpc.encoder_downsampling_factor,"\nchunk length(s)",self.model_cpc.encoder_downsampling_factor/target_fs,"\npixels predicted ahead:",self.model_cpc.encoder_downsampling_factor*self.hparams.steps_predicted,"\nseconds predicted ahead:",self.model_cpc.encoder_downsampling_factor*self.hparams.steps_predicted/target_fs,"\nRNN input size:",self.hparams.input_size//self.model_cpc.encoder_downsampling_factor) + + def forward(self, x): + return self.model_cpc(x) + + def _step(self,data_batch, batch_idx, train): + if(self.hparams.finetune): + preds = self.forward(data_batch[0]) + loss = self.criterion(preds,data_batch[1]) + self.log("train_loss" if train else "val_loss", loss) + return {'loss':loss, "preds":preds.detach(), "targs": data_batch[1]} + else: + loss, acc = self.model_cpc.cpc_loss(data_batch[0],steps_predicted=self.hparams.steps_predicted,n_false_negatives=self.hparams.n_false_negatives, negatives_from_same_seq_only=self.hparams.negatives_from_same_seq_only, eval_acc=True) + self.log("loss" if train else "val_loss", loss) + self.log("acc" if train else "val_acc", acc) + return loss + + def training_step(self, train_batch, batch_idx): + if(self.hparams.linear_eval): + _freeze_bn_stats(self) + return self._step(train_batch,batch_idx,True) + + def validation_step(self, val_batch, batch_idx, dataloader_idx=0): + return self._step(val_batch,batch_idx,False) + + def validation_epoch_end(self, outputs_all): + if(self.hparams.finetune): + for dataloader_idx,outputs in enumerate(outputs_all): #multiple val dataloaders + preds_all = torch.cat([x['preds'] for x in outputs]) + targs_all = torch.cat([x['targs'] for x in outputs]) + if(self.hparams.finetune_dataset=="thew"): + preds_all = F.softmax(preds_all,dim=-1) + targs_all = torch.eye(len(self.lbl_itos))[targs_all].to(preds.device) + else: + preds_all = torch.sigmoid(preds_all) + preds_all = preds_all.cpu().numpy() + targs_all = targs_all.cpu().numpy() + #instance level score + res = eval_scores(targs_all,preds_all,classes=self.lbl_itos) + + idmap = self.val_dataset.get_id_mapping() + preds_all_agg,targs_all_agg = aggregate_predictions(preds_all,targs_all,idmap,aggregate_fn=np.mean) + res_agg = eval_scores(targs_all_agg,preds_all_agg,classes=self.lbl_itos) + self.log_dict({"macro_auc_agg"+str(dataloader_idx):res_agg["label_AUC"]["macro"], "macro_auc_noagg"+str(dataloader_idx):res["label_AUC"]["macro"]}) + print("epoch",self.current_epoch,"macro_auc_agg"+str(dataloader_idx)+":",res_agg["label_AUC"]["macro"],"macro_auc_noagg"+str(dataloader_idx)+":",res["label_AUC"]["macro"]) + + + def on_fit_start(self): + if(self.hparams.linear_eval): + print("copying state dict before training for sanity check after training") + self.state_dict_pre = copy.deepcopy(self.state_dict().copy()) + + + def on_fit_end(self): + if(self.hparams.linear_eval): + sanity_check(self,self.state_dict_pre) + + + def setup(self, stage): + # configure dataset params + chunkify_train = False + chunk_length_train = self.hparams.input_size if chunkify_train else 0 + stride_train = self.hparams.input_size + + chunkify_valtest = True + chunk_length_valtest = self.hparams.input_size if chunkify_valtest else 0 + stride_valtest = self.hparams.input_size//2 + + train_datasets = [] + val_datasets = [] + test_datasets = [] + + for i,target_folder in enumerate(self.hparams.data): + target_folder = Path(target_folder) + + df_mapped, lbl_itos, mean, std = load_dataset(target_folder) + # always use PTB-XL stats + mean = np.array([-0.00184586, -0.00130277, 0.00017031, -0.00091313, -0.00148835, -0.00174687, -0.00077071, -0.00207407, 0.00054329, 0.00155546, -0.00114379, -0.00035649]) + std = np.array([0.16401004, 0.1647168 , 0.23374124, 0.33767231, 0.33362807, 0.30583013, 0.2731171 , 0.27554379, 0.17128962, 0.14030828, 0.14606956, 0.14656108]) + + #specific for PTB-XL + if(self.hparams.finetune and self.hparams.finetune_dataset.startswith("ptbxl")): + if(self.hparams.finetune_dataset=="ptbxl_super"): + ptb_xl_label = "label_diag_superclass" + elif(self.hparams.finetune_dataset=="ptbxl_all"): + ptb_xl_label = "label_all" + + lbl_itos= np.array(lbl_itos[ptb_xl_label]) + + def multihot_encode(x, num_classes): + res = np.zeros(num_classes,dtype=np.float32) + for y in x: + res[y]=1 + return res + + df_mapped["label"]= df_mapped[ptb_xl_label+"_filtered_numeric"].apply(lambda x: multihot_encode(x,len(lbl_itos))) + + + self.lbl_itos = lbl_itos + tfms_ptb_xl_cpc = ToTensor() if self.hparams.normalize is False else transforms.Compose([Normalize(mean,std),ToTensor()]) + + max_fold_id = df_mapped.strat_fold.max() #unfortunately 1-based for PTB-XL; sometimes 100 (Ribeiro) + df_train = df_mapped[df_mapped.strat_fold<(max_fold_id-1 if self.hparams.finetune else max_fold_id)] + df_val = df_mapped[df_mapped.strat_fold==(max_fold_id-1 if self.hparams.finetune else max_fold_id)] + if(self.hparams.finetune): + df_test = df_mapped[df_mapped.strat_fold==max_fold_id] + train_datasets.append(TimeseriesDatasetCrops(df_train,self.hparams.input_size,num_classes=len(lbl_itos),data_folder=target_folder,chunk_length=chunk_length_train,min_chunk_length=self.hparams.input_size, stride=stride_train,transforms=tfms_ptb_xl_cpc,annotation=False,col_lbl ="label" if self.hparams.finetune else None,memmap_filename=target_folder/("memmap.npy"))) + val_datasets.append(TimeseriesDatasetCrops(df_val,self.hparams.input_size,num_classes=len(lbl_itos),data_folder=target_folder,chunk_length=chunk_length_valtest,min_chunk_length=self.hparams.input_size, stride=stride_valtest,transforms=tfms_ptb_xl_cpc,annotation=False,col_lbl ="label" if self.hparams.finetune else None,memmap_filename=target_folder/("memmap.npy"))) + if(self.hparams.finetune): + test_datasets.append(TimeseriesDatasetCrops(df_test,self.hparams.input_size,num_classes=len(lbl_itos),data_folder=target_folder,chunk_length=chunk_length_valtest,min_chunk_length=self.hparams.input_size, stride=stride_valtest,transforms=tfms_ptb_xl_cpc,annotation=False,col_lbl ="label",memmap_filename=target_folder/("memmap.npy"))) + + print("\n",target_folder) + print("train dataset:",len(train_datasets[-1]),"samples") + print("val dataset:",len(val_datasets[-1]),"samples") + if(self.hparams.finetune): + print("test dataset:",len(test_datasets[-1]),"samples") + + if(len(train_datasets)>1): #multiple data folders + print("\nCombined:") + self.train_dataset = ConcatDataset(train_datasets) + self.val_dataset = ConcatDataset(val_datasets) + print("train dataset:",len(self.train_dataset),"samples") + print("val dataset:",len(self.val_dataset),"samples") + if(self.hparams.finetune): + self.test_dataset = ConcatDataset(test_datasets) + print("test dataset:",len(self.test_dataset),"samples") + else: #just a single data folder + self.train_dataset = train_datasets[0] + self.val_dataset = val_datasets[0] + if(self.hparams.finetune): + self.test_dataset = test_datasets[0] + + def train_dataloader(self): + return DataLoader(self.train_dataset, batch_size=self.hparams.batch_size, num_workers=4, shuffle=True, drop_last = True) + + def val_dataloader(self): + if(self.hparams.finetune):#multiple val dataloaders + return [DataLoader(self.val_dataset, batch_size=self.hparams.batch_size, num_workers=4),DataLoader(self.test_dataset, batch_size=self.hparams.batch_size, num_workers=4)] + else: + return DataLoader(self.val_dataset, batch_size=self.hparams.batch_size, num_workers=4) + + def configure_optimizers(self): + if(self.hparams.optimizer == "sgd"): + opt = torch.optim.SGD + elif(self.hparams.optimizer == "adam"): + opt = torch.optim.AdamW + else: + raise NotImplementedError("Unknown Optimizer.") + + if(self.hparams.finetune and (self.hparams.linear_eval or self.hparams.train_head_only)): + optimizer = opt(self.model_cpc.head.parameters(), self.lr, weight_decay=self.hparams.weight_decay) + elif(self.hparams.finetune and self.hparams.discriminative_lr_factor != 1.):#discrimative lrs + optimizer = opt([{"params":self.model_cpc.encoder.parameters(), "lr":self.lr*self.hparams.discriminative_lr_factor*self.hparams.discriminative_lr_factor},{"params":self.model_cpc.rnn.parameters(), "lr":self.lr*self.hparams.discriminative_lr_factor},{"params":self.model_cpc.head.parameters(), "lr":self.lr}],self.hparams.lr, weight_decay=self.hparams.weight_decay) + else: + optimizer = opt(self.parameters(), self.lr, weight_decay=self.hparams.weight_decay) + + return optimizer + + def load_weights_from_checkpoint(self, checkpoint): + """ Function that loads the weights from a given checkpoint file. + based on https://github.com/PyTorchLightning/pytorch-lightning/issues/525 + """ + checkpoint = torch.load(checkpoint, map_location=lambda storage, loc: storage,) + pretrained_dict = checkpoint["state_dict"] + model_dict = self.state_dict() + + pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} + model_dict.update(pretrained_dict) + self.load_state_dict(model_dict) + +##################################################################################################### +#ARGPARSERS +##################################################################################################### +def add_model_specific_args(parser): + parser.add_argument("--input-channels", type=int, default=12) + parser.add_argument("--normalize", action='store_true', help='Normalize input using PTB-XL stats') + parser.add_argument('--mlp', action='store_true', help="False: original CPC True: as in SimCLR") + parser.add_argument('--bias', action='store_true', help="original CPC: no bias") + parser.add_argument("--n-hidden", type=int, default=512) + parser.add_argument("--gru", action="store_true") + parser.add_argument("--n-layers", type=int, default=2) + parser.add_argument("--steps-predicted", dest="steps_predicted", type=int, default=12) + parser.add_argument("--n-false-negatives", dest="n_false_negatives", type=int, default=128) + parser.add_argument("--skip-encoder", action="store_true", help="disable the convolutional encoder i.e. just RNN; for testing") + parser.add_argument("--fc-encoder", action="store_true", help="use a fully connected encoder (as opposed to an encoder with strided convs)") + parser.add_argument("--negatives-from-same-seq-only", action="store_true", help="only draw false negatives from same sequence (as opposed to drawing from everywhere)") + parser.add_argument("--no-bn-encoder", action="store_true", help="switch off batch normalization in encoder") + parser.add_argument("--dropout-head", type=float, default=0.5) + parser.add_argument("--train-head-only", action="store_true", help="freeze everything except classification head (note: --linear-eval defaults to no hidden layer in classification head)") + parser.add_argument("--lin-ftrs-head", type=str, default="[512]", help="hidden layers in the classification head") + parser.add_argument('--no-bn-head', action='store_true', help="use no batch normalization in classification head") + return parser + +def add_default_args(): + parser = argparse.ArgumentParser(description='PyTorch Lightning CPC Training') + parser.add_argument('--data', metavar='DIR',type=str, + help='path(s) to dataset',action='append') + parser.add_argument('--epochs', default=30, type=int, metavar='N', + help='number of total epochs to run') + parser.add_argument('--batch-size', default=64, type=int, + metavar='N', + help='mini-batch size (default: 256), this is the total ' + 'batch size of all GPUs on the current node when ' + 'using Data Parallel or Distributed Data Parallel') + parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float, + metavar='LR', help='initial learning rate', dest='lr') + parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float, + metavar='W', help='weight decay (default: 0.)', + dest='weight_decay') + + parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') + + parser.add_argument('--pretrained', default='', type=str, metavar='PATH', + help='path to pretrained checkpoint (default: none)') + parser.add_argument('--optimizer', default='adam', help='sgd/adam')#was sgd + parser.add_argument('--output-path', default='.', type=str,dest="output_path", + help='output path') + parser.add_argument('--metadata', default='', type=str, + help='metadata for output') + + parser.add_argument("--gpus", type=int, default=1, help="number of gpus") + parser.add_argument("--num-nodes", dest="num_nodes", type=int, default=1, help="number of compute nodes") + parser.add_argument("--precision", type=int, default=16, help="16/32") + parser.add_argument("--distributed-backend", dest="distributed_backend", type=str, default=None, help="None/ddp") + parser.add_argument("--accumulate", type=int, default=1, help="accumulate grad batches (total-bs=accumulate-batches*bs)") + + parser.add_argument("--input-size", dest="input_size", type=int, default=16000) + + parser.add_argument("--finetune", action="store_true", help="finetuning (downstream classification task)", default=False ) + parser.add_argument("--linear-eval", action="store_true", help="linear evaluation instead of full finetuning", default=False ) + + parser.add_argument( + "--finetune-dataset", + type=str, + help="thew/ptbxl_super/ptbxl_all", + default="thew" + ) + + parser.add_argument( + "--discriminative-lr-factor", + type=float, + help="factor by which the lr decreases per layer group during finetuning", + default=0.1 + ) + + + parser.add_argument( + "--lr-find", + action="store_true", + help="run lr finder before training run", + default=False + ) + + return parser + +################################################################################################### +#MAIN +################################################################################################### +if __name__ == '__main__': + parser = add_default_args() + parser = add_model_specific_args(parser) + hparams = parser.parse_args() + hparams.executable = "cpc" + + if not os.path.exists(hparams.output_path): + os.makedirs(hparams.output_path) + + model = LightningCPC(hparams) + + if(hparams.pretrained!=""): + print("Loading pretrained weights from",hparams.pretrained) + model.load_weights_from_checkpoint(hparams.pretrained) + + + logger = TensorBoardLogger( + save_dir=hparams.output_path, + #version="",#hparams.metadata.split(":")[0], + name="") + print("Output directory:",logger.log_dir) + checkpoint_callback = ModelCheckpoint( + filepath=os.path.join(logger.log_dir,"best_model"),#hparams.output_path + save_top_k=1, + save_last=True, + verbose=True, + monitor='macro_auc_agg0' if hparams.finetune else 'val_loss',#val_loss/dataloader_idx_0 + mode='max' if hparams.finetune else 'min', + prefix='') + lr_monitor = LearningRateMonitor() + + trainer = pl.Trainer( + #overfit_batches=0.01, + auto_lr_find = hparams.lr_find, + accumulate_grad_batches=hparams.accumulate, + max_epochs=hparams.epochs, + min_epochs=hparams.epochs, + + default_root_dir=hparams.output_path, + + num_sanity_val_steps=0, + + logger=logger, + checkpoint_callback=checkpoint_callback, + callbacks = [],#lr_monitor], + benchmark=True, + + gpus=hparams.gpus, + num_nodes=hparams.num_nodes, + precision=hparams.precision, + distributed_backend=hparams.distributed_backend, + + progress_bar_refresh_rate=0, + weights_summary='top', + resume_from_checkpoint= None if hparams.resume=="" else hparams.resume) + + if(hparams.lr_find):#lr find + trainer.tune(model) + + trainer.fit(model) + +