|
a |
|
b/pytorch/train.py |
|
|
1 |
from __future__ import absolute_import, division, print_function |
|
|
2 |
|
|
|
3 |
import argparse |
|
|
4 |
import random |
|
|
5 |
import shutil |
|
|
6 |
from os import getcwd |
|
|
7 |
from os.path import exists, isdir, isfile, join |
|
|
8 |
|
|
|
9 |
import numpy as np |
|
|
10 |
import pandas as pd |
|
|
11 |
import torch |
|
|
12 |
import torch.backends.cudnn as cudnn |
|
|
13 |
import torch.nn as nn |
|
|
14 |
import torch.nn.parallel |
|
|
15 |
import torch.optim as optim |
|
|
16 |
import torch.utils.data as data |
|
|
17 |
from sklearn.metrics import (accuracy_score, f1_score, precision_score, recall_score) |
|
|
18 |
from tensorboardX import SummaryWriter |
|
|
19 |
from torch.autograd import Variable |
|
|
20 |
from torch.optim.lr_scheduler import ReduceLROnPlateau |
|
|
21 |
from tqdm import tqdm |
|
|
22 |
|
|
|
23 |
import torchvision |
|
|
24 |
import torchvision.models as models |
|
|
25 |
import torchvision.transforms as transforms |
|
|
26 |
from dataloader import MuraDataset |
|
|
27 |
|
|
|
28 |
print("torch : {}".format(torch.__version__)) |
|
|
29 |
print("torch vision : {}".format(torchvision.__version__)) |
|
|
30 |
print("numpy : {}".format(np.__version__)) |
|
|
31 |
print("pandas : {}".format(pd.__version__)) |
|
|
32 |
model_names = sorted(name for name in models.__dict__ if name.islower() and not name.startswith("__")) |
|
|
33 |
|
|
|
34 |
parser = argparse.ArgumentParser(description='Hyperparameters') |
|
|
35 |
parser.add_argument('--data_dir', default='MURA-v1.0', metavar='DIR', help='path to dataset') |
|
|
36 |
parser.add_argument('--arch', default='densenet121', choices=model_names, help='nn architecture') |
|
|
37 |
parser.add_argument('--classes', default=2, type=int) |
|
|
38 |
parser.add_argument('--workers', default=4, type=int) |
|
|
39 |
parser.add_argument('--epochs', default=90, type=int) |
|
|
40 |
parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number') |
|
|
41 |
parser.add_argument('-b', '--batch-size', default=512, type=int, help='mini-batch size') |
|
|
42 |
parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float, help='initial learning rate') |
|
|
43 |
parser.add_argument('--momentum', default=0.9, type=float, help='momentum') |
|
|
44 |
parser.add_argument('--weight-decay', default=.1, type=float, help='weight decay') |
|
|
45 |
parser.add_argument('--resume', default='', type=str, help='path to latest checkpoint') |
|
|
46 |
parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model') |
|
|
47 |
parser.add_argument('--fullretrain', dest='fullretrain', action='store_true', help='retrain all layers of the model') |
|
|
48 |
parser.add_argument('--seed', default=1337, type=int, help='random seed') |
|
|
49 |
|
|
|
50 |
best_val_loss = 0 |
|
|
51 |
|
|
|
52 |
tb_writer = SummaryWriter() |
|
|
53 |
|
|
|
54 |
|
|
|
55 |
def main(): |
|
|
56 |
global args, best_val_loss |
|
|
57 |
args = parser.parse_args() |
|
|
58 |
print("=> setting random seed to '{}'".format(args.seed)) |
|
|
59 |
np.random.seed(args.seed) |
|
|
60 |
torch.manual_seed(args.seed) |
|
|
61 |
torch.cuda.manual_seed(args.seed) |
|
|
62 |
|
|
|
63 |
if args.pretrained: |
|
|
64 |
print("=> using pre-trained model '{}'".format(args.arch)) |
|
|
65 |
model = models.__dict__[args.arch](pretrained=True) |
|
|
66 |
for param in model.parameters(): |
|
|
67 |
param.requires_grad = False |
|
|
68 |
|
|
|
69 |
if 'resnet' in args.arch: |
|
|
70 |
# for param in model.layer4.parameters(): |
|
|
71 |
model.fc = nn.Linear(2048, args.classes) |
|
|
72 |
|
|
|
73 |
if 'dense' in args.arch: |
|
|
74 |
if '121' in args.arch: |
|
|
75 |
# (classifier): Linear(in_features=1024) |
|
|
76 |
model.classifier = nn.Linear(1024, args.classes) |
|
|
77 |
elif '169' in args.arch: |
|
|
78 |
# (classifier): Linear(in_features=1664) |
|
|
79 |
model.classifier = nn.Linear(1664, args.classes) |
|
|
80 |
else: |
|
|
81 |
return |
|
|
82 |
|
|
|
83 |
else: |
|
|
84 |
print("=> creating model '{}'".format(args.arch)) |
|
|
85 |
model = models.__dict__[args.arch]() |
|
|
86 |
|
|
|
87 |
model = torch.nn.DataParallel(model).cuda() |
|
|
88 |
# optionally resume from a checkpoint |
|
|
89 |
if args.resume: |
|
|
90 |
if isfile(args.resume): |
|
|
91 |
print("=> found checkpoint") |
|
|
92 |
checkpoint = torch.load(args.resume) |
|
|
93 |
args.start_epoch = checkpoint['epoch'] |
|
|
94 |
best_val_loss = checkpoint['best_val_loss'] |
|
|
95 |
model.load_state_dict(checkpoint['state_dict']) |
|
|
96 |
|
|
|
97 |
args.epochs = args.epochs + args.start_epoch |
|
|
98 |
print("=> loading checkpoint '{}' with acc of '{}'".format( |
|
|
99 |
args.resume, |
|
|
100 |
checkpoint['best_val_loss'], )) |
|
|
101 |
|
|
|
102 |
else: |
|
|
103 |
print("=> no checkpoint found at '{}'".format(args.resume)) |
|
|
104 |
|
|
|
105 |
cudnn.benchmark = True |
|
|
106 |
|
|
|
107 |
# Data loading code |
|
|
108 |
data_dir = join(getcwd(), args.data_dir) |
|
|
109 |
train_dir = join(data_dir, 'train') |
|
|
110 |
train_csv = join(data_dir, 'train.csv') |
|
|
111 |
val_dir = join(data_dir, 'valid') |
|
|
112 |
val_csv = join(data_dir, 'valid.csv') |
|
|
113 |
test_dir = join(data_dir, 'test') |
|
|
114 |
assert isdir(data_dir) and isdir(train_dir) and isdir(val_dir) and isdir(test_dir) |
|
|
115 |
assert exists(train_csv) and isfile(train_csv) and exists(val_csv) and isfile(val_csv) |
|
|
116 |
|
|
|
117 |
# Before feeding images into the network, we normalize each image to have |
|
|
118 |
# the same mean and standard deviation of images in the ImageNet training set. |
|
|
119 |
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
120 |
|
|
|
121 |
# We then scale the variable-sized images to 224 × 224. |
|
|
122 |
# We augment by applying random lateral inversions and rotations. |
|
|
123 |
train_transforms = transforms.Compose([ |
|
|
124 |
transforms.Resize(224), |
|
|
125 |
transforms.CenterCrop(224), |
|
|
126 |
# transforms.RandomVerticalFlip(), |
|
|
127 |
# transforms.RandomRotation(30), |
|
|
128 |
transforms.RandomHorizontalFlip(), |
|
|
129 |
transforms.ToTensor(), |
|
|
130 |
normalize, |
|
|
131 |
]) |
|
|
132 |
|
|
|
133 |
train_data = MuraDataset(train_csv, transform=train_transforms) |
|
|
134 |
weights = train_data.balanced_weights |
|
|
135 |
weights = torch.DoubleTensor(weights) |
|
|
136 |
sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights)) |
|
|
137 |
|
|
|
138 |
# num_of_sample = 37110 |
|
|
139 |
# weights = 1 / torch.DoubleTensor([24121, 1300]) |
|
|
140 |
# sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, num_of_sample) |
|
|
141 |
train_loader = data.DataLoader( |
|
|
142 |
train_data, |
|
|
143 |
batch_size=args.batch_size, |
|
|
144 |
# shuffle=True, |
|
|
145 |
num_workers=args.workers, |
|
|
146 |
sampler=sampler, |
|
|
147 |
pin_memory=True) |
|
|
148 |
val_loader = data.DataLoader( |
|
|
149 |
MuraDataset(val_csv, |
|
|
150 |
transforms.Compose([ |
|
|
151 |
transforms.Resize(224), |
|
|
152 |
transforms.CenterCrop(224), |
|
|
153 |
transforms.ToTensor(), |
|
|
154 |
normalize, |
|
|
155 |
])), |
|
|
156 |
batch_size=args.batch_size, |
|
|
157 |
shuffle=False, |
|
|
158 |
num_workers=args.workers, |
|
|
159 |
pin_memory=True) |
|
|
160 |
|
|
|
161 |
criterion = nn.CrossEntropyLoss().cuda() |
|
|
162 |
# We use an initial learning rate of 0.0001 that is decayed by a factor of |
|
|
163 |
# 10 each time the validation loss plateaus after an epoch, and pick the |
|
|
164 |
# model with the lowest validation loss |
|
|
165 |
if args.fullretrain: |
|
|
166 |
print("=> optimizing all layers") |
|
|
167 |
for param in model.parameters(): |
|
|
168 |
param.requires_grad = True |
|
|
169 |
optimizer = optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay) |
|
|
170 |
else: |
|
|
171 |
print("=> optimizing fc/classifier layers") |
|
|
172 |
optimizer = optim.Adam(model.module.fc.parameters(), args.lr, weight_decay=args.weight_decay) |
|
|
173 |
|
|
|
174 |
scheduler = ReduceLROnPlateau(optimizer, 'max', patience=10, verbose=True) |
|
|
175 |
for epoch in range(args.start_epoch, args.epochs): |
|
|
176 |
# train for one epoch |
|
|
177 |
train(train_loader, model, criterion, optimizer, epoch) |
|
|
178 |
# evaluate on validation set |
|
|
179 |
val_loss = validate(val_loader, model, criterion, epoch) |
|
|
180 |
scheduler.step(val_loss) |
|
|
181 |
# remember best Accuracy and save checkpoint |
|
|
182 |
is_best = val_loss > best_val_loss |
|
|
183 |
best_val_loss = max(val_loss, best_val_loss) |
|
|
184 |
save_checkpoint({ |
|
|
185 |
'epoch': epoch + 1, |
|
|
186 |
'arch': args.arch, |
|
|
187 |
'state_dict': model.state_dict(), |
|
|
188 |
'best_val_loss': best_val_loss, |
|
|
189 |
}, is_best) |
|
|
190 |
|
|
|
191 |
|
|
|
192 |
def train(train_loader, model, criterion, optimizer, epoch): |
|
|
193 |
losses = AverageMeter() |
|
|
194 |
acc = AverageMeter() |
|
|
195 |
|
|
|
196 |
# ensure model is in train mode |
|
|
197 |
model.train() |
|
|
198 |
pbar = tqdm(train_loader) |
|
|
199 |
for i, (images, target, meta) in enumerate(pbar): |
|
|
200 |
target = target.cuda(async=True) |
|
|
201 |
image_var = Variable(images) |
|
|
202 |
label_var = Variable(target) |
|
|
203 |
|
|
|
204 |
# pass this batch through our model and get y_pred |
|
|
205 |
y_pred = model(image_var) |
|
|
206 |
|
|
|
207 |
# update loss metric |
|
|
208 |
loss = criterion(y_pred, label_var) |
|
|
209 |
losses.update(loss.data[0], images.size(0)) |
|
|
210 |
|
|
|
211 |
# update accuracy metric |
|
|
212 |
prec1, prec1 = accuracy(y_pred.data, target, topk=(1, 1)) |
|
|
213 |
acc.update(prec1[0], images.size(0)) |
|
|
214 |
|
|
|
215 |
# compute gradient and do SGD step |
|
|
216 |
optimizer.zero_grad() |
|
|
217 |
loss.backward() |
|
|
218 |
optimizer.step() |
|
|
219 |
|
|
|
220 |
pbar.set_description("EPOCH[{0}][{1}/{2}]".format(epoch, i, len(train_loader))) |
|
|
221 |
pbar.set_postfix( |
|
|
222 |
acc="{acc.val:.4f} ({acc.avg:.4f})".format(acc=acc), |
|
|
223 |
loss="{loss.val:.4f} ({loss.avg:.4f})".format(loss=losses)) |
|
|
224 |
|
|
|
225 |
tb_writer.add_scalar('train/loss', losses.avg, epoch) |
|
|
226 |
tb_writer.add_scalar('train/acc', acc.avg, epoch) |
|
|
227 |
return |
|
|
228 |
|
|
|
229 |
|
|
|
230 |
def validate(val_loader, model, criterion, epoch): |
|
|
231 |
model.eval() |
|
|
232 |
acc = AverageMeter() |
|
|
233 |
losses = AverageMeter() |
|
|
234 |
meta_data = [] |
|
|
235 |
pbar = tqdm(val_loader) |
|
|
236 |
for i, (images, target, meta) in enumerate(pbar): |
|
|
237 |
target = target.cuda(async=True) |
|
|
238 |
image_var = Variable(images, volatile=True) |
|
|
239 |
label_var = Variable(target, volatile=True) |
|
|
240 |
|
|
|
241 |
y_pred = model(image_var) |
|
|
242 |
# udpate loss metric |
|
|
243 |
loss = criterion(y_pred, label_var) |
|
|
244 |
losses.update(loss.data[0], images.size(0)) |
|
|
245 |
|
|
|
246 |
# update accuracy metric on the GPU |
|
|
247 |
prec1, prec1 = accuracy(y_pred.data, target, topk=(1, 1)) |
|
|
248 |
acc.update(prec1[0], images.size(0)) |
|
|
249 |
|
|
|
250 |
sm = nn.Softmax() |
|
|
251 |
sm_pred = sm(y_pred).data.cpu().numpy() |
|
|
252 |
# y_norm_probs = sm_pred[:, 0] # p(normal) |
|
|
253 |
y_pred_probs = sm_pred[:, 1] # p(abnormal) |
|
|
254 |
|
|
|
255 |
meta_data.append( |
|
|
256 |
pd.DataFrame({ |
|
|
257 |
'img_filename': meta['img_filename'], |
|
|
258 |
'y_true': meta['y_true'].numpy(), |
|
|
259 |
'y_pred_probs': y_pred_probs, |
|
|
260 |
'patient': meta['patient'].numpy(), |
|
|
261 |
'study': meta['study'].numpy(), |
|
|
262 |
'image_num': meta['image_num'].numpy(), |
|
|
263 |
'encounter': meta['encounter'], |
|
|
264 |
})) |
|
|
265 |
|
|
|
266 |
pbar.set_description("VALIDATION[{}/{}]".format(i, len(val_loader))) |
|
|
267 |
pbar.set_postfix( |
|
|
268 |
acc="{acc.val:.4f} ({acc.avg:.4f})".format(acc=acc), |
|
|
269 |
loss="{loss.val:.4f} ({loss.avg:.4f})".format(loss=losses)) |
|
|
270 |
df = pd.concat(meta_data) |
|
|
271 |
ab = df.groupby(['encounter'])['y_pred_probs', 'y_true'].mean() |
|
|
272 |
ab['y_pred_round'] = ab.y_pred_probs.round() |
|
|
273 |
ab['y_pred_round'] = pd.to_numeric(ab.y_pred_round, downcast='integer') |
|
|
274 |
|
|
|
275 |
f1_s = f1_score(ab.y_true, ab.y_pred_round) |
|
|
276 |
prec_s = precision_score(ab.y_true, ab.y_pred_round) |
|
|
277 |
rec_s = recall_score(ab.y_true, ab.y_pred_round) |
|
|
278 |
acc_s = accuracy_score(ab.y_true, ab.y_pred_round) |
|
|
279 |
tb_writer.add_scalar('val/f1_score', f1_s, epoch) |
|
|
280 |
tb_writer.add_scalar('val/precision', prec_s, epoch) |
|
|
281 |
tb_writer.add_scalar('val/recall', rec_s, epoch) |
|
|
282 |
tb_writer.add_scalar('val/accuracy', acc_s, epoch) |
|
|
283 |
# return the metric we want to evaluate this model's performance by |
|
|
284 |
return f1_s |
|
|
285 |
|
|
|
286 |
|
|
|
287 |
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'): |
|
|
288 |
torch.save(state, filename) |
|
|
289 |
if is_best: |
|
|
290 |
shutil.copyfile(filename, 'model_best.pth.tar') |
|
|
291 |
|
|
|
292 |
|
|
|
293 |
class AverageMeter(object): |
|
|
294 |
"""Computes and stores the average and current value""" |
|
|
295 |
|
|
|
296 |
def __init__(self): |
|
|
297 |
self.reset() |
|
|
298 |
|
|
|
299 |
def reset(self): |
|
|
300 |
self.val = 0 |
|
|
301 |
self.avg = 0 |
|
|
302 |
self.sum = 0 |
|
|
303 |
self.count = 0 |
|
|
304 |
|
|
|
305 |
def update(self, val, n=1): |
|
|
306 |
self.val = val |
|
|
307 |
self.sum += val * n |
|
|
308 |
self.count += n |
|
|
309 |
self.avg = self.sum / self.count |
|
|
310 |
|
|
|
311 |
|
|
|
312 |
def accuracy(y_pred, y_actual, topk=(1, )): |
|
|
313 |
"""Computes the precision@k for the specified values of k""" |
|
|
314 |
maxk = max(topk) |
|
|
315 |
batch_size = y_actual.size(0) |
|
|
316 |
|
|
|
317 |
_, pred = y_pred.topk(maxk, 1, True, True) |
|
|
318 |
pred = pred.t() |
|
|
319 |
correct = pred.eq(y_actual.view(1, -1).expand_as(pred)) |
|
|
320 |
|
|
|
321 |
res = [] |
|
|
322 |
for k in topk: |
|
|
323 |
correct_k = correct[:k].view(-1).float().sum(0) |
|
|
324 |
res.append(correct_k.mul_(100.0 / batch_size)) |
|
|
325 |
|
|
|
326 |
return res |
|
|
327 |
|
|
|
328 |
|
|
|
329 |
if __name__ == '__main__': |
|
|
330 |
main() |