Diff of /main_cpc_lightning.py [000000] .. [134fd7]

Switch to unified view

a b/main_cpc_lightning.py
1
###############
2
#generic
3
import torch
4
from torch import nn
5
import pytorch_lightning as pl
6
from torch.utils.data import DataLoader, ConcatDataset
7
from torchvision import transforms
8
import torch.nn.functional as F
9
10
import torchvision
11
import os
12
import argparse
13
from pytorch_lightning.loggers import TensorBoardLogger
14
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
15
import copy
16
17
#################
18
#specific
19
from clinical_ts.timeseries_utils import *
20
from clinical_ts.ecg_utils import *
21
22
from functools import partial
23
from pathlib import Path
24
import pandas as pd
25
import numpy as np
26
27
from clinical_ts.xresnet1d import xresnet1d50,xresnet1d101
28
from clinical_ts.basic_conv1d import weight_init
29
from clinical_ts.eval_utils_cafa import eval_scores, eval_scores_bootstrap
30
from clinical_ts.cpc import *
31
32
def _freeze_bn_stats(model, freeze=True):
33
    for m in model.modules():
34
        if(isinstance(m,nn.BatchNorm1d)):
35
            if(freeze):
36
                m.eval()
37
            else:
38
                m.train()
39
                
40
def sanity_check(model, state_dict_pre):
41
    """
42
    Linear classifier should not change any weights other than the linear layer.
43
    This sanity check asserts nothing wrong happens (e.g., BN stats updated).
44
    """
45
    print("=> loading state dict for sanity check")
46
    state_dict = model.state_dict()
47
48
    for k in list(state_dict.keys()):
49
        print(k)
50
        # only ignore fc layer
51
        if 'head.1.weight' in k or 'head.1.bias' in k:
52
            continue
53
54
55
        assert ((state_dict[k].cpu() == state_dict_pre[k].cpu()).all()), \
56
            '{} is changed in linear classifier training.'.format(k)
57
58
    print("=> sanity check passed.")
59
60
class LightningCPC(pl.LightningModule):
61
62
    def __init__(self, hparams):
63
        super(LightningCPC, self).__init__()
64
        
65
        self.hparams = hparams
66
        self.lr = self.hparams.lr
67
        
68
        #these coincide with the adapted wav2vec2 params
69
        if(self.hparams.fc_encoder):
70
            strides=[1]*4 
71
            kss = [1]*4 
72
            features = [512]*4
73
        else: #strided conv encoder
74
            strides=[2,2,2,2] #original wav2vec2 [5,2,2,2,2,2] original cpc [5,4,2,2,2]
75
            kss = [10,4,4,4] #original wav2vec2 [10,3,3,3,3,2] original cpc [18,8,4,4,4]
76
            features = [512]*4 #wav2vec2 [512]*6 original cpc [512]*5
77
        
78
        if(self.hparams.finetune):
79
            self.criterion = F.cross_entropy if self.hparams.finetune_dataset == "thew" else F.binary_cross_entropy_with_logits
80
            if(self.hparams.finetune_dataset == "thew"):
81
                num_classes = 5
82
            elif(self.hparams.finetune_dataset == "ptbxl_super"):
83
                num_classes = 5
84
            if(self.hparams.finetune_dataset == "ptbxl_all"):
85
                num_classes = 71
86
        else:
87
            num_classes = None
88
89
        self.model_cpc = CPCModel(input_channels=self.hparams.input_channels, strides=strides,kss=kss,features=features,n_hidden=self.hparams.n_hidden,n_layers=self.hparams.n_layers,mlp=self.hparams.mlp,lstm=not(self.hparams.gru),bias_proj=self.hparams.bias,num_classes=num_classes,skip_encoder=self.hparams.skip_encoder,bn_encoder=not(self.hparams.no_bn_encoder),lin_ftrs_head=[] if self.hparams.linear_eval else eval(self.hparams.lin_ftrs_head),ps_head=0 if self.hparams.linear_eval else self.hparams.dropout_head,bn_head=False if self.hparams.linear_eval else not(self.hparams.no_bn_head))
90
        
91
        target_fs=100
92
        if(not(self.hparams.finetune)):
