|
a |
|
b/train.py |
|
|
1 |
import dataset |
|
|
2 |
import utils |
|
|
3 |
from utils import EarlyStopping, LRScheduler |
|
|
4 |
import os |
|
|
5 |
import pandas as pd |
|
|
6 |
import argparse |
|
|
7 |
import torch.backends.cudnn as cudnn |
|
|
8 |
import torch |
|
|
9 |
import torch.nn as nn |
|
|
10 |
import torch.nn.functional as F |
|
|
11 |
import torchvision.transforms as transforms |
|
|
12 |
import numpy as np |
|
|
13 |
import time |
|
|
14 |
|
|
|
15 |
parser = argparse.ArgumentParser(description='PET lymphoma classification') |
|
|
16 |
|
|
|
17 |
#I/O PARAMS |
|
|
18 |
parser.add_argument('--output', type=str, default='results', help='name of output folder (default: "results")') |
|
|
19 |
|
|
|
20 |
#MODEL PARAMS |
|
|
21 |
parser.add_argument('--normalize', action='store_true', default=False, help='normalize images') |
|
|
22 |
parser.add_argument('--checkpoint', default='', type=str, help='model checkpoint if any (default: none)') |
|
|
23 |
parser.add_argument('--resume', action='store_true', default=False, help='resume from checkpoint') |
|
|
24 |
|
|
|
25 |
#OPTIMIZATION PARAMS |
|
|
26 |
parser.add_argument('--optimizer', default='sgd', type=str, help='The optimizer to use (default: sgd)') |
|
|
27 |
parser.add_argument('--lr', type=float, default=1e-3, help='learning rate (default: 1e-3)') |
|
|
28 |
parser.add_argument('--lr_anneal', type=int, default=15, help='period for lr annealing (default: 15). Only works for SGD') |
|
|
29 |
parser.add_argument('--momentum', default=0.9, type=float, help='momentum (default: 0.9)') |
|
|
30 |
parser.add_argument('--wd', default=1e-4, type=float, help='weight decay (default: 1e-4)') |
|
|
31 |
|
|
|
32 |
#TRAINING PARAMS |
|
|
33 |
parser.add_argument('--split_index', default=0, type=int, metavar='INT', choices=list(range(0,20)),help='which split index (default: 0)') |
|
|
34 |
parser.add_argument('--run', default=1, type=int, metavar='INT', help='repetition run with same settings (default: 1)') |
|
|
35 |
parser.add_argument('--batch_size', type=int, default=50, help='how many images to sample per slide (default: 50)') |
|
|
36 |
parser.add_argument('--nepochs', type=int, default=40, help='number of epochs (default: 40)') |
|
|
37 |
parser.add_argument('--workers', default=10, type=int, help='number of data loading workers (default: 10)') |
|
|
38 |
parser.add_argument('--augm', default=0, type=int, choices=[0,1,2,3,12,4,5,14,34,45], help='augmentation procedure 0=none,1=flip,2=rot,3=flip LR, 12=flip+rot, 4=scale, 5=noise, 14=flip+scale, 34=flipLR+scale, 45=scale+noise (default: 0)') |
|
|
39 |
parser.add_argument('--balance', action='store_true', default=False, help='balance dataset (balance loss)') |
|
|
40 |
parser.add_argument('--lr_scheduler', action='store_true',default=False, help='decrease LR on platau') |
|
|
41 |
parser.add_argument('--early_stopping', action='store_true',default=False, help='use early stopping') |
|
|
42 |
|
|
|
43 |
def main(): |
|
|
44 |
### Get user input |
|
|
45 |
global args |
|
|
46 |
args = parser.parse_args() |
|
|
47 |
print(args) |
|
|
48 |
best_auc = 0. |
|
|
49 |
|
|
|
50 |
### Output directory and files |
|
|
51 |
if not os.path.isdir(args.output): |
|
|
52 |
try: |
|
|
53 |
os.mkdir(args.output) |
|
|
54 |
except OSError: |
|
|
55 |
print ('Creation of the output directory "{}" failed.'.format(args.output)) |
|
|
56 |
else: |
|
|
57 |
print ('Successfully created the output directory "{}".'.format(args.output)) |
|
|
58 |
|
|
|
59 |
### Get model |
|
|
60 |
model = utils.get_model() |
|
|
61 |
if args.checkpoint: |
|
|
62 |
ch = torch.load(args.checkpoint) |
|
|
63 |
model_dict = model.state_dict() |
|
|
64 |
pretrained_dict = {k: v for k, v in ch['state_dict'].items() if k in model_dict} |
|
|
65 |
print('Loaded [{}/{}] keys from checkpoint'.format(len(pretrained_dict),len(model_dict))) |
|
|
66 |
model_dict.update(pretrained_dict) |
|
|
67 |
model.load_state_dict(model_dict) |
|
|
68 |
if args.resume: |
|
|
69 |
ch = torch.load( os.path.join(args.output,'checkpoint_split'+str(args.split_index)+'_run'+str(args.run)+'.pth') ) |
|
|
70 |
model_dict = model.state_dict() |
|
|
71 |
pretrained_dict = {k: v for k, v in ch['state_dict'].items() if k in model_dict} |
|
|
72 |
print('Loaded [{}/{}] keys from checkpoint'.format(len(pretrained_dict),len(model_dict))) |
|
|
73 |
model_dict.update(pretrained_dict) |
|
|
74 |
model.load_state_dict(model_dict) |
|
|
75 |
|
|
|
76 |
### Set optimizer |
|
|
77 |
optimizer = utils.create_optimizer(model, args.optimizer, args.lr, args.momentum, args.wd) |
|
|
78 |
if args.resume and 'optimizer' in ch: |
|
|
79 |
optimizer.load_state_dict(ch['optimizer']) |
|
|
80 |
print('Loaded optimizer state') |
|
|
81 |
cudnn.benchmark = True |
|
|
82 |
|
|
|
83 |
### Augmentations |
|
|
84 |
flipHorVer = dataset.RandomFlip() |
|
|
85 |
flipLR = dataset.RandomFlipLeftRight() |
|
|
86 |
rot90 = dataset.RandomRot90() |
|
|
87 |
scale = dataset.RandomScale() |
|
|
88 |
noise = dataset.RandomNoise() |
|
|
89 |
if args.augm==0: |
|
|
90 |
transform = None |
|
|
91 |
elif args.augm==1: |
|
|
92 |
transform = transforms.Compose([flipHorVer]) |
|
|
93 |
elif args.augm==2: |
|
|
94 |
transform = transforms.Compose([rot90]) |
|
|
95 |
elif args.augm==3: |
|
|
96 |
transform = transforms.Compose([flipLR]) |
|
|
97 |
elif args.augm==12: |
|
|
98 |
transform = transforms.Compose([flipHorVer,rot90]) |
|
|
99 |
elif args.augm==4: |
|
|
100 |
transform = transforms.Compose([scale]) |
|
|
101 |
elif args.augm==5: |
|
|
102 |
transform = transforms.Compose([noise]) |
|
|
103 |
elif args.augm==14: |
|
|
104 |
transform = transforms.Compose([flip,scale]) |
|
|
105 |
elif args.augm==34: |
|
|
106 |
transform = transforms.Compose([flipLR,scale]) |
|
|
107 |
elif args.augm==45: |
|
|
108 |
transform = transforms.Compose([scale,noise]) |
|
|
109 |
|
|
|
110 |
### Set datasets |
|
|
111 |
train_dset,trainval_dset,val_dset,_,balance_weight_neg_pos = dataset.get_datasets_singleview(transform,args.normalize,args.balance,args.split_index) |
|
|
112 |
print('Datasets train:{}, val:{}'.format(len(train_dset.df),len(val_dset.df))) |
|
|
113 |
|
|
|
114 |
### Set loss criterion |
|
|
115 |
if args.balance: |
|
|
116 |
w = torch.Tensor(balance_weight_neg_pos) |
|
|
117 |
print('Balance loss with weights:',balance_weight_neg_pos) |
|
|
118 |
criterion = nn.BCEWithLogitsLoss(pos_weight=w).cuda() |
|
|
119 |
else: |
|
|
120 |
criterion = nn.BCEWithLogitsLoss().cuda() |
|
|
121 |
|
|
|
122 |
### Early stopping |
|
|
123 |
if args.lr_scheduler: |
|
|
124 |
print('INFO: Initializing learning rate scheduler') |
|
|
125 |
lr_scheduler = LRScheduler(optimizer) |
|
|
126 |
if args.resume and 'lr_scheduler' in ch: |
|
|
127 |
lr_scheduler.lr_scheduler.load_state_dict(ch['lr_scheduler']) |
|
|
128 |
print('Loaded lr_scheduler state') |
|
|
129 |
if args.early_stopping: |
|
|
130 |
print('INFO: Initializing early stopping') |
|
|
131 |
early_stopping = EarlyStopping() |
|
|
132 |
if args.resume and 'early_stopping' in ch: |
|
|
133 |
early_stopping.best_loss = ch['early_stopping']['best_loss'] |
|
|
134 |
early_stopping.counter = ch['early_stopping']['counter'] |
|
|
135 |
early_stopping.min_delta = ch['early_stopping']['min_delta'] |
|
|
136 |
early_stopping.patience = ch['early_stopping']['patience'] |
|
|
137 |
early_stopping.early_stop = ch['early_stopping']['early_stop'] |
|
|
138 |
print('Loaded early_stopping state') |
|
|
139 |
|
|
|
140 |
### Set loaders |
|
|
141 |
train_loader = torch.utils.data.DataLoader(train_dset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers) |
|
|
142 |
trainval_loader = torch.utils.data.DataLoader(trainval_dset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) |
|
|
143 |
val_loader = torch.utils.data.DataLoader(val_dset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) |
|
|
144 |
|
|
|
145 |
### Set output files |
|
|
146 |
convergence_name = 'convergence_split'+str(args.split_index)+'_run'+str(args.run)+'.csv' |
|
|
147 |
if not args.resume: |
|
|
148 |
fconv = open(os.path.join(args.output,convergence_name), 'w') |
|
|
149 |
fconv.write('epoch,split,metric,value\n') |
|
|
150 |
fconv.close() |
|
|
151 |
|
|
|
152 |
### Main training loop |
|
|
153 |
if args.resume: |
|
|
154 |
epochs = range(ch['epoch']+1,args.nepochs+1) |
|
|
155 |
else: |
|
|
156 |
epochs = range(args.nepochs+1) |
|
|
157 |
|
|
|
158 |
for epoch in epochs: |
|
|
159 |
if args.optimizer == 'sgd': |
|
|
160 |
utils.adjust_learning_rate(optimizer, epoch, args.lr_anneal, args.lr) |
|
|
161 |
|
|
|
162 |
### Training logic |
|
|
163 |
if epoch > 0: |
|
|
164 |
loss = train(epoch, train_loader, model, criterion, optimizer) |
|
|
165 |
else: |
|
|
166 |
loss = np.nan |
|
|
167 |
### Printing stats |
|
|
168 |
fconv = open(os.path.join(args.output,convergence_name), 'a') |
|
|
169 |
fconv.write('{},train,loss,{}\n'.format(epoch, loss)) |
|
|
170 |
fconv.close() |
|
|
171 |
|
|
|
172 |
### Validation logic |
|
|
173 |
# Evaluate on train data |
|
|
174 |
train_probs = test(epoch, trainval_loader, model) |
|
|
175 |
train_auc, train_ber, train_fpr, train_fnr = train_dset.errors(train_probs) |
|
|
176 |
# Evaluate on validation set |
|
|
177 |
val_probs = test(epoch, val_loader, model) |
|
|
178 |
val_auc, val_ber, val_fpr, val_fnr = val_dset.errors(val_probs) |
|
|
179 |
|
|
|
180 |
print('Epoch: [{}/{}]\tLoss: {:.6f}\tAUC: {:.4f}\t{:.4f}'.format(epoch, args.nepochs, loss, train_auc, val_auc)) |
|
|
181 |
|
|
|
182 |
fconv = open(os.path.join(args.output,convergence_name), 'a') |
|
|
183 |
fconv.write('{},train,auc,{}\n'.format(epoch, train_auc)) |
|
|
184 |
fconv.write('{},train,ber,{}\n'.format(epoch, train_ber)) |
|
|
185 |
fconv.write('{},train,fpr,{}\n'.format(epoch, train_fpr)) |
|
|
186 |
fconv.write('{},train,fnr,{}\n'.format(epoch, train_fnr)) |
|
|
187 |
fconv.write('{},validation,auc,{}\n'.format(epoch, val_auc)) |
|
|
188 |
fconv.write('{},validation,ber,{}\n'.format(epoch, val_ber)) |
|
|
189 |
fconv.write('{},validation,fpr,{}\n'.format(epoch, val_fpr)) |
|
|
190 |
fconv.write('{},validation,fnr,{}\n'.format(epoch, val_fnr)) |
|
|
191 |
fconv.close() |
|
|
192 |
|
|
|
193 |
### Create checkpoint dictionary |
|
|
194 |
obj = { |
|
|
195 |
'epoch': epoch, |
|
|
196 |
'state_dict': model.state_dict(), |
|
|
197 |
'optimizer' : optimizer.state_dict(), |
|
|
198 |
'lr_scheduler' : lr_scheduler.lr_scheduler.state_dict(), |
|
|
199 |
'early_stopping' : {'best_loss':early_stopping.best_loss,'counter':early_stopping.counter,'early_stop':early_stopping.early_stop,'min_delta': early_stopping.min_delta,'patience': early_stopping.patience}, |
|
|
200 |
'auc': val_auc, |
|
|
201 |
} |
|
|
202 |
### Save checkpoint |
|
|
203 |
torch.save(obj, os.path.join(args.output,'checkpoint_split'+str(args.split_index)+'_run'+str(args.run)+'.pth')) |
|
|
204 |
|
|
|
205 |
### Early stopping |
|
|
206 |
if args.lr_scheduler: |
|
|
207 |
lr_scheduler(-val_auc) |
|
|
208 |
if args.early_stopping: |
|
|
209 |
early_stopping(-val_auc) |
|
|
210 |
if early_stopping.early_stop: |
|
|
211 |
break |
|
|
212 |
|
|
|
213 |
def test(epoch, loader, model): |
|
|
214 |
# Set model in test mode |
|
|
215 |
model.eval() |
|
|
216 |
# Initialize probability vector |
|
|
217 |
probs = torch.FloatTensor(len(loader.dataset)).cuda() |
|
|
218 |
# Loop through batches |
|
|
219 |
with torch.no_grad(): |
|
|
220 |
for i, (input,_) in enumerate(loader): |
|
|
221 |
## Copy batch to GPU |
|
|
222 |
input = input.cuda() |
|
|
223 |
## Forward pass |
|
|
224 |
y = model(input) #features, probabilities |
|
|
225 |
p = F.softmax(y,dim=1) |
|
|
226 |
## Clone output to output vector |
|
|
227 |
probs[i*args.batch_size:i*args.batch_size+input.size(0)] = p.detach()[:,1].clone() |
|
|
228 |
return probs.cpu().numpy() |
|
|
229 |
|
|
|
230 |
def train(epoch, loader, model, criterion, optimizer): |
|
|
231 |
# Set model in training mode |
|
|
232 |
model.train() |
|
|
233 |
# Initialize loss |
|
|
234 |
running_loss = 0. |
|
|
235 |
# Loop through batches |
|
|
236 |
for i, (input,target) in enumerate(loader): |
|
|
237 |
## Copy to GPU |
|
|
238 |
input = input.cuda() |
|
|
239 |
target_1hot = F.one_hot(target.long(),num_classes=2).cuda() |
|
|
240 |
## Forward pass |
|
|
241 |
y = model(input) #features, probabilities |
|
|
242 |
## Calculate loss |
|
|
243 |
loss = criterion(y, target_1hot.float()) |
|
|
244 |
## Optimization step |
|
|
245 |
optimizer.zero_grad() |
|
|
246 |
loss.backward() |
|
|
247 |
optimizer.step() |
|
|
248 |
## Store loss |
|
|
249 |
running_loss += loss.item()*input.size(0) |
|
|
250 |
return running_loss/len(loader.dataset) |
|
|
251 |
|
|
|
252 |
if __name__ == '__main__': |
|
|
253 |
main() |