Diff of /pytorch/train.py [000000] .. [bca7a0]

Switch to unified view

a b/pytorch/train.py
1
from __future__ import absolute_import, division, print_function
2
3
import argparse
4
import random
5
import shutil
6
from os import getcwd
7
from os.path import exists, isdir, isfile, join
8
9
import numpy as np
10
import pandas as pd
11
import torch
12
import torch.backends.cudnn as cudnn
13
import torch.nn as nn
14
import torch.nn.parallel
15
import torch.optim as optim
16
import torch.utils.data as data
17
from sklearn.metrics import (accuracy_score, f1_score, precision_score, recall_score)
18
from tensorboardX import SummaryWriter
19
from torch.autograd import Variable
20
from torch.optim.lr_scheduler import ReduceLROnPlateau
21
from tqdm import tqdm
22
23
import torchvision
24
import torchvision.models as models
25
import torchvision.transforms as transforms
26
from dataloader import MuraDataset
27
28
print("torch : {}".format(torch.__version__))
29
print("torch vision : {}".format(torchvision.__version__))
30
print("numpy : {}".format(np.__version__))
31
print("pandas : {}".format(pd.__version__))
32
model_names = sorted(name for name in models.__dict__ if name.islower() and not name.startswith("__"))
33
34
parser = argparse.ArgumentParser(description='Hyperparameters')
35
parser.add_argument('--data_dir', default='MURA-v1.0', metavar='DIR', help='path to dataset')
36
parser.add_argument('--arch', default='densenet121', choices=model_names, help='nn architecture')
37
parser.add_argument('--classes', default=2, type=int)
38
parser.add_argument('--workers', default=4, type=int)
39
parser.add_argument('--epochs', default=90, type=int)
40
parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number')
41
parser.add_argument('-b', '--batch-size', default=512, type=int, help='mini-batch size')
42
parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float, help='initial learning rate')
43
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
44
parser.add_argument('--weight-decay', default=.1, type=float, help='weight decay')
45
parser.add_argument('--resume', default='', type=str, help='path to latest checkpoint')
46
parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model')
47
parser.add_argument('--fullretrain', dest='fullretrain', action='store_true', help='retrain all layers of the model')
48
parser.add_argument('--seed', default=1337, type=int, help='random seed')
49
50
best_val_loss = 0
51
52
tb_writer = SummaryWriter()
53
54
55
def main():
56
    global args, best_val_loss
57
    args = parser.parse_args()
58
    print("=> setting random seed to '{}'".format(args.seed))
59
    np.random.seed(args.seed)
60
    torch.manual_seed(args.seed)
61
    torch.cuda.manual_seed(args.seed)
62
63
    if args.pretrained:
64
        print("=> using pre-trained model '{}'".format(args.arch))
65
        model = models.__dict__[args.arch](pretrained=True)
66
        for param in model.parameters():
67
            param.requires_grad = False
68
69
        if 'resnet' in args.arch:
70
            # for param in model.layer4.parameters():
71
            model.fc = nn.Linear(2048, args.classes)
72
73
        if 'dense' in args.arch:
74
            if '121' in args.arch:
75
                # (classifier): Linear(in_features=1024)
76
                model.classifier = nn.Linear(1024, args.classes)
77
            elif '169' in args.arch:
78
                # (classifier): Linear(in_features=1664)
79
                model.classifier = nn.Linear(1664, args.classes)
80
            else:
81
                return
82
83
    else:
84
        print("=> creating model '{}'".format(args.arch))
85
        model = models.__dict__[args.arch]()
86
87
    model = torch.nn.DataParallel(model).cuda()
88
    # optionally resume from a checkpoint
89
    if args.resume:
90
        if isfile(args.resume):
91
            print("=> found checkpoint")
92
            checkpoint = torch.load(args.resume)
93
            args.start_epoch = checkpoint['epoch']
94
            best_val_loss = checkpoint['best_val_loss']
95
            model.load_state_dict(checkpoint['state_dict'])
96
97
            args.epochs = args.epochs + args.start_epoch
98
            print("=> loading checkpoint '{}' with acc of '{}'".format(
99
                args.resume,
100
                checkpoint['best_val_loss'], ))
101
102
        else:
103
            print("=> no checkpoint found at '{}'".format(args.resume))
104
105
    cudnn.benchmark = True
106
107
    # Data loading code
108
    data_dir = join(getcwd(), args.data_dir)
109
    train_dir = join(data_dir, 'train')
110
    train_csv = join(data_dir, 'train.csv')