93
            print("CPC pretraining:\ndownsampling factor:",self.model_cpc.encoder_downsampling_factor,"\nchunk length(s)",self.model_cpc.encoder_downsampling_factor/target_fs,"\npixels predicted ahead:",self.model_cpc.encoder_downsampling_factor*self.hparams.steps_predicted,"\nseconds predicted ahead:",self.model_cpc.encoder_downsampling_factor*self.hparams.steps_predicted/target_fs,"\nRNN input size:",self.hparams.input_size//self.model_cpc.encoder_downsampling_factor)
94
95
    def forward(self, x):
96
        return self.model_cpc(x)
97
        
98
    def _step(self,data_batch, batch_idx, train):       
99
        if(self.hparams.finetune):
100
            preds = self.forward(data_batch[0])
101
            loss = self.criterion(preds,data_batch[1])
102
            self.log("train_loss" if train else "val_loss", loss)
103
            return {'loss':loss, "preds":preds.detach(), "targs": data_batch[1]}
104
        else:
105
            loss, acc = self.model_cpc.cpc_loss(data_batch[0],steps_predicted=self.hparams.steps_predicted,n_false_negatives=self.hparams.n_false_negatives, negatives_from_same_seq_only=self.hparams.negatives_from_same_seq_only, eval_acc=True)
106
            self.log("loss" if train else "val_loss", loss)
107
            self.log("acc" if train else "val_acc", acc)
108
            return loss
109
      
110
    def training_step(self, train_batch, batch_idx):
111
        if(self.hparams.linear_eval):
112
            _freeze_bn_stats(self)
113
        return self._step(train_batch,batch_idx,True)
114
        
115
    def validation_step(self, val_batch, batch_idx, dataloader_idx=0):
116
        return self._step(val_batch,batch_idx,False)
117
        
118
    def validation_epoch_end(self, outputs_all):
119
        if(self.hparams.finetune):
120
            for dataloader_idx,outputs in enumerate(outputs_all): #multiple val dataloaders
121
                preds_all = torch.cat([x['preds'] for x in outputs])
122
                targs_all = torch.cat([x['targs'] for x in outputs])
123
                if(self.hparams.finetune_dataset=="thew"):
124
                    preds_all = F.softmax(preds_all,dim=-1)
125
                    targs_all = torch.eye(len(self.lbl_itos))[targs_all].to(preds.device) 
126
                else:
127
                    preds_all = torch.sigmoid(preds_all)
128
                preds_all = preds_all.cpu().numpy()
129
                targs_all = targs_all.cpu().numpy()
130
                #instance level score
131
                res = eval_scores(targs_all,preds_all,classes=self.lbl_itos)
132
                
133
                idmap = self.val_dataset.get_id_mapping()
134
                preds_all_agg,targs_all_agg = aggregate_predictions(preds_all,targs_all,idmap,aggregate_fn=np.mean)
135
                res_agg = eval_scores(targs_all_agg,preds_all_agg,classes=self.lbl_itos)
136
                self.log_dict({"macro_auc_agg"+str(dataloader_idx):res_agg["label_AUC"]["macro"], "macro_auc_noagg"+str(dataloader_idx):res["label_AUC"]["macro"]})
137
                print("epoch",self.current_epoch,"macro_auc_agg"+str(dataloader_idx)+":",res_agg["label_AUC"]["macro"],"macro_auc_noagg"+str(dataloader_idx)+":",res["label_AUC"]["macro"])
138
139
140
    def on_fit_start(self):
141
        if(self.hparams.linear_eval):
142
            print("copying state dict before training for sanity check after training")   
143
            self.state_dict_pre = copy.deepcopy(self.state_dict().copy())
144
145
    
146
    def on_fit_end(self):
147
        if(self.hparams.linear_eval):
148
            sanity_check(self,self.state_dict_pre)
149
            
150
            
151
    def setup(self, stage):
152
        # configure dataset params
153
        chunkify_train = False
154
        chunk_length_train = self.hparams.input_size if chunkify_train else 0
155
        stride_train = self.hparams.input_size
156
        
157
        chunkify_valtest = True
158
        chunk_length_valtest = self.hparams.input_size if chunkify_valtest else 0
159
        stride_valtest = self.hparams.input_size//2
160
161
        train_datasets = []
162
        val_datasets = []
163
        test_datasets = []
164
        
165
        for i,target_folder in enumerate(self.hparams.data):
166
            target_folder = Path(target_folder)           
167
            
168
            df_mapped, lbl_itos,  mean, std = load_dataset(target_folder)
169
            # always use PTB-XL stats
170
            mean = np.array([-0.00184586, -0.00130277,  0.00017031, -0.00091313, -0.00148835,  -0.00174687, -0.00077071, -0.00207407,  0.00054329,  0.00155546,  -0.00114379, -0.00035649])
171
            std = np.array([0.16401004, 0.1647168 , 0.23374124, 0.33767231, 0.33362807,  0.30583013, 0.2731171 , 0.27554379, 0.17128962, 0.14030828,   0.14606956, 0.14656108])
172
            
173
            #specific for PTB-XL
174
            if(self.hparams.finetune and self.hparams.finetune_dataset.startswith("ptbxl")):
175
                if(self.hparams.finetune_dataset=="ptbxl_super"):
176
                    ptb_xl_label = "label_diag_superclass"
177
                elif(self.hparams.finetune_dataset=="ptbxl_all"):
178
                    ptb_xl_label = "label_all"
179
                    
180
                lbl_itos= np.array(lbl_itos[ptb_xl_label])
181
                
182
                def multihot_encode(x, num_classes):
183
                    res = np.zeros(num_classes,dtype=np.float32)
184
                    for y in x:
185
                        res[y]=1
186
                    return res
187
                    
188
                df_mapped["label"]= df_mapped[ptb_xl_label+"_filtered_numeric"].apply(lambda x: multihot_encode(x,len(lbl_itos)))
189
                    
190
            
191
            self.lbl_itos = lbl_itos
192
            tfms_ptb_xl_cpc = ToTensor() if self.hparams.normalize is False else transforms.Compose([Normalize(mean,std),ToTensor()])
193
            
194
            max_fold_id = df_mapped.strat_fold.max() #unfortunately 1-based for PTB-XL; sometimes 100 (Ribeiro)
195
            df_train = df_mapped[df_mapped.strat_fold<(max_fold_id-1 if self.hparams.finetune else max_fold_id)]
196
            df_val = df_mapped[df_mapped.strat_fold==(max_fold_id-1 if self.hparams.finetune else max_fold_id)]
197
            if(self.hparams.finetune):
198
                df_test = df_mapped[df_mapped.strat_fold==max_fold_id]
199
            train_datasets.append(TimeseriesDatasetCrops(df_train,self.hparams.input_size,num_classes=len(lbl_itos),data_folder=target_folder,chunk_length=chunk_length_train,min_chunk_length=self.hparams.input_size, stride=stride_train,transforms=tfms_ptb_xl_cpc,annotation=False,col_lbl ="label" if self.hparams.finetune else None,memmap_filename=target_folder/("memmap.npy")))
200
            val_datasets.append(TimeseriesDatasetCrops(df_val,self.hparams.input_size,num_classes=len(lbl_itos),data_folder=target_folder,chunk_length=chunk_length_valtest,min_chunk_length=self.hparams.input_size, stride=stride_valtest,transforms=tfms_ptb_xl_cpc,annotation=False,col_lbl ="label" if self.hparams.finetune else None,memmap_filename=target_folder/("memmap.npy")))
201
            if(self.hparams.finetune):
202
                test_datasets.append(TimeseriesDatasetCrops(df_test,self.hparams.input_size,num_classes=len(lbl_itos),data_folder=target_folder,chunk_length=chunk_length_valtest,min_chunk_length=self.hparams.input_size, stride=stride_valtest,transforms=tfms_ptb_xl_cpc,annotation=False,col_lbl ="label",memmap_filename=target_folder/("memmap.npy")))
203
            
204
            print("\n",target_folder)
205
            print("train dataset:",len(train_datasets[-1]),"samples")
206
            print("val dataset:",len(val_datasets[-1]),"samples")
207
            if(self.hparams.finetune):
208
                print("test dataset:",len(test_datasets[-1]),"samples")
209
210
        if(len(train_datasets)>1): #multiple data folders
211
            print("\nCombined:")
212
            self.train_dataset = ConcatDataset(train_datasets)
213
            self.val_dataset = ConcatDataset(val_datasets)
214
            print("train dataset:",len(self.train_dataset),"samples")
215
            print("val dataset:",len(self.val_dataset),"samples")
216
            if(self.hparams.finetune):
217
                self.test_dataset = ConcatDataset(test_datasets)
218
                print("test dataset:",len(self.test_dataset),"samples")
219
        else: #just a single data folder
220
            self.train_dataset = train_datasets[0]
221
            self.val_dataset = val_datasets[0]
222
            if(self.hparams.finetune):
223
                self.test_dataset = test_datasets[0]
224
    
225
    def train_dataloader(self):
226
        return DataLoader(self.train_dataset, batch_size=self.hparams.batch_size, num_workers=4, shuffle=True, drop_last = True)
227
        
228
    def val_dataloader(self):
229
        if(self.hparams.finetune):#multiple val dataloaders
230
            return [DataLoader(self.val_dataset, batch_size=self.hparams.batch_size, num_workers=4),DataLoader(self.test_dataset, batch_size=self.hparams.batch_size, num_workers=4)]
231
        else:
232
            return DataLoader(self.val_dataset, batch_size=self.hparams.batch_size, num_workers=4)
233
234
    def configure_optimizers(self):
235
        if(self.hparams.optimizer == "sgd"):
236
            opt = torch.optim.SGD
237
        elif(self.hparams.optimizer == "adam"):
238
            opt = torch.optim.AdamW
239
        else:
240
            raise NotImplementedError("Unknown Optimizer.")
241
        
242
        if(self.hparams.finetune and (self.hparams.linear_eval or self.hparams.train_head_only)):
243
            optimizer = opt(self.model_cpc.head.parameters(), self.lr, weight_decay=self.hparams.weight_decay)
244
        elif(self.hparams.finetune and self.hparams.discriminative_lr_factor != 1.):#discrimative lrs
245
            optimizer = opt([{"params":self.model_cpc.encoder.parameters(), "lr":self.lr*self.hparams.discriminative_lr_factor*self.hparams.discriminative_lr_factor},{"params":self.model_cpc.rnn.parameters(), "lr":self.lr*self.hparams.discriminative_lr_factor},{"params":self.model_cpc.head.parameters(), "lr":self.lr}],self.hparams.lr, weight_decay=self.hparams.weight_decay)
246
        else:
247
            optimizer = opt(self.parameters(), self.lr, weight_decay=self.hparams.weight_decay)
248
249
        return optimizer
250
        
251
    def load_weights_from_checkpoint(self, checkpoint):
252
        """ Function that loads the weights from a given checkpoint file. 
253
        based on https://github.com/PyTorchLightning/pytorch-lightning/issues/525
254
        """
255
        checkpoint = torch.load(checkpoint, map_location=lambda storage, loc: storage,)
256
        pretrained_dict = checkpoint["state_dict"]
257
        model_dict = self.state_dict()
258
            
259
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
260
        model_dict.update(pretrained_dict)
261
        self.load_state_dict(model_dict)
262
263
#####################################################################################################
264
#ARGPARSERS
265
#####################################################################################################
266
def add_model_specific_args(parser):
267
    parser.add_argument("--input-channels", type=int, default=12)
268
    parser.add_argument("--normalize", action='store_true', help='Normalize input using PTB-XL stats')
269
    parser.add_argument('--mlp', action='store_true', help="False: original CPC True: as in SimCLR")
270
    parser.add_argument('--bias', action='store_true', help="original CPC: no bias")
271
    parser.add_argument("--n-hidden", type=int, default=512)
272
    parser.add_argument("--gru", action="store_true")
273
    parser.add_argument("--n-layers", type=int, default=2)
274
    parser.add_argument("--steps-predicted", dest="steps_predicted", type=int, default=12)
275
    parser.add_argument("--n-false-negatives", dest="n_false_negatives", type=int, default=128)
276
    parser.add_argument("--skip-encoder", action="store_true", help="disable the convolutional encoder i.e. just RNN; for testing")
277
    parser.add_argument("--fc-encoder", action="store_true", help="use a fully connected encoder (as opposed to an encoder with strided convs)")
278
    parser.add_argument("--negatives-from-same-seq-only", action="store_true", help="only draw false negatives from same sequence (as opposed to drawing from everywhere)")
279
    parser.add_argument("--no-bn-encoder", action="store_true", help="switch off batch normalization in encoder")
280
    parser.add_argument("--dropout-head", type=float, default=0.5)
281
    parser.add_argument("--train-head-only", action="store_true", help="freeze everything except classification head (note: --linear-eval defaults to no hidden layer in classification head)")
282
    parser.add_argument("--lin-ftrs-head", type=str, default="[512]", help="hidden layers in the classification head")
283
    parser.add_argument('--no-bn-head', action='store_true', help="use no batch normalization in classification head")
284
    return parser
285
286
def add_default_args():
287
    parser = argparse.ArgumentParser(description='PyTorch Lightning CPC Training')
288
    parser.add_argument('--data', metavar='DIR',type=str,
289
                        help='path(s) to dataset',action='append')
290
    parser.add_argument('--epochs', default=30, type=int, metavar='N',
291
                        help='number of total epochs to run')
292
    parser.add_argument('--batch-size', default=64, type=int,
293
                        metavar='N',
294
                        help='mini-batch size (default: 256), this is the total '
295
                             'batch size of all GPUs on the current node when '
296
                             'using Data Parallel or Distributed Data Parallel')
297
    parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float,
298
                        metavar='LR', help='initial learning rate', dest='lr')
