Diff of /train.py [000000] .. [1928b6]

Switch to side-by-side view

--- a
+++ b/train.py
@@ -0,0 +1,253 @@
+import dataset
+import utils
+from utils import EarlyStopping, LRScheduler
+import os
+import pandas as pd
+import argparse
+import torch.backends.cudnn as cudnn
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.transforms as transforms
+import numpy as np
+import time
+
+parser = argparse.ArgumentParser(description='PET lymphoma classification')
+
+#I/O PARAMS
+parser.add_argument('--output', type=str, default='results', help='name of output folder (default: "results")')
+
+#MODEL PARAMS
+parser.add_argument('--normalize', action='store_true', default=False, help='normalize images')
+parser.add_argument('--checkpoint', default='', type=str, help='model checkpoint if any (default: none)')
+parser.add_argument('--resume', action='store_true', default=False, help='resume from checkpoint')
+
+#OPTIMIZATION PARAMS
+parser.add_argument('--optimizer', default='sgd', type=str, help='The optimizer to use (default: sgd)')
+parser.add_argument('--lr', type=float, default=1e-3, help='learning rate (default: 1e-3)')
+parser.add_argument('--lr_anneal', type=int, default=15, help='period for lr annealing (default: 15). Only works for SGD')
+parser.add_argument('--momentum', default=0.9, type=float, help='momentum (default: 0.9)')
+parser.add_argument('--wd', default=1e-4, type=float, help='weight decay (default: 1e-4)')
+
+#TRAINING PARAMS
+parser.add_argument('--split_index', default=0, type=int, metavar='INT', choices=list(range(0,20)),help='which split index (default: 0)')   
+parser.add_argument('--run', default=1, type=int, metavar='INT', help='repetition run with same settings (default: 1)')   
+parser.add_argument('--batch_size', type=int, default=50, help='how many images to sample per slide (default: 50)')
+parser.add_argument('--nepochs', type=int, default=40, help='number of epochs (default: 40)')
+parser.add_argument('--workers', default=10, type=int, help='number of data loading workers (default: 10)')
+parser.add_argument('--augm', default=0, type=int, choices=[0,1,2,3,12,4,5,14,34,45], help='augmentation procedure 0=none,1=flip,2=rot,3=flip LR, 12=flip+rot, 4=scale, 5=noise, 14=flip+scale, 34=flipLR+scale, 45=scale+noise (default: 0)')
+parser.add_argument('--balance', action='store_true', default=False, help='balance dataset (balance loss)')
+parser.add_argument('--lr_scheduler', action='store_true',default=False, help='decrease LR on platau')
+parser.add_argument('--early_stopping', action='store_true',default=False, help='use early stopping')
+
+def main():
+    ### Get user input
+    global args
+    args = parser.parse_args()
+    print(args)
+    best_auc = 0.
+
+    ### Output directory and files
+    if not os.path.isdir(args.output):
+        try:
+            os.mkdir(args.output)
+        except OSError:
+            print ('Creation of the output directory "{}" failed.'.format(args.output))
+        else:
+            print ('Successfully created the output directory "{}".'.format(args.output))
+    
+    ### Get model
+    model = utils.get_model()
+    if args.checkpoint:
+        ch = torch.load(args.checkpoint)
+        model_dict = model.state_dict()
+        pretrained_dict = {k: v for k, v in ch['state_dict'].items() if k in model_dict}
+        print('Loaded [{}/{}] keys from checkpoint'.format(len(pretrained_dict),len(model_dict)))
+        model_dict.update(pretrained_dict)
+        model.load_state_dict(model_dict)
+    if args.resume:
+        ch = torch.load( os.path.join(args.output,'checkpoint_split'+str(args.split_index)+'_run'+str(args.run)+'.pth') )
+        model_dict = model.state_dict()
+        pretrained_dict = {k: v for k, v in ch['state_dict'].items() if k in model_dict}
+        print('Loaded [{}/{}] keys from checkpoint'.format(len(pretrained_dict),len(model_dict)))
+        model_dict.update(pretrained_dict)
+        model.load_state_dict(model_dict)
+    
+    ### Set optimizer
+    optimizer = utils.create_optimizer(model, args.optimizer, args.lr, args.momentum, args.wd)
+    if args.resume and 'optimizer' in ch:
+        optimizer.load_state_dict(ch['optimizer'])
+        print('Loaded optimizer state')
+    cudnn.benchmark = True
+    
+    ### Augmentations
+    flipHorVer = dataset.RandomFlip()
+    flipLR     = dataset.RandomFlipLeftRight()
+    rot90      = dataset.RandomRot90()
+    scale      = dataset.RandomScale()
+    noise      = dataset.RandomNoise()
+    if args.augm==0:
+        transform = None
+    elif args.augm==1:
+        transform = transforms.Compose([flipHorVer])
+    elif args.augm==2:
+        transform = transforms.Compose([rot90])
+    elif args.augm==3:
+        transform = transforms.Compose([flipLR])
+    elif args.augm==12:
+        transform = transforms.Compose([flipHorVer,rot90])
+    elif args.augm==4:
+        transform = transforms.Compose([scale])
+    elif args.augm==5:
+        transform = transforms.Compose([noise])
+    elif args.augm==14:
+        transform = transforms.Compose([flip,scale])
+    elif args.augm==34:
+        transform = transforms.Compose([flipLR,scale])
+    elif args.augm==45:
+        transform = transforms.Compose([scale,noise])
+    
+    ### Set datasets
+    train_dset,trainval_dset,val_dset,_,balance_weight_neg_pos = dataset.get_datasets_singleview(transform,args.normalize,args.balance,args.split_index)
+    print('Datasets train:{}, val:{}'.format(len(train_dset.df),len(val_dset.df))) 
+    
+    ### Set loss criterion
+    if args.balance:
+        w = torch.Tensor(balance_weight_neg_pos)
+        print('Balance loss with weights:',balance_weight_neg_pos)
+        criterion = nn.BCEWithLogitsLoss(pos_weight=w).cuda()
+    else:
+        criterion = nn.BCEWithLogitsLoss().cuda()
+    
+    ### Early stopping
+    if args.lr_scheduler:
+        print('INFO: Initializing learning rate scheduler')
+        lr_scheduler = LRScheduler(optimizer)
+        if args.resume and 'lr_scheduler' in ch:
+            lr_scheduler.lr_scheduler.load_state_dict(ch['lr_scheduler'])
+            print('Loaded lr_scheduler state')
+    if args.early_stopping:
+        print('INFO: Initializing early stopping')
+        early_stopping = EarlyStopping()
+        if args.resume and 'early_stopping' in ch:
+            early_stopping.best_loss = ch['early_stopping']['best_loss']
+            early_stopping.counter = ch['early_stopping']['counter']
+            early_stopping.min_delta = ch['early_stopping']['min_delta']
+            early_stopping.patience = ch['early_stopping']['patience']
+            early_stopping.early_stop = ch['early_stopping']['early_stop']
+            print('Loaded early_stopping state')
+        
+    ### Set loaders
+    train_loader = torch.utils.data.DataLoader(train_dset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)
+    trainval_loader = torch.utils.data.DataLoader(trainval_dset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
+    val_loader = torch.utils.data.DataLoader(val_dset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
+    
+    ### Set output files
+    convergence_name = 'convergence_split'+str(args.split_index)+'_run'+str(args.run)+'.csv'
+    if not args.resume:
+        fconv = open(os.path.join(args.output,convergence_name), 'w')
+        fconv.write('epoch,split,metric,value\n')
+        fconv.close()
+    
+    ### Main training loop
+    if args.resume:
+        epochs = range(ch['epoch']+1,args.nepochs+1)
+    else:
+        epochs = range(args.nepochs+1)
+    
+    for epoch in epochs:
+        if args.optimizer == 'sgd':
+            utils.adjust_learning_rate(optimizer, epoch, args.lr_anneal, args.lr)
+        
+        ### Training logic
+        if epoch > 0:
+            loss = train(epoch, train_loader, model, criterion, optimizer)
+        else:
+            loss = np.nan
+        ### Printing stats
+        fconv = open(os.path.join(args.output,convergence_name), 'a')
+        fconv.write('{},train,loss,{}\n'.format(epoch, loss))
+        fconv.close()
+        
+        ### Validation logic
+        # Evaluate on train data
+        train_probs = test(epoch, trainval_loader, model)
+        train_auc, train_ber, train_fpr, train_fnr = train_dset.errors(train_probs)
+        # Evaluate on validation set
+        val_probs = test(epoch, val_loader, model)
+        val_auc, val_ber, val_fpr, val_fnr = val_dset.errors(val_probs)
+        
+        print('Epoch: [{}/{}]\tLoss: {:.6f}\tAUC: {:.4f}\t{:.4f}'.format(epoch, args.nepochs, loss, train_auc, val_auc))
+        
+        fconv = open(os.path.join(args.output,convergence_name), 'a')
+        fconv.write('{},train,auc,{}\n'.format(epoch, train_auc))
+        fconv.write('{},train,ber,{}\n'.format(epoch, train_ber))
+        fconv.write('{},train,fpr,{}\n'.format(epoch, train_fpr))
+        fconv.write('{},train,fnr,{}\n'.format(epoch, train_fnr))
+        fconv.write('{},validation,auc,{}\n'.format(epoch, val_auc))
+        fconv.write('{},validation,ber,{}\n'.format(epoch, val_ber))
+        fconv.write('{},validation,fpr,{}\n'.format(epoch, val_fpr))
+        fconv.write('{},validation,fnr,{}\n'.format(epoch, val_fnr))
+        fconv.close()
+        
+        ### Create checkpoint dictionary
+        obj = {
+            'epoch': epoch,
+            'state_dict': model.state_dict(),
+            'optimizer' : optimizer.state_dict(),
+            'lr_scheduler' : lr_scheduler.lr_scheduler.state_dict(),
+            'early_stopping' : {'best_loss':early_stopping.best_loss,'counter':early_stopping.counter,'early_stop':early_stopping.early_stop,'min_delta': early_stopping.min_delta,'patience': early_stopping.patience},
+            'auc': val_auc,
+        }
+        ### Save checkpoint
+        torch.save(obj, os.path.join(args.output,'checkpoint_split'+str(args.split_index)+'_run'+str(args.run)+'.pth'))
+        
+        ### Early stopping
+        if args.lr_scheduler:
+            lr_scheduler(-val_auc)
+        if args.early_stopping:
+            early_stopping(-val_auc)
+            if early_stopping.early_stop:
+                break
+
+def test(epoch, loader, model):
+    # Set model in test mode
+    model.eval()
+    # Initialize probability vector
+    probs = torch.FloatTensor(len(loader.dataset)).cuda()
+    # Loop through batches
+    with torch.no_grad():
+        for i, (input,_) in enumerate(loader):
+            ## Copy batch to GPU
+            input = input.cuda()
+            ## Forward pass
+            y = model(input) #features, probabilities
+            p = F.softmax(y,dim=1)
+            ## Clone output to output vector
+            probs[i*args.batch_size:i*args.batch_size+input.size(0)] = p.detach()[:,1].clone()
+    return probs.cpu().numpy()
+
+def train(epoch, loader, model, criterion, optimizer):
+    # Set model in training mode
+    model.train()
+    # Initialize loss
+    running_loss = 0.
+    # Loop through batches
+    for i, (input,target) in enumerate(loader):
+        ## Copy to GPU
+        input = input.cuda()
+        target_1hot = F.one_hot(target.long(),num_classes=2).cuda()
+        ## Forward pass
+        y = model(input) #features, probabilities
+        ## Calculate loss
+        loss = criterion(y, target_1hot.float())
+        ## Optimization step
+        optimizer.zero_grad()
+        loss.backward()
+        optimizer.step()
+        ## Store loss
+        running_loss += loss.item()*input.size(0)
+    return running_loss/len(loader.dataset)
+
+if __name__ == '__main__':
+    main()