111
    val_dir = join(data_dir, 'valid')
112
    val_csv = join(data_dir, 'valid.csv')
113
    test_dir = join(data_dir, 'test')
114
    assert isdir(data_dir) and isdir(train_dir) and isdir(val_dir) and isdir(test_dir)
115
    assert exists(train_csv) and isfile(train_csv) and exists(val_csv) and isfile(val_csv)
116
117
    # Before feeding images into the network, we normalize each image to have
118
    # the same mean and standard deviation of images in the ImageNet training set.
119
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
120
121
    # We then scale the variable-sized images to 224 × 224.
122
    # We augment by applying random lateral inversions and rotations.
123
    train_transforms = transforms.Compose([
124
        transforms.Resize(224),
125
        transforms.CenterCrop(224),
126
        # transforms.RandomVerticalFlip(),
127
        # transforms.RandomRotation(30),
128
        transforms.RandomHorizontalFlip(),
129
        transforms.ToTensor(),
130
        normalize,
131
    ])
132
133
    train_data = MuraDataset(train_csv, transform=train_transforms)
134
    weights = train_data.balanced_weights
135
    weights = torch.DoubleTensor(weights)
136
    sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))
137
138
    # num_of_sample = 37110
139
    # weights = 1 / torch.DoubleTensor([24121, 1300])
140
    # sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, num_of_sample)
141
    train_loader = data.DataLoader(
142
        train_data,
143
        batch_size=args.batch_size,
144
        # shuffle=True,
145
        num_workers=args.workers,
146
        sampler=sampler,
147
        pin_memory=True)
148
    val_loader = data.DataLoader(
149
        MuraDataset(val_csv,
150
                    transforms.Compose([
151
                        transforms.Resize(224),
152
                        transforms.CenterCrop(224),
153
                        transforms.ToTensor(),
154
                        normalize,
155
                    ])),
156
        batch_size=args.batch_size,
157
        shuffle=False,
158
        num_workers=args.workers,
159
        pin_memory=True)
160
161
    criterion = nn.CrossEntropyLoss().cuda()
162
    # We use an initial learning rate of 0.0001 that is decayed by a factor of
163
    # 10 each time the validation loss plateaus after an epoch, and pick the
164
    # model with the lowest validation loss
165
    if args.fullretrain:
166
        print("=> optimizing all layers")
167
        for param in model.parameters():
168
            param.requires_grad = True
169
        optimizer = optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay)
170
    else:
171
        print("=> optimizing fc/classifier layers")
172
        optimizer = optim.Adam(model.module.fc.parameters(), args.lr, weight_decay=args.weight_decay)
173
174
    scheduler = ReduceLROnPlateau(optimizer, 'max', patience=10, verbose=True)
175
    for epoch in range(args.start_epoch, args.epochs):
176
        # train for one epoch
177
        train(train_loader, model, criterion, optimizer, epoch)
178
        # evaluate on validation set
179
        val_loss = validate(val_loader, model, criterion, epoch)
180
        scheduler.step(val_loss)
181
        # remember best Accuracy and save checkpoint
182
        is_best = val_loss > best_val_loss
183
        best_val_loss = max(val_loss, best_val_loss)
184
        save_checkpoint({
185
            'epoch': epoch + 1,
186
            'arch': args.arch,
187
            'state_dict': model.state_dict(),
188
            'best_val_loss': best_val_loss,
189
        }, is_best)
190
191
192
def train(train_loader, model, criterion, optimizer, epoch):
193
    losses = AverageMeter()
194
    acc = AverageMeter()
195
196
    # ensure model is in train mode
197
    model.train()
198
    pbar = tqdm(train_loader)
199
    for i, (images, target, meta) in enumerate(pbar):
200
        target = target.cuda(async=True)
201
        image_var = Variable(images)
202
        label_var = Variable(target)
203
204
        # pass this batch through our model and get y_pred
205
        y_pred = model(image_var)
206
207
        # update loss metric
208
        loss = criterion(y_pred, label_var)
209
        losses.update(loss.data[0], images.size(0))
210
211
        # update accuracy metric
212
        prec1, prec1 = accuracy(y_pred.data, target, topk=(1, 1))
213
        acc.update(prec1[0], images.size(0))
214
215
        # compute gradient and do SGD step
216
        optimizer.zero_grad()
217
        loss.backward()
218
        optimizer.step()
219
220
        pbar.set_description("EPOCH[{0}][{1}/{2}]".format(epoch, i, len(train_loader)))