299
    parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float,
300
                        metavar='W', help='weight decay (default: 0.)',
301
                        dest='weight_decay')
302
303
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
304
                    help='path to latest checkpoint (default: none)')
305
306
    parser.add_argument('--pretrained', default='', type=str, metavar='PATH',
307
                    help='path to pretrained checkpoint (default: none)')
308
    parser.add_argument('--optimizer', default='adam', help='sgd/adam')#was sgd
309
    parser.add_argument('--output-path', default='.', type=str,dest="output_path",
310
                        help='output path')
311
    parser.add_argument('--metadata', default='', type=str,
312
                        help='metadata for output')
313
    
314
    parser.add_argument("--gpus", type=int, default=1, help="number of gpus")
315
    parser.add_argument("--num-nodes", dest="num_nodes", type=int, default=1, help="number of compute nodes")
316
    parser.add_argument("--precision", type=int, default=16, help="16/32")
317
    parser.add_argument("--distributed-backend", dest="distributed_backend", type=str, default=None, help="None/ddp")
318
    parser.add_argument("--accumulate", type=int, default=1, help="accumulate grad batches (total-bs=accumulate-batches*bs)")
319
        
320
    parser.add_argument("--input-size", dest="input_size", type=int, default=16000)
321
    
322
    parser.add_argument("--finetune", action="store_true", help="finetuning (downstream classification task)",  default=False )
