Diff of /pytorch/train.py [000000] .. [bca7a0]

Switch to side-by-side view

--- a
+++ b/pytorch/train.py
@@ -0,0 +1,330 @@
+from __future__ import absolute_import, division, print_function
+
+import argparse
+import random
+import shutil
+from os import getcwd
+from os.path import exists, isdir, isfile, join
+
+import numpy as np
+import pandas as pd
+import torch
+import torch.backends.cudnn as cudnn
+import torch.nn as nn
+import torch.nn.parallel
+import torch.optim as optim
+import torch.utils.data as data
+from sklearn.metrics import (accuracy_score, f1_score, precision_score, recall_score)
+from tensorboardX import SummaryWriter
+from torch.autograd import Variable
+from torch.optim.lr_scheduler import ReduceLROnPlateau
+from tqdm import tqdm
+
+import torchvision
+import torchvision.models as models
+import torchvision.transforms as transforms
+from dataloader import MuraDataset
+
+print("torch : {}".format(torch.__version__))
+print("torch vision : {}".format(torchvision.__version__))
+print("numpy : {}".format(np.__version__))
+print("pandas : {}".format(pd.__version__))
+model_names = sorted(name for name in models.__dict__ if name.islower() and not name.startswith("__"))
+
+parser = argparse.ArgumentParser(description='Hyperparameters')
+parser.add_argument('--data_dir', default='MURA-v1.0', metavar='DIR', help='path to dataset')
+parser.add_argument('--arch', default='densenet121', choices=model_names, help='nn architecture')
+parser.add_argument('--classes', default=2, type=int)
+parser.add_argument('--workers', default=4, type=int)
+parser.add_argument('--epochs', default=90, type=int)
+parser.add_argument('--start-epoch', default=0, type=int, help='manual epoch number')
+parser.add_argument('-b', '--batch-size', default=512, type=int, help='mini-batch size')
+parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float, help='initial learning rate')
+parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
+parser.add_argument('--weight-decay', default=.1, type=float, help='weight decay')
+parser.add_argument('--resume', default='', type=str, help='path to latest checkpoint')
+parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model')
+parser.add_argument('--fullretrain', dest='fullretrain', action='store_true', help='retrain all layers of the model')
+parser.add_argument('--seed', default=1337, type=int, help='random seed')
+
+best_val_loss = 0
+
+tb_writer = SummaryWriter()
+
+
+def main():
+    global args, best_val_loss
+    args = parser.parse_args()
+    print("=> setting random seed to '{}'".format(args.seed))
+    np.random.seed(args.seed)
+    torch.manual_seed(args.seed)
+    torch.cuda.manual_seed(args.seed)
+
+    if args.pretrained:
+        print("=> using pre-trained model '{}'".format(args.arch))
+        model = models.__dict__[args.arch](pretrained=True)
+        for param in model.parameters():
+            param.requires_grad = False
+
+        if 'resnet' in args.arch:
+            # for param in model.layer4.parameters():
+            model.fc = nn.Linear(2048, args.classes)
+
+        if 'dense' in args.arch:
+            if '121' in args.arch:
+                # (classifier): Linear(in_features=1024)
+                model.classifier = nn.Linear(1024, args.classes)
+            elif '169' in args.arch:
+                # (classifier): Linear(in_features=1664)
+                model.classifier = nn.Linear(1664, args.classes)
+            else:
+                return
+
+    else:
+        print("=> creating model '{}'".format(args.arch))
+        model = models.__dict__[args.arch]()
+
+    model = torch.nn.DataParallel(model).cuda()
+    # optionally resume from a checkpoint
+    if args.resume:
+        if isfile(args.resume):
+            print("=> found checkpoint")
+            checkpoint = torch.load(args.resume)
+            args.start_epoch = checkpoint['epoch']
+            best_val_loss = checkpoint['best_val_loss']
+            model.load_state_dict(checkpoint['state_dict'])
+
+            args.epochs = args.epochs + args.start_epoch
+            print("=> loading checkpoint '{}' with acc of '{}'".format(
+                args.resume,
+                checkpoint['best_val_loss'], ))
+
+        else:
+            print("=> no checkpoint found at '{}'".format(args.resume))
+
+    cudnn.benchmark = True
+
+    # Data loading code
+    data_dir = join(getcwd(), args.data_dir)
+    train_dir = join(data_dir, 'train')
+    train_csv = join(data_dir, 'train.csv')
+    val_dir = join(data_dir, 'valid')
+    val_csv = join(data_dir, 'valid.csv')
+    test_dir = join(data_dir, 'test')
+    assert isdir(data_dir) and isdir(train_dir) and isdir(val_dir) and isdir(test_dir)
+    assert exists(train_csv) and isfile(train_csv) and exists(val_csv) and isfile(val_csv)
+
+    # Before feeding images into the network, we normalize each image to have
+    # the same mean and standard deviation of images in the ImageNet training set.
+    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+
+    # We then scale the variable-sized images to 224 × 224.
+    # We augment by applying random lateral inversions and rotations.
+    train_transforms = transforms.Compose([
+        transforms.Resize(224),
+        transforms.CenterCrop(224),
+        # transforms.RandomVerticalFlip(),
+        # transforms.RandomRotation(30),
+        transforms.RandomHorizontalFlip(),
+        transforms.ToTensor(),
+        normalize,
+    ])
+
+    train_data = MuraDataset(train_csv, transform=train_transforms)
+    weights = train_data.balanced_weights
+    weights = torch.DoubleTensor(weights)
+    sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights))
+
+    # num_of_sample = 37110
+    # weights = 1 / torch.DoubleTensor([24121, 1300])
+    # sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, num_of_sample)
+    train_loader = data.DataLoader(
+        train_data,
+        batch_size=args.batch_size,
+        # shuffle=True,
+        num_workers=args.workers,
+        sampler=sampler,
+        pin_memory=True)
+    val_loader = data.DataLoader(
+        MuraDataset(val_csv,
+                    transforms.Compose([
+                        transforms.Resize(224),
+                        transforms.CenterCrop(224),
+                        transforms.ToTensor(),
+                        normalize,
+                    ])),
+        batch_size=args.batch_size,
+        shuffle=False,
+        num_workers=args.workers,
+        pin_memory=True)
+
+    criterion = nn.CrossEntropyLoss().cuda()
+    # We use an initial learning rate of 0.0001 that is decayed by a factor of
+    # 10 each time the validation loss plateaus after an epoch, and pick the
+    # model with the lowest validation loss
+    if args.fullretrain:
+        print("=> optimizing all layers")
+        for param in model.parameters():
+            param.requires_grad = True
+        optimizer = optim.Adam(model.parameters(), args.lr, weight_decay=args.weight_decay)
+    else:
+        print("=> optimizing fc/classifier layers")
+        optimizer = optim.Adam(model.module.fc.parameters(), args.lr, weight_decay=args.weight_decay)
+
+    scheduler = ReduceLROnPlateau(optimizer, 'max', patience=10, verbose=True)
+    for epoch in range(args.start_epoch, args.epochs):
+        # train for one epoch
+        train(train_loader, model, criterion, optimizer, epoch)
+        # evaluate on validation set
+        val_loss = validate(val_loader, model, criterion, epoch)
+        scheduler.step(val_loss)
+        # remember best Accuracy and save checkpoint
+        is_best = val_loss > best_val_loss
+        best_val_loss = max(val_loss, best_val_loss)
+        save_checkpoint({
+            'epoch': epoch + 1,
+            'arch': args.arch,
+            'state_dict': model.state_dict(),
+            'best_val_loss': best_val_loss,
+        }, is_best)
+
+
+def train(train_loader, model, criterion, optimizer, epoch):
+    losses = AverageMeter()
+    acc = AverageMeter()
+
+    # ensure model is in train mode
+    model.train()
+    pbar = tqdm(train_loader)
+    for i, (images, target, meta) in enumerate(pbar):
+        target = target.cuda(async=True)
+        image_var = Variable(images)
+        label_var = Variable(target)
+
+        # pass this batch through our model and get y_pred
+        y_pred = model(image_var)
+
+        # update loss metric
+        loss = criterion(y_pred, label_var)
+        losses.update(loss.data[0], images.size(0))
+
+        # update accuracy metric
+        prec1, prec1 = accuracy(y_pred.data, target, topk=(1, 1))
+        acc.update(prec1[0], images.size(0))
+
+        # compute gradient and do SGD step
+        optimizer.zero_grad()
+        loss.backward()
+        optimizer.step()
+
+        pbar.set_description("EPOCH[{0}][{1}/{2}]".format(epoch, i, len(train_loader)))
+        pbar.set_postfix(
+            acc="{acc.val:.4f} ({acc.avg:.4f})".format(acc=acc),
+            loss="{loss.val:.4f} ({loss.avg:.4f})".format(loss=losses))
+
+    tb_writer.add_scalar('train/loss', losses.avg, epoch)
+    tb_writer.add_scalar('train/acc', acc.avg, epoch)
+    return
+
+
+def validate(val_loader, model, criterion, epoch):
+    model.eval()
+    acc = AverageMeter()
+    losses = AverageMeter()
+    meta_data = []
+    pbar = tqdm(val_loader)
+    for i, (images, target, meta) in enumerate(pbar):
+        target = target.cuda(async=True)
+        image_var = Variable(images, volatile=True)
+        label_var = Variable(target, volatile=True)
+
+        y_pred = model(image_var)
+        # udpate loss metric
+        loss = criterion(y_pred, label_var)
+        losses.update(loss.data[0], images.size(0))
+
+        # update accuracy metric on the GPU
+        prec1, prec1 = accuracy(y_pred.data, target, topk=(1, 1))
+        acc.update(prec1[0], images.size(0))
+
+        sm = nn.Softmax()
+        sm_pred = sm(y_pred).data.cpu().numpy()
+        # y_norm_probs = sm_pred[:, 0] # p(normal)
+        y_pred_probs = sm_pred[:, 1]  # p(abnormal)
+
+        meta_data.append(
+            pd.DataFrame({
+                'img_filename': meta['img_filename'],
+                'y_true': meta['y_true'].numpy(),
+                'y_pred_probs': y_pred_probs,
+                'patient': meta['patient'].numpy(),
+                'study': meta['study'].numpy(),
+                'image_num': meta['image_num'].numpy(),
+                'encounter': meta['encounter'],
+            }))
+
+        pbar.set_description("VALIDATION[{}/{}]".format(i, len(val_loader)))
+        pbar.set_postfix(
+            acc="{acc.val:.4f} ({acc.avg:.4f})".format(acc=acc),
+            loss="{loss.val:.4f} ({loss.avg:.4f})".format(loss=losses))
+    df = pd.concat(meta_data)
+    ab = df.groupby(['encounter'])['y_pred_probs', 'y_true'].mean()
+    ab['y_pred_round'] = ab.y_pred_probs.round()
+    ab['y_pred_round'] = pd.to_numeric(ab.y_pred_round, downcast='integer')
+
+    f1_s = f1_score(ab.y_true, ab.y_pred_round)
+    prec_s = precision_score(ab.y_true, ab.y_pred_round)
+    rec_s = recall_score(ab.y_true, ab.y_pred_round)
+    acc_s = accuracy_score(ab.y_true, ab.y_pred_round)
+    tb_writer.add_scalar('val/f1_score', f1_s, epoch)
+    tb_writer.add_scalar('val/precision', prec_s, epoch)
+    tb_writer.add_scalar('val/recall', rec_s, epoch)
+    tb_writer.add_scalar('val/accuracy', acc_s, epoch)
+    # return the metric we want to evaluate this model's performance by
+    return f1_s
+
+
+def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
+    torch.save(state, filename)
+    if is_best:
+        shutil.copyfile(filename, 'model_best.pth.tar')
+
+
+class AverageMeter(object):
+    """Computes and stores the average and current value"""
+
+    def __init__(self):
+        self.reset()
+
+    def reset(self):
+        self.val = 0
+        self.avg = 0
+        self.sum = 0
+        self.count = 0
+
+    def update(self, val, n=1):
+        self.val = val
+        self.sum += val * n
+        self.count += n
+        self.avg = self.sum / self.count
+
+
+def accuracy(y_pred, y_actual, topk=(1, )):
+    """Computes the precision@k for the specified values of k"""
+    maxk = max(topk)
+    batch_size = y_actual.size(0)
+
+    _, pred = y_pred.topk(maxk, 1, True, True)
+    pred = pred.t()
+    correct = pred.eq(y_actual.view(1, -1).expand_as(pred))
+
+    res = []
+    for k in topk:
+        correct_k = correct[:k].view(-1).float().sum(0)
+        res.append(correct_k.mul_(100.0 / batch_size))
+
+    return res
+
+
+if __name__ == '__main__':
+    main()