221
        pbar.set_postfix(
222
            acc="{acc.val:.4f} ({acc.avg:.4f})".format(acc=acc),
223
            loss="{loss.val:.4f} ({loss.avg:.4f})".format(loss=losses))
224
225
    tb_writer.add_scalar('train/loss', losses.avg, epoch)
226
    tb_writer.add_scalar('train/acc', acc.avg, epoch)
227
    return
228
229
230
def validate(val_loader, model, criterion, epoch):
231
    model.eval()
232
    acc = AverageMeter()
233
    losses = AverageMeter()
234
    meta_data = []
235
    pbar = tqdm(val_loader)
236
    for i, (images, target, meta) in enumerate(pbar):
237
        target = target.cuda(async=True)
238
        image_var = Variable(images, volatile=True)
239
        label_var = Variable(target, volatile=True)
240
241
        y_pred = model(image_var)
242
        # udpate loss metric
243
        loss = criterion(y_pred, label_var)
244
        losses.update(loss.data[0], images.size(0))
245
246
        # update accuracy metric on the GPU
247
        prec1, prec1 = accuracy(y_pred.data, target, topk=(1, 1))
248
        acc.update(prec1[0], images.size(0))
249
250
        sm = nn.Softmax()
251
        sm_pred = sm(y_pred).data.cpu().numpy()
252
        # y_norm_probs = sm_pred[:, 0] # p(normal)
253
        y_pred_probs = sm_pred[:, 1]  # p(abnormal)
254
255
        meta_data.append(
256
            pd.DataFrame({
257
                'img_filename': meta['img_filename'],
258
                'y_true': meta['y_true'].numpy(),
259
                'y_pred_probs': y_pred_probs,
260
                'patient': meta['patient'].numpy(),
261
                'study': meta['study'].numpy(),
262
                'image_num': meta['image_num'].numpy(),
263
                'encounter': meta['encounter'],
264
            }))
265
266
        pbar.set_description("VALIDATION[{}/{}]".format(i, len(val_loader)))
267
        pbar.set_postfix(
268
            acc="{acc.val:.4f} ({acc.avg:.4f})".format(acc=acc),
269
            loss="{loss.val:.4f} ({loss.avg:.4f})".format(loss=losses))
270
    df = pd.concat(meta_data)
271
    ab = df.groupby(['encounter'])['y_pred_probs', 'y_true'].mean()
272
    ab['y_pred_round'] = ab.y_pred_probs.round()
273
    ab['y_pred_round'] = pd.to_numeric(ab.y_pred_round, downcast='integer')
274
275
    f1_s = f1_score(ab.y_true, ab.y_pred_round)
276
    prec_s = precision_score(ab.y_true, ab.y_pred_round)
277
    rec_s = recall_score(ab.y_true, ab.y_pred_round)
278
    acc_s = accuracy_score(ab.y_true, ab.y_pred_round)
279
    tb_writer.add_scalar('val/f1_score', f1_s, epoch)
280
    tb_writer.add_scalar('val/precision', prec_s, epoch)
281
    tb_writer.add_scalar('val/recall', rec_s, epoch)
282
    tb_writer.add_scalar('val/accuracy', acc_s, epoch)
283
    # return the metric we want to evaluate this model's performance by
284
    return f1_s
285
286
287
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
288
    torch.save(state, filename)
289
    if is_best:
290
        shutil.copyfile(filename, 'model_best.pth.tar')
291
292
293
class AverageMeter(object):
294
    """Computes and stores the average and current value"""
295
296
    def __init__(self):
297
        self.reset()
298
299
    def reset(self):
300
        self.val = 0
301
        self.avg = 0
302
        self.sum = 0
303
        self.count = 0
304
305
    def update(self, val, n=1):
306
        self.val = val
307
        self.sum += val * n
308
        self.count += n
309
        self.avg = self.sum / self.count
310
311
312
def accuracy(y_pred, y_actual, topk=(1, )):
313
    """Computes the precision@k for the specified values of k"""
314
    maxk = max(topk)
315
    batch_size = y_actual.size(0)
316
317
    _, pred = y_pred.topk(maxk, 1, True, True)
318
    pred = pred.t()
319
    correct = pred.eq(y_actual.view(1, -1).expand_as(pred))
320
321
    res = []
322
    for k in topk:
323
        correct_k = correct[:k].view(-1).float().sum(0)
324
        res.append(correct_k.mul_(100.0 / batch_size))
325
326
    return res
327
328
329
if __name__ == '__main__':
330
    main()