323
    parser.add_argument("--linear-eval", action="store_true", help="linear evaluation instead of full finetuning",  default=False )
324
    
325
    parser.add_argument(
326
        "--finetune-dataset",
327
        type=str,
328
        help="thew/ptbxl_super/ptbxl_all",
329
        default="thew"
330
    )
331
    
332
    parser.add_argument(
333
        "--discriminative-lr-factor",
334
        type=float,
335
        help="factor by which the lr decreases per layer group during finetuning",
336
        default=0.1
337
    )
338
    
339
    
340
    parser.add_argument(
341
        "--lr-find",
342
        action="store_true",
343
        help="run lr finder before training run",
344
        default=False
345
    )
346
    
347
    return parser
348
             
349
###################################################################################################
350
#MAIN
351
###################################################################################################
352
if __name__ == '__main__':
353
    parser = add_default_args()
354
    parser = add_model_specific_args(parser)
355
    hparams = parser.parse_args()
356
    hparams.executable = "cpc"
357
358
    if not os.path.exists(hparams.output_path):
359
        os.makedirs(hparams.output_path)
360
        
361
    model = LightningCPC(hparams)
362
    
363
    if(hparams.pretrained!=""):
364
        print("Loading pretrained weights from",hparams.pretrained)
365
        model.load_weights_from_checkpoint(hparams.pretrained)
366
367
368
    logger = TensorBoardLogger(
369
        save_dir=hparams.output_path,
370
        #version="",#hparams.metadata.split(":")[0],
371
        name="")
372
    print("Output directory:",logger.log_dir)    
373
    checkpoint_callback = ModelCheckpoint(
374
        filepath=os.path.join(logger.log_dir,"best_model"),#hparams.output_path
375
        save_top_k=1,
376
        save_last=True,
377
        verbose=True,
378
        monitor='macro_auc_agg0' if hparams.finetune else 'val_loss',#val_loss/dataloader_idx_0
379
        mode='max' if hparams.finetune else 'min',
380
        prefix='')
381
    lr_monitor = LearningRateMonitor()
382
383
    trainer = pl.Trainer(
384
        #overfit_batches=0.01,
385
        auto_lr_find = hparams.lr_find,
386
        accumulate_grad_batches=hparams.accumulate,
387
        max_epochs=hparams.epochs,
388
        min_epochs=hparams.epochs,
389
        
390
        default_root_dir=hparams.output_path,
391
        
392
        num_sanity_val_steps=0,
393
        
394
        logger=logger,
395
        checkpoint_callback=checkpoint_callback,
396
        callbacks = [],#lr_monitor],
397
        benchmark=True,
398
    
399
        gpus=hparams.gpus,
400
        num_nodes=hparams.num_nodes,
401
        precision=hparams.precision,
402
        distributed_backend=hparams.distributed_backend,
403
        
404
        progress_bar_refresh_rate=0,
405
        weights_summary='top',
406
        resume_from_checkpoint= None if hparams.resume=="" else hparams.resume)
407
        
408
    if(hparams.lr_find):#lr find
409
        trainer.tune(model)
410
        
411
    trainer.fit(model)
412
    
413