a b/train.py
1
import dataset
2
import utils
3
from utils import EarlyStopping, LRScheduler
4
import os
5
import pandas as pd
6
import argparse
7
import torch.backends.cudnn as cudnn
8
import torch
9
import torch.nn as nn
10
import torch.nn.functional as F
11
import torchvision.transforms as transforms
12
import numpy as np
13
import time
14
15
parser = argparse.ArgumentParser(description='PET lymphoma classification')
16
17
#I/O PARAMS
18
parser.add_argument('--output', type=str, default='results', help='name of output folder (default: "results")')
19
20
#MODEL PARAMS
21
parser.add_argument('--normalize', action='store_true', default=False, help='normalize images')
22
parser.add_argument('--checkpoint', default='', type=str, help='model checkpoint if any (default: none)')
23
parser.add_argument('--resume', action='store_true', default=False, help='resume from checkpoint')
24
25
#OPTIMIZATION PARAMS
26
parser.add_argument('--optimizer', default='sgd', type=str, help='The optimizer to use (default: sgd)')
27
parser.add_argument('--lr', type=float, default=1e-3, help='learning rate (default: 1e-3)')
28
parser.add_argument('--lr_anneal', type=int, default=15, help='period for lr annealing (default: 15). Only works for SGD')
29
parser.add_argument('--momentum', default=0.9, type=float, help='momentum (default: 0.9)')
30
parser.add_argument('--wd', default=1e-4, type=float, help='weight decay (default: 1e-4)')
31
32
#TRAINING PARAMS
33
parser.add_argument('--split_index', default=0, type=int, metavar='INT', choices=list(range(0,20)),help='which split index (default: 0)')   
34
parser.add_argument('--run', default=1, type=int, metavar='INT', help='repetition run with same settings (default: 1)')   
35
parser.add_argument('--batch_size', type=int, default=50, help='how many images to sample per slide (default: 50)')
36
parser.add_argument('--nepochs', type=int, default=40, help='number of epochs (default: 40)')
37
parser.add_argument('--workers', default=10, type=int, help='number of data loading workers (default: 10)')
38
parser.add_argument('--augm', default=0, type=int, choices=[0,1,2,3,12,4,5,14,34,45], help='augmentation procedure 0=none,1=flip,2=rot,3=flip LR, 12=flip+rot, 4=scale, 5=noise, 14=flip+scale, 34=flipLR+scale, 45=scale+noise (default: 0)')
39
parser.add_argument('--balance', action='store_true', default=False, help='balance dataset (balance loss)')
40
parser.add_argument('--lr_scheduler', action='store_true',default=False, help='decrease LR on platau')
41
parser.add_argument('--early_stopping', action='store_true',default=False, help='use early stopping')
42
43
def main():
44
    ### Get user input
45
    global args
46
    args = parser.parse_args()
47
    print(args)
48
    best_auc = 0.
49
50
    ### Output directory and files
51
    if not os.path.isdir(args.output):
52
        try:
53
            os.mkdir(args.output)
54
        except OSError:
55
            print ('Creation of the output directory "{}" failed.'.format(args.output))
56
        else:
57
            print ('Successfully created the output directory "{}".'.format(args.output))
58
    
59
    ### Get model
60
    model = utils.get_model()
61
    if args.checkpoint:
62
        ch = torch.load(args.checkpoint)
63
        model_dict = model.state_dict()
64
        pretrained_dict = {k: v for k, v in ch['state_dict'].items() if k in model_dict}
65
        print('Loaded [{}/{}] keys from checkpoint'.format(len(pretrained_dict),len(model_dict)))
66
        model_dict.update(pretrained_dict)
67
        model.load_state_dict(model_dict)
68
    if args.resume:
69
        ch = torch.load( os.path.join(args.output,'checkpoint_split'+str(args.split_index)+'_run'+str(args.run)+'.pth') )
70
        model_dict = model.state_dict()
71
        pretrained_dict = {k: v for k, v in ch['state_dict'].items() if k in model_dict}
72
        print('Loaded [{}/{}] keys from checkpoint'.format(len(pretrained_dict),len(model_dict)))
73
        model_dict.update(pretrained_dict)
74
        model.load_state_dict(model_dict)
75
    
76
    ### Set optimizer
