--- a +++ b/train.py @@ -0,0 +1,253 @@ +import dataset +import utils +from utils import EarlyStopping, LRScheduler +import os +import pandas as pd +import argparse +import torch.backends.cudnn as cudnn +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms as transforms +import numpy as np +import time + +parser = argparse.ArgumentParser(description='PET lymphoma classification') + +#I/O PARAMS +parser.add_argument('--output', type=str, default='results', help='name of output folder (default: "results")') + +#MODEL PARAMS +parser.add_argument('--normalize', action='store_true', default=False, help='normalize images') +parser.add_argument('--checkpoint', default='', type=str, help='model checkpoint if any (default: none)') +parser.add_argument('--resume', action='store_true', default=False, help='resume from checkpoint') + +#OPTIMIZATION PARAMS +parser.add_argument('--optimizer', default='sgd', type=str, help='The optimizer to use (default: sgd)') +parser.add_argument('--lr', type=float, default=1e-3, help='learning rate (default: 1e-3)') +parser.add_argument('--lr_anneal', type=int, default=15, help='period for lr annealing (default: 15). Only works for SGD') +parser.add_argument('--momentum', default=0.9, type=float, help='momentum (default: 0.9)') +parser.add_argument('--wd', default=1e-4, type=float, help='weight decay (default: 1e-4)') + +#TRAINING PARAMS +parser.add_argument('--split_index', default=0, type=int, metavar='INT', choices=list(range(0,20)),help='which split index (default: 0)') +parser.add_argument('--run', default=1, type=int, metavar='INT', help='repetition run with same settings (default: 1)') +parser.add_argument('--batch_size', type=int, default=50, help='how many images to sample per slide (default: 50)') +parser.add_argument('--nepochs', type=int, default=40, help='number of epochs (default: 40)') +parser.add_argument('--workers', default=10, type=int, help='number of data loading workers (default: 10)') +parser.add_argument('--augm', default=0, type=int, choices=[0,1,2,3,12,4,5,14,34,45], help='augmentation procedure 0=none,1=flip,2=rot,3=flip LR, 12=flip+rot, 4=scale, 5=noise, 14=flip+scale, 34=flipLR+scale, 45=scale+noise (default: 0)') +parser.add_argument('--balance', action='store_true', default=False, help='balance dataset (balance loss)') +parser.add_argument('--lr_scheduler', action='store_true',default=False, help='decrease LR on platau') +parser.add_argument('--early_stopping', action='store_true',default=False, help='use early stopping') + +def main(): + ### Get user input + global args + args = parser.parse_args() + print(args) + best_auc = 0. + + ### Output directory and files + if not os.path.isdir(args.output): + try: + os.mkdir(args.output) + except OSError: + print ('Creation of the output directory "{}" failed.'.format(args.output)) + else: + print ('Successfully created the output directory "{}".'.format(args.output)) + + ### Get model + model = utils.get_model() + if args.checkpoint: + ch = torch.load(args.checkpoint) + model_dict = model.state_dict() + pretrained_dict = {k: v for k, v in ch['state_dict'].items() if k in model_dict} + print('Loaded [{}/{}] keys from checkpoint'.format(len(pretrained_dict),len(model_dict))) + model_dict.update(pretrained_dict) + model.load_state_dict(model_dict) + if args.resume: + ch = torch.load( os.path.join(args.output,'checkpoint_split'+str(args.split_index)+'_run'+str(args.run)+'.pth') ) + model_dict = model.state_dict() + pretrained_dict = {k: v for k, v in ch['state_dict'].items() if k in model_dict} + print('Loaded [{}/{}] keys from checkpoint'.format(len(pretrained_dict),len(model_dict))) + model_dict.update(pretrained_dict) + model.load_state_dict(model_dict) + + ### Set optimizer + optimizer = utils.create_optimizer(model, args.optimizer, args.lr, args.momentum, args.wd) + if args.resume and 'optimizer' in ch: + optimizer.load_state_dict(ch['optimizer']) + print('Loaded optimizer state') + cudnn.benchmark = True + + ### Augmentations + flipHorVer = dataset.RandomFlip() + flipLR = dataset.RandomFlipLeftRight() + rot90 = dataset.RandomRot90() + scale = dataset.RandomScale() + noise = dataset.RandomNoise() + if args.augm==0: + transform = None + elif args.augm==1: + transform = transforms.Compose([flipHorVer]) + elif args.augm==2: + transform = transforms.Compose([rot90]) + elif args.augm==3: + transform = transforms.Compose([flipLR]) + elif args.augm==12: + transform = transforms.Compose([flipHorVer,rot90]) + elif args.augm==4: + transform = transforms.Compose([scale]) + elif args.augm==5: + transform = transforms.Compose([noise]) + elif args.augm==14: + transform = transforms.Compose([flip,scale]) + elif args.augm==34: + transform = transforms.Compose([flipLR,scale]) + elif args.augm==45: + transform = transforms.Compose([scale,noise]) + + ### Set datasets + train_dset,trainval_dset,val_dset,_,balance_weight_neg_pos = dataset.get_datasets_singleview(transform,args.normalize,args.balance,args.split_index) + print('Datasets train:{}, val:{}'.format(len(train_dset.df),len(val_dset.df))) + + ### Set loss criterion + if args.balance: + w = torch.Tensor(balance_weight_neg_pos) + print('Balance loss with weights:',balance_weight_neg_pos) + criterion = nn.BCEWithLogitsLoss(pos_weight=w).cuda() + else: + criterion = nn.BCEWithLogitsLoss().cuda() + + ### Early stopping + if args.lr_scheduler: + print('INFO: Initializing learning rate scheduler') + lr_scheduler = LRScheduler(optimizer) + if args.resume and 'lr_scheduler' in ch: + lr_scheduler.lr_scheduler.load_state_dict(ch['lr_scheduler']) + print('Loaded lr_scheduler state') + if args.early_stopping: + print('INFO: Initializing early stopping') + early_stopping = EarlyStopping() + if args.resume and 'early_stopping' in ch: + early_stopping.best_loss = ch['early_stopping']['best_loss'] + early_stopping.counter = ch['early_stopping']['counter'] + early_stopping.min_delta = ch['early_stopping']['min_delta'] + early_stopping.patience = ch['early_stopping']['patience'] + early_stopping.early_stop = ch['early_stopping']['early_stop'] + print('Loaded early_stopping state') + + ### Set loaders + train_loader = torch.utils.data.DataLoader(train_dset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers) + trainval_loader = torch.utils.data.DataLoader(trainval_dset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) + val_loader = torch.utils.data.DataLoader(val_dset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) + + ### Set output files + convergence_name = 'convergence_split'+str(args.split_index)+'_run'+str(args.run)+'.csv' + if not args.resume: + fconv = open(os.path.join(args.output,convergence_name), 'w') + fconv.write('epoch,split,metric,value\n') + fconv.close() + + ### Main training loop + if args.resume: + epochs = range(ch['epoch']+1,args.nepochs+1) + else: + epochs = range(args.nepochs+1) + + for epoch in epochs: + if args.optimizer == 'sgd': + utils.adjust_learning_rate(optimizer, epoch, args.lr_anneal, args.lr) + + ### Training logic + if epoch > 0: + loss = train(epoch, train_loader, model, criterion, optimizer) + else: + loss = np.nan + ### Printing stats + fconv = open(os.path.join(args.output,convergence_name), 'a') + fconv.write('{},train,loss,{}\n'.format(epoch, loss)) + fconv.close() + + ### Validation logic + # Evaluate on train data + train_probs = test(epoch, trainval_loader, model) + train_auc, train_ber, train_fpr, train_fnr = train_dset.errors(train_probs) + # Evaluate on validation set + val_probs = test(epoch, val_loader, model) + val_auc, val_ber, val_fpr, val_fnr = val_dset.errors(val_probs) + + print('Epoch: [{}/{}]\tLoss: {:.6f}\tAUC: {:.4f}\t{:.4f}'.format(epoch, args.nepochs, loss, train_auc, val_auc)) + + fconv = open(os.path.join(args.output,convergence_name), 'a') + fconv.write('{},train,auc,{}\n'.format(epoch, train_auc)) + fconv.write('{},train,ber,{}\n'.format(epoch, train_ber)) + fconv.write('{},train,fpr,{}\n'.format(epoch, train_fpr)) + fconv.write('{},train,fnr,{}\n'.format(epoch, train_fnr)) + fconv.write('{},validation,auc,{}\n'.format(epoch, val_auc)) + fconv.write('{},validation,ber,{}\n'.format(epoch, val_ber)) + fconv.write('{},validation,fpr,{}\n'.format(epoch, val_fpr)) + fconv.write('{},validation,fnr,{}\n'.format(epoch, val_fnr)) + fconv.close() + + ### Create checkpoint dictionary + obj = { + 'epoch': epoch, + 'state_dict': model.state_dict(), + 'optimizer' : optimizer.state_dict(), + 'lr_scheduler' : lr_scheduler.lr_scheduler.state_dict(), + 'early_stopping' : {'best_loss':early_stopping.best_loss,'counter':early_stopping.counter,'early_stop':early_stopping.early_stop,'min_delta': early_stopping.min_delta,'patience': early_stopping.patience}, + 'auc': val_auc, + } + ### Save checkpoint + torch.save(obj, os.path.join(args.output,'checkpoint_split'+str(args.split_index)+'_run'+str(args.run)+'.pth')) + + ### Early stopping + if args.lr_scheduler: + lr_scheduler(-val_auc) + if args.early_stopping: + early_stopping(-val_auc) + if early_stopping.early_stop: + break + +def test(epoch, loader, model): + # Set model in test mode + model.eval() + # Initialize probability vector + probs = torch.FloatTensor(len(loader.dataset)).cuda() + # Loop through batches + with torch.no_grad(): + for i, (input,_) in enumerate(loader): + ## Copy batch to GPU + input = input.cuda() + ## Forward pass + y = model(input) #features, probabilities + p = F.softmax(y,dim=1) + ## Clone output to output vector + probs[i*args.batch_size:i*args.batch_size+input.size(0)] = p.detach()[:,1].clone() + return probs.cpu().numpy() + +def train(epoch, loader, model, criterion, optimizer): + # Set model in training mode + model.train() + # Initialize loss + running_loss = 0. + # Loop through batches + for i, (input,target) in enumerate(loader): + ## Copy to GPU + input = input.cuda() + target_1hot = F.one_hot(target.long(),num_classes=2).cuda() + ## Forward pass + y = model(input) #features, probabilities + ## Calculate loss + loss = criterion(y, target_1hot.float()) + ## Optimization step + optimizer.zero_grad() + loss.backward() + optimizer.step() + ## Store loss + running_loss += loss.item()*input.size(0) + return running_loss/len(loader.dataset) + +if __name__ == '__main__': + main()