Diff of /main_cpc_lightning.py [000000] .. [134fd7]

Switch to side-by-side view

--- a
+++ b/main_cpc_lightning.py
@@ -0,0 +1,413 @@
+###############
+#generic
+import torch
+from torch import nn
+import pytorch_lightning as pl
+from torch.utils.data import DataLoader, ConcatDataset
+from torchvision import transforms
+import torch.nn.functional as F
+
+import torchvision
+import os
+import argparse
+from pytorch_lightning.loggers import TensorBoardLogger
+from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
+import copy
+
+#################
+#specific
+from clinical_ts.timeseries_utils import *
+from clinical_ts.ecg_utils import *
+
+from functools import partial
+from pathlib import Path
+import pandas as pd
+import numpy as np
+
+from clinical_ts.xresnet1d import xresnet1d50,xresnet1d101
+from clinical_ts.basic_conv1d import weight_init
+from clinical_ts.eval_utils_cafa import eval_scores, eval_scores_bootstrap
+from clinical_ts.cpc import *
+
+def _freeze_bn_stats(model, freeze=True):
+    for m in model.modules():
+        if(isinstance(m,nn.BatchNorm1d)):
+            if(freeze):
+                m.eval()
+            else:
+                m.train()
+                
+def sanity_check(model, state_dict_pre):
+    """
+    Linear classifier should not change any weights other than the linear layer.
+    This sanity check asserts nothing wrong happens (e.g., BN stats updated).
+    """
+    print("=> loading state dict for sanity check")
+    state_dict = model.state_dict()
+
+    for k in list(state_dict.keys()):
+        print(k)
+        # only ignore fc layer
+        if 'head.1.weight' in k or 'head.1.bias' in k:
+            continue
+
+
+        assert ((state_dict[k].cpu() == state_dict_pre[k].cpu()).all()), \
+            '{} is changed in linear classifier training.'.format(k)
+
+    print("=> sanity check passed.")
+
+class LightningCPC(pl.LightningModule):
+
+    def __init__(self, hparams):
+        super(LightningCPC, self).__init__()
+        
+        self.hparams = hparams
+        self.lr = self.hparams.lr
+        
+        #these coincide with the adapted wav2vec2 params
+        if(self.hparams.fc_encoder):
+            strides=[1]*4 
+            kss = [1]*4 
+            features = [512]*4
+        else: #strided conv encoder
+            strides=[2,2,2,2] #original wav2vec2 [5,2,2,2,2,2] original cpc [5,4,2,2,2]
+            kss = [10,4,4,4] #original wav2vec2 [10,3,3,3,3,2] original cpc [18,8,4,4,4]
+            features = [512]*4 #wav2vec2 [512]*6 original cpc [512]*5
+        
+        if(self.hparams.finetune):
+            self.criterion = F.cross_entropy if self.hparams.finetune_dataset == "thew" else F.binary_cross_entropy_with_logits
+            if(self.hparams.finetune_dataset == "thew"):
+                num_classes = 5
+            elif(self.hparams.finetune_dataset == "ptbxl_super"):
+                num_classes = 5
+            if(self.hparams.finetune_dataset == "ptbxl_all"):
+                num_classes = 71
+        else:
+            num_classes = None
+
+        self.model_cpc = CPCModel(input_channels=self.hparams.input_channels, strides=strides,kss=kss,features=features,n_hidden=self.hparams.n_hidden,n_layers=self.hparams.n_layers,mlp=self.hparams.mlp,lstm=not(self.hparams.gru),bias_proj=self.hparams.bias,num_classes=num_classes,skip_encoder=self.hparams.skip_encoder,bn_encoder=not(self.hparams.no_bn_encoder),lin_ftrs_head=[] if self.hparams.linear_eval else eval(self.hparams.lin_ftrs_head),ps_head=0 if self.hparams.linear_eval else self.hparams.dropout_head,bn_head=False if self.hparams.linear_eval else not(self.hparams.no_bn_head))
+        
+        target_fs=100
+        if(not(self.hparams.finetune)):
+            print("CPC pretraining:\ndownsampling factor:",self.model_cpc.encoder_downsampling_factor,"\nchunk length(s)",self.model_cpc.encoder_downsampling_factor/target_fs,"\npixels predicted ahead:",self.model_cpc.encoder_downsampling_factor*self.hparams.steps_predicted,"\nseconds predicted ahead:",self.model_cpc.encoder_downsampling_factor*self.hparams.steps_predicted/target_fs,"\nRNN input size:",self.hparams.input_size//self.model_cpc.encoder_downsampling_factor)
+
+    def forward(self, x):
+        return self.model_cpc(x)
+        
+    def _step(self,data_batch, batch_idx, train):       
+        if(self.hparams.finetune):
+            preds = self.forward(data_batch[0])
+            loss = self.criterion(preds,data_batch[1])
+            self.log("train_loss" if train else "val_loss", loss)
+            return {'loss':loss, "preds":preds.detach(), "targs": data_batch[1]}
+        else:
+            loss, acc = self.model_cpc.cpc_loss(data_batch[0],steps_predicted=self.hparams.steps_predicted,n_false_negatives=self.hparams.n_false_negatives, negatives_from_same_seq_only=self.hparams.negatives_from_same_seq_only, eval_acc=True)
+            self.log("loss" if train else "val_loss", loss)
+            self.log("acc" if train else "val_acc", acc)
+            return loss
+      
+    def training_step(self, train_batch, batch_idx):
+        if(self.hparams.linear_eval):
+            _freeze_bn_stats(self)
+        return self._step(train_batch,batch_idx,True)
+        
+    def validation_step(self, val_batch, batch_idx, dataloader_idx=0):
+        return self._step(val_batch,batch_idx,False)
+        
+    def validation_epoch_end(self, outputs_all):
+        if(self.hparams.finetune):
+            for dataloader_idx,outputs in enumerate(outputs_all): #multiple val dataloaders
+                preds_all = torch.cat([x['preds'] for x in outputs])
+                targs_all = torch.cat([x['targs'] for x in outputs])
+                if(self.hparams.finetune_dataset=="thew"):
+                    preds_all = F.softmax(preds_all,dim=-1)
+                    targs_all = torch.eye(len(self.lbl_itos))[targs_all].to(preds.device) 
+                else:
+                    preds_all = torch.sigmoid(preds_all)
+                preds_all = preds_all.cpu().numpy()
+                targs_all = targs_all.cpu().numpy()
+                #instance level score
+                res = eval_scores(targs_all,preds_all,classes=self.lbl_itos)
+                
+                idmap = self.val_dataset.get_id_mapping()
+                preds_all_agg,targs_all_agg = aggregate_predictions(preds_all,targs_all,idmap,aggregate_fn=np.mean)
+                res_agg = eval_scores(targs_all_agg,preds_all_agg,classes=self.lbl_itos)
+                self.log_dict({"macro_auc_agg"+str(dataloader_idx):res_agg["label_AUC"]["macro"], "macro_auc_noagg"+str(dataloader_idx):res["label_AUC"]["macro"]})
+                print("epoch",self.current_epoch,"macro_auc_agg"+str(dataloader_idx)+":",res_agg["label_AUC"]["macro"],"macro_auc_noagg"+str(dataloader_idx)+":",res["label_AUC"]["macro"])
+
+
+    def on_fit_start(self):
+        if(self.hparams.linear_eval):
+            print("copying state dict before training for sanity check after training")   
+            self.state_dict_pre = copy.deepcopy(self.state_dict().copy())
+
+    
+    def on_fit_end(self):
+        if(self.hparams.linear_eval):
+            sanity_check(self,self.state_dict_pre)
+            
+            
+    def setup(self, stage):
+        # configure dataset params
+        chunkify_train = False
+        chunk_length_train = self.hparams.input_size if chunkify_train else 0
+        stride_train = self.hparams.input_size
+        
+        chunkify_valtest = True
+        chunk_length_valtest = self.hparams.input_size if chunkify_valtest else 0
+        stride_valtest = self.hparams.input_size//2
+
+        train_datasets = []
+        val_datasets = []
+        test_datasets = []
+        
+        for i,target_folder in enumerate(self.hparams.data):
+            target_folder = Path(target_folder)           
+            
+            df_mapped, lbl_itos,  mean, std = load_dataset(target_folder)
+            # always use PTB-XL stats
+            mean = np.array([-0.00184586, -0.00130277,  0.00017031, -0.00091313, -0.00148835,  -0.00174687, -0.00077071, -0.00207407,  0.00054329,  0.00155546,  -0.00114379, -0.00035649])
+            std = np.array([0.16401004, 0.1647168 , 0.23374124, 0.33767231, 0.33362807,  0.30583013, 0.2731171 , 0.27554379, 0.17128962, 0.14030828,   0.14606956, 0.14656108])
+            
+            #specific for PTB-XL
+            if(self.hparams.finetune and self.hparams.finetune_dataset.startswith("ptbxl")):
+                if(self.hparams.finetune_dataset=="ptbxl_super"):
+                    ptb_xl_label = "label_diag_superclass"
+                elif(self.hparams.finetune_dataset=="ptbxl_all"):
+                    ptb_xl_label = "label_all"
+                    
+                lbl_itos= np.array(lbl_itos[ptb_xl_label])
+                
+                def multihot_encode(x, num_classes):
+                    res = np.zeros(num_classes,dtype=np.float32)
+                    for y in x:
+                        res[y]=1
+                    return res
+                    
+                df_mapped["label"]= df_mapped[ptb_xl_label+"_filtered_numeric"].apply(lambda x: multihot_encode(x,len(lbl_itos)))
+                    
+            
+            self.lbl_itos = lbl_itos
+            tfms_ptb_xl_cpc = ToTensor() if self.hparams.normalize is False else transforms.Compose([Normalize(mean,std),ToTensor()])
+            
+            max_fold_id = df_mapped.strat_fold.max() #unfortunately 1-based for PTB-XL; sometimes 100 (Ribeiro)
+            df_train = df_mapped[df_mapped.strat_fold<(max_fold_id-1 if self.hparams.finetune else max_fold_id)]
+            df_val = df_mapped[df_mapped.strat_fold==(max_fold_id-1 if self.hparams.finetune else max_fold_id)]
+            if(self.hparams.finetune):
+                df_test = df_mapped[df_mapped.strat_fold==max_fold_id]
+            train_datasets.append(TimeseriesDatasetCrops(df_train,self.hparams.input_size,num_classes=len(lbl_itos),data_folder=target_folder,chunk_length=chunk_length_train,min_chunk_length=self.hparams.input_size, stride=stride_train,transforms=tfms_ptb_xl_cpc,annotation=False,col_lbl ="label" if self.hparams.finetune else None,memmap_filename=target_folder/("memmap.npy")))
+            val_datasets.append(TimeseriesDatasetCrops(df_val,self.hparams.input_size,num_classes=len(lbl_itos),data_folder=target_folder,chunk_length=chunk_length_valtest,min_chunk_length=self.hparams.input_size, stride=stride_valtest,transforms=tfms_ptb_xl_cpc,annotation=False,col_lbl ="label" if self.hparams.finetune else None,memmap_filename=target_folder/("memmap.npy")))
+            if(self.hparams.finetune):
+                test_datasets.append(TimeseriesDatasetCrops(df_test,self.hparams.input_size,num_classes=len(lbl_itos),data_folder=target_folder,chunk_length=chunk_length_valtest,min_chunk_length=self.hparams.input_size, stride=stride_valtest,transforms=tfms_ptb_xl_cpc,annotation=False,col_lbl ="label",memmap_filename=target_folder/("memmap.npy")))
+            
+            print("\n",target_folder)
+            print("train dataset:",len(train_datasets[-1]),"samples")
+            print("val dataset:",len(val_datasets[-1]),"samples")
+            if(self.hparams.finetune):
+                print("test dataset:",len(test_datasets[-1]),"samples")
+
+        if(len(train_datasets)>1): #multiple data folders
+            print("\nCombined:")
+            self.train_dataset = ConcatDataset(train_datasets)
+            self.val_dataset = ConcatDataset(val_datasets)
+            print("train dataset:",len(self.train_dataset),"samples")
+            print("val dataset:",len(self.val_dataset),"samples")
+            if(self.hparams.finetune):
+                self.test_dataset = ConcatDataset(test_datasets)
+                print("test dataset:",len(self.test_dataset),"samples")
+        else: #just a single data folder
+            self.train_dataset = train_datasets[0]
+            self.val_dataset = val_datasets[0]
+            if(self.hparams.finetune):
+                self.test_dataset = test_datasets[0]
+    
+    def train_dataloader(self):
+        return DataLoader(self.train_dataset, batch_size=self.hparams.batch_size, num_workers=4, shuffle=True, drop_last = True)
+        
+    def val_dataloader(self):
+        if(self.hparams.finetune):#multiple val dataloaders
+            return [DataLoader(self.val_dataset, batch_size=self.hparams.batch_size, num_workers=4),DataLoader(self.test_dataset, batch_size=self.hparams.batch_size, num_workers=4)]
+        else:
+            return DataLoader(self.val_dataset, batch_size=self.hparams.batch_size, num_workers=4)
+
+    def configure_optimizers(self):
+        if(self.hparams.optimizer == "sgd"):
+            opt = torch.optim.SGD
+        elif(self.hparams.optimizer == "adam"):
+            opt = torch.optim.AdamW
+        else:
+            raise NotImplementedError("Unknown Optimizer.")
+        
+        if(self.hparams.finetune and (self.hparams.linear_eval or self.hparams.train_head_only)):
+            optimizer = opt(self.model_cpc.head.parameters(), self.lr, weight_decay=self.hparams.weight_decay)
+        elif(self.hparams.finetune and self.hparams.discriminative_lr_factor != 1.):#discrimative lrs
+            optimizer = opt([{"params":self.model_cpc.encoder.parameters(), "lr":self.lr*self.hparams.discriminative_lr_factor*self.hparams.discriminative_lr_factor},{"params":self.model_cpc.rnn.parameters(), "lr":self.lr*self.hparams.discriminative_lr_factor},{"params":self.model_cpc.head.parameters(), "lr":self.lr}],self.hparams.lr, weight_decay=self.hparams.weight_decay)
+        else:
+            optimizer = opt(self.parameters(), self.lr, weight_decay=self.hparams.weight_decay)
+
+        return optimizer
+        
+    def load_weights_from_checkpoint(self, checkpoint):
+        """ Function that loads the weights from a given checkpoint file. 
+        based on https://github.com/PyTorchLightning/pytorch-lightning/issues/525
+        """
+        checkpoint = torch.load(checkpoint, map_location=lambda storage, loc: storage,)
+        pretrained_dict = checkpoint["state_dict"]
+        model_dict = self.state_dict()
+            
+        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
+        model_dict.update(pretrained_dict)
+        self.load_state_dict(model_dict)
+
+#####################################################################################################
+#ARGPARSERS
+#####################################################################################################
+def add_model_specific_args(parser):
+    parser.add_argument("--input-channels", type=int, default=12)
+    parser.add_argument("--normalize", action='store_true', help='Normalize input using PTB-XL stats')
+    parser.add_argument('--mlp', action='store_true', help="False: original CPC True: as in SimCLR")
+    parser.add_argument('--bias', action='store_true', help="original CPC: no bias")
+    parser.add_argument("--n-hidden", type=int, default=512)
+    parser.add_argument("--gru", action="store_true")
+    parser.add_argument("--n-layers", type=int, default=2)
+    parser.add_argument("--steps-predicted", dest="steps_predicted", type=int, default=12)
+    parser.add_argument("--n-false-negatives", dest="n_false_negatives", type=int, default=128)
+    parser.add_argument("--skip-encoder", action="store_true", help="disable the convolutional encoder i.e. just RNN; for testing")
+    parser.add_argument("--fc-encoder", action="store_true", help="use a fully connected encoder (as opposed to an encoder with strided convs)")
+    parser.add_argument("--negatives-from-same-seq-only", action="store_true", help="only draw false negatives from same sequence (as opposed to drawing from everywhere)")
+    parser.add_argument("--no-bn-encoder", action="store_true", help="switch off batch normalization in encoder")
+    parser.add_argument("--dropout-head", type=float, default=0.5)
+    parser.add_argument("--train-head-only", action="store_true", help="freeze everything except classification head (note: --linear-eval defaults to no hidden layer in classification head)")
+    parser.add_argument("--lin-ftrs-head", type=str, default="[512]", help="hidden layers in the classification head")
+    parser.add_argument('--no-bn-head', action='store_true', help="use no batch normalization in classification head")
+    return parser
+
+def add_default_args():
+    parser = argparse.ArgumentParser(description='PyTorch Lightning CPC Training')
+    parser.add_argument('--data', metavar='DIR',type=str,
+                        help='path(s) to dataset',action='append')
+    parser.add_argument('--epochs', default=30, type=int, metavar='N',
+                        help='number of total epochs to run')
+    parser.add_argument('--batch-size', default=64, type=int,
+                        metavar='N',
+                        help='mini-batch size (default: 256), this is the total '
+                             'batch size of all GPUs on the current node when '
+                             'using Data Parallel or Distributed Data Parallel')
+    parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float,
+                        metavar='LR', help='initial learning rate', dest='lr')
+    parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float,
+                        metavar='W', help='weight decay (default: 0.)',
+                        dest='weight_decay')
+
+    parser.add_argument('--resume', default='', type=str, metavar='PATH',
+                    help='path to latest checkpoint (default: none)')
+
+    parser.add_argument('--pretrained', default='', type=str, metavar='PATH',
+                    help='path to pretrained checkpoint (default: none)')
+    parser.add_argument('--optimizer', default='adam', help='sgd/adam')#was sgd
+    parser.add_argument('--output-path', default='.', type=str,dest="output_path",
+                        help='output path')
+    parser.add_argument('--metadata', default='', type=str,
+                        help='metadata for output')
+    
+    parser.add_argument("--gpus", type=int, default=1, help="number of gpus")
+    parser.add_argument("--num-nodes", dest="num_nodes", type=int, default=1, help="number of compute nodes")
+    parser.add_argument("--precision", type=int, default=16, help="16/32")
+    parser.add_argument("--distributed-backend", dest="distributed_backend", type=str, default=None, help="None/ddp")
+    parser.add_argument("--accumulate", type=int, default=1, help="accumulate grad batches (total-bs=accumulate-batches*bs)")
+        
+    parser.add_argument("--input-size", dest="input_size", type=int, default=16000)
+    
+    parser.add_argument("--finetune", action="store_true", help="finetuning (downstream classification task)",  default=False )
+    parser.add_argument("--linear-eval", action="store_true", help="linear evaluation instead of full finetuning",  default=False )
+    
+    parser.add_argument(
+        "--finetune-dataset",
+        type=str,
+        help="thew/ptbxl_super/ptbxl_all",
+        default="thew"
+    )
+    
+    parser.add_argument(
+        "--discriminative-lr-factor",
+        type=float,
+        help="factor by which the lr decreases per layer group during finetuning",
+        default=0.1
+    )
+    
+    
+    parser.add_argument(
+        "--lr-find",
+        action="store_true",
+        help="run lr finder before training run",
+        default=False
+    )
+    
+    return parser
+             
+###################################################################################################
+#MAIN
+###################################################################################################
+if __name__ == '__main__':
+    parser = add_default_args()
+    parser = add_model_specific_args(parser)
+    hparams = parser.parse_args()
+    hparams.executable = "cpc"
+
+    if not os.path.exists(hparams.output_path):
+        os.makedirs(hparams.output_path)
+        
+    model = LightningCPC(hparams)
+    
+    if(hparams.pretrained!=""):
+        print("Loading pretrained weights from",hparams.pretrained)
+        model.load_weights_from_checkpoint(hparams.pretrained)
+
+
+    logger = TensorBoardLogger(
+        save_dir=hparams.output_path,
+        #version="",#hparams.metadata.split(":")[0],
+        name="")
+    print("Output directory:",logger.log_dir)    
+    checkpoint_callback = ModelCheckpoint(
+        filepath=os.path.join(logger.log_dir,"best_model"),#hparams.output_path
+        save_top_k=1,
+		save_last=True,
+        verbose=True,
+        monitor='macro_auc_agg0' if hparams.finetune else 'val_loss',#val_loss/dataloader_idx_0
+        mode='max' if hparams.finetune else 'min',
+        prefix='')
+    lr_monitor = LearningRateMonitor()
+
+    trainer = pl.Trainer(
+        #overfit_batches=0.01,
+        auto_lr_find = hparams.lr_find,
+        accumulate_grad_batches=hparams.accumulate,
+        max_epochs=hparams.epochs,
+        min_epochs=hparams.epochs,
+        
+        default_root_dir=hparams.output_path,
+        
+        num_sanity_val_steps=0,
+        
+        logger=logger,
+        checkpoint_callback=checkpoint_callback,
+        callbacks = [],#lr_monitor],
+        benchmark=True,
+    
+        gpus=hparams.gpus,
+        num_nodes=hparams.num_nodes,
+        precision=hparams.precision,
+        distributed_backend=hparams.distributed_backend,
+        
+        progress_bar_refresh_rate=0,
+        weights_summary='top',
+        resume_from_checkpoint= None if hparams.resume=="" else hparams.resume)
+        
+    if(hparams.lr_find):#lr find
+        trainer.tune(model)
+        
+    trainer.fit(model)
+    
+