77
    optimizer = utils.create_optimizer(model, args.optimizer, args.lr, args.momentum, args.wd)
78
    if args.resume and 'optimizer' in ch:
79
        optimizer.load_state_dict(ch['optimizer'])
80
        print('Loaded optimizer state')
81
    cudnn.benchmark = True
82
    
83
    ### Augmentations
84
    flipHorVer = dataset.RandomFlip()
85
    flipLR     = dataset.RandomFlipLeftRight()
86
    rot90      = dataset.RandomRot90()
87
    scale      = dataset.RandomScale()
88
    noise      = dataset.RandomNoise()
89
    if args.augm==0:
90
        transform = None
91
    elif args.augm==1:
92
        transform = transforms.Compose([flipHorVer])
93
    elif args.augm==2:
94
        transform = transforms.Compose([rot90])
95
    elif args.augm==3:
96
        transform = transforms.Compose([flipLR])
97
    elif args.augm==12:
98
        transform = transforms.Compose([flipHorVer,rot90])
99
    elif args.augm==4:
100
        transform = transforms.Compose([scale])
101
    elif args.augm==5:
102
        transform = transforms.Compose([noise])
103
    elif args.augm==14:
104
        transform = transforms.Compose([flip,scale])
105
    elif args.augm==34:
106
        transform = transforms.Compose([flipLR,scale])
107
    elif args.augm==45:
108
        transform = transforms.Compose([scale,noise])
109
    
110
    ### Set datasets
111
    train_dset,trainval_dset,val_dset,_,balance_weight_neg_pos = dataset.get_datasets_singleview(transform,args.normalize,args.balance,args.split_index)
112
    print('Datasets train:{}, val:{}'.format(len(train_dset.df),len(val_dset.df))) 
113
    
114
    ### Set loss criterion
115
    if args.balance:
116
        w = torch.Tensor(balance_weight_neg_pos)
117
        print('Balance loss with weights:',balance_weight_neg_pos)
118
        criterion = nn.BCEWithLogitsLoss(pos_weight=w).cuda()
119
    else:
120
        criterion = nn.BCEWithLogitsLoss().cuda()
121
    
122
    ### Early stopping
123
    if args.lr_scheduler:
124
        print('INFO: Initializing learning rate scheduler')
125
        lr_scheduler = LRScheduler(optimizer)
126
        if args.resume and 'lr_scheduler' in ch:
127
            lr_scheduler.lr_scheduler.load_state_dict(ch['lr_scheduler'])
128
            print('Loaded lr_scheduler state')
129
    if args.early_stopping:
130
        print('INFO: Initializing early stopping')
131
        early_stopping = EarlyStopping()
132
        if args.resume and 'early_stopping' in ch:
133
            early_stopping.best_loss = ch['early_stopping']['best_loss']
134
            early_stopping.counter = ch['early_stopping']['counter']
135
            early_stopping.min_delta = ch['early_stopping']['min_delta']
136
            early_stopping.patience = ch['early_stopping']['patience']
137
            early_stopping.early_stop = ch['early_stopping']['early_stop']
138
            print('Loaded early_stopping state')
139
        
140
    ### Set loaders
141
    train_loader = torch.utils.data.DataLoader(train_dset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers)
142
    trainval_loader = torch.utils.data.DataLoader(trainval_dset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
143
    val_loader = torch.utils.data.DataLoader(val_dset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
144
    
145
    ### Set output files
146
    convergence_name = 'convergence_split'+str(args.split_index)+'_run'+str(args.run)+'.csv'
147
    if not args.resume:
148
        fconv = open(os.path.join(args.output,convergence_name), 'w')
149
        fconv.write('epoch,split,metric,value\n')
150
        fconv.close()
151
    
152
    ### Main training loop
153
    if args.resume:
154
        epochs = range(ch['epoch']+1,args.nepochs+1)
155
    else:
156
        epochs = range(args.nepochs+1)
157
    
158
    for epoch in epochs:
159
        if args.optimizer == 'sgd':
160
            utils.adjust_learning_rate(optimizer, epoch, args.lr_anneal, args.lr)
161
        
162
        ### Training logic
163
        if epoch > 0:
164
            loss = train(epoch, train_loader, model, criterion, optimizer)
165
        else:
166
            loss = np.nan
167
        ### Printing stats
168
        fconv = open(os.path.join(args.output,convergence_name), 'a')
169
        fconv.write('{},train,loss,{}\n'.format(epoch, loss))
170
        fconv.close()
171
        
172
        ### Validation logic
173
        # Evaluate on train data
174
        train_probs = test(epoch, trainval_loader, model)
175
        train_auc, train_ber, train_fpr, train_fnr = train_dset.errors(train_probs)
176
        # Evaluate on validation set
177
        val_probs = test(epoch, val_loader, model)
178
        val_auc, val_ber, val_fpr, val_fnr = val_dset.errors(val_probs)
179
        
180
        print('Epoch: [{}/{}]\tLoss: {:.6f}\tAUC: {:.4f}\t{:.4f}'.format(epoch, args.nepochs, loss, train_auc, val_auc))
181
        
182
        fconv = open(os.path.join(args.output,convergence_name), 'a')
183
        fconv.write('{},train,auc,{}\n'.format(epoch, train_auc))
184
        fconv.write('{},train,ber,{}\n'.format(epoch, train_ber))
185
        fconv.write('{},train,fpr,{}\n'.format(epoch, train_fpr))
186
        fconv.write('{},train,fnr,{}\n'.format(epoch, train_fnr))
187
        fconv.write('{},validation,auc,{}\n'.format(epoch, val_auc))
188
        fconv.write('{},validation,ber,{}\n'.format(epoch, val_ber))
189
        fconv.write('{},validation,fpr,{}\n'.format(epoch, val_fpr))
190
        fconv.write('{},validation,fnr,{}\n'.format(epoch, val_fnr))
191
        fconv.close()
192
        
193
        ### Create checkpoint dictionary
194
        obj = {
195
            'epoch': epoch,
196
            'state_dict': model.state_dict(),
197
            'optimizer' : optimizer.state_dict(),
198
            'lr_scheduler' : lr_scheduler.lr_scheduler.state_dict(),
199
            'early_stopping' : {'best_loss':early_stopping.best_loss,'counter':early_stopping.counter,'early_stop':early_stopping.early_stop,'min_delta': early_stopping.min_delta,'patience': early_stopping.patience},
200
            'auc': val_auc,
201
        }
202
        ### Save checkpoint
203
        torch.save(obj, os.path.join(args.output,'checkpoint_split'+str(args.split_index)+'_run'+str(args.run)+'.pth'))
204
        
205
        ### Early stopping
206
        if args.lr_scheduler:
207
            lr_scheduler(-val_auc)
208
        if args.early_stopping:
209
            early_stopping(-val_auc)
210
            if early_stopping.early_stop:
211
                break
212
213
def test(epoch, loader, model):
214
    # Set model in test mode
215
    model.eval()
216
    # Initialize probability vector
217
    probs = torch.FloatTensor(len(loader.dataset)).cuda()
218
    # Loop through batches
219
    with torch.no_grad():
220
        for i, (input,_) in enumerate(loader):
221
            ## Copy batch to GPU
222
            input = input.cuda()
223
            ## Forward pass
224
            y = model(input) #features, probabilities
225
            p = F.softmax(y,dim=1)
226
            ## Clone output to output vector
227
            probs[i*args.batch_size:i*args.batch_size+input.size(0)] = p.detach()[:,1].clone()
228
    return probs.cpu().numpy()
229
230
def train(epoch, loader, model, criterion, optimizer):
231
    # Set model in training mode
232
    model.train()
233
    # Initialize loss
234
    running_loss = 0.
235
    # Loop through batches
236
    for i, (input,target) in enumerate(loader):
237
        ## Copy to GPU
238
        input = input.cuda()
239
        target_1hot = F.one_hot(target.long(),num_classes=2).cuda()
240
        ## Forward pass
241
        y = model(input) #features, probabilities
242
        ## Calculate loss
243
        loss = criterion(y, target_1hot.float())
244
        ## Optimization step
245
        optimizer.zero_grad()
246
        loss.backward()
247
        optimizer.step()
248
        ## Store loss
249
        running_loss += loss.item()*input.size(0)
250
    return running_loss/len(loader.dataset)
251
252
if __name__ == '__main__':
253
    main()