Diff of /train.py [000000] .. [455abf]

Switch to side-by-side view

--- a
+++ b/train.py
@@ -0,0 +1,458 @@
+#!/usr/bin/env python3
+#
+# Note -- this training script is tweaked from the original at:
+#           https://github.com/pytorch/examples/tree/master/imagenet
+#
+# For a step-by-step guide to transfer learning with PyTorch, see:
+#           https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html
+#
+import argparse
+import os
+import random
+
+import time
+import shutil
+import warnings
+import datetime
+
+import torch
+import torch.nn as nn
+import torch.nn.parallel
+import torch.nn.functional as F
+import torch.backends.cudnn as cudnn
+import torch.optim
+import torch.utils.data
+import torchvision.transforms as transforms
+import torchvision.datasets as datasets
+import torchvision.models as models
+
+from torch.utils.tensorboard import SummaryWriter
+
+from voc import VOCDataset
+from nuswide import NUSWideDataset
+from reshape import reshape_model
+
+
+# get the available network architectures
+model_names = sorted(name for name in models.__dict__
+    if name.islower() and not name.startswith("__")
+    and callable(models.__dict__[name]))
+
+
+# parse command-line arguments
+parser = argparse.ArgumentParser(description='PyTorch Image Classifier Training')
+
+parser.add_argument('data', metavar='DIR',
+                    help='path to dataset')
+parser.add_argument('--dataset-type', type=str, default='folder',
+                    choices=['folder', 'nuswide', 'voc'],
+                    help='specify the dataset type (default: folder)')
+parser.add_argument('--multi-label', action='store_true',
+                    help='multi-label model (aka image tagging)')
+parser.add_argument('--multi-label-threshold', type=float, default=0.5,
+                    help='confidence threshold for counting a prediction as correct')
+parser.add_argument('--model-dir', type=str, default='models', 
+                    help='path to desired output directory for saving model '
+					'checkpoints (default: models/)')
+parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
+                    choices=model_names,
+                    help='model architecture: ' + ' | '.join(model_names) + ' (default: resnet18)')
+parser.add_argument('--resolution', default=224, type=int, metavar='N',
+                    help='input NxN image resolution of model (default: 224x224) '
+                         'note than Inception models should use 299x299')
+parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
+                    help='number of data loading workers (default: 2)')
+parser.add_argument('--epochs', default=35, type=int, metavar='N',
+                    help='number of total epochs to run')
+parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
+                    help='manual epoch number (useful on restarts)')
+parser.add_argument('-b', '--batch-size', default=8, type=int, metavar='N',
+                    help='mini-batch size (default: 8)')
+parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
+                    metavar='LR', help='initial learning rate', dest='lr')
+parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
+                    help='momentum')
+parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
+                    metavar='W', help='weight decay (default: 1e-4)',
+                    dest='weight_decay')
+parser.add_argument('-p', '--print-freq', default=10, type=int,
+                    metavar='N', help='print frequency (default: 10)')
+parser.add_argument('--resume', default='', type=str, metavar='PATH',
+                    help='path to latest checkpoint (default: none)')
+parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
+                    help='evaluate model on validation set')
+parser.add_argument('--pretrained', dest='pretrained', action='store_true', default=True,
+                    help='use pre-trained model')
+parser.add_argument('--seed', default=None, type=int,
+                    help='seed for initializing training')
+parser.add_argument('--gpu', default=0, type=int,
+                    help='GPU ID to use (default: 0)')
+
+args = parser.parse_args()
+
+
+# open tensorboard logger (to model_dir/tensorboard)
+tensorboard = SummaryWriter(log_dir=os.path.join(args.model_dir, "tensorboard", f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"))
+print(f"To start tensorboard run:  tensorboard --log-dir={os.path.join(args.model_dir, 'tensorboard')}")
+
+# variable for storing the best model accuracy so far
+best_accuracy = 0
+
+
+def main(args):
+    """
+    Load dataset, setup model, and train for N epochs
+    """
+    global best_accuracy
+    
+    if args.seed is not None:
+        random.seed(args.seed)
+        torch.manual_seed(args.seed)
+        cudnn.deterministic = True
+        warnings.warn('You have chosen to seed training. '
+                      'This will turn on the CUDNN deterministic setting, '
+                      'which can slow down your training considerably! '
+                      'You may see unexpected behavior when restarting '
+                      'from checkpoints.')
+
+    if args.gpu is not None:
+        print(f"=> using GPU {args.gpu} ({torch.cuda.get_device_name(args.gpu)})")
+
+    # setup data transformations
+    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
+                                     std=[0.229, 0.224, 0.225])
+
+    train_transforms = transforms.Compose([
+        transforms.RandomResizedCrop(args.resolution),
+        transforms.RandomHorizontalFlip(),
+        transforms.ToTensor(),
+        normalize,
+    ])
+        
+    val_transforms = transforms.Compose([
+        transforms.Resize(args.resolution),
+        transforms.CenterCrop(args.resolution),
+        transforms.ToTensor(),
+        normalize,
+    ])
+        
+    # load the dataset
+    if args.dataset_type == 'folder':
+        train_dataset = datasets.ImageFolder(os.path.join(args.data, 'train'), train_transforms)
+        val_dataset = datasets.ImageFolder(os.path.join(args.data, 'val'), val_transforms)
+    elif args.dataset_type == 'nuswide':
+        train_dataset = NUSWideDataset(args.data, 'trainval', train_transforms)
+        val_dataset = NUSWideDataset(args.data, 'test', val_transforms)
+    elif args.dataset_type == 'voc':
+        train_dataset = VOCDataset(args.data, 'trainval', train_transforms)
+        val_dataset = VOCDataset(args.data, 'val', val_transforms)
+    
+    if (args.dataset_type == 'nuswide' or args.dataset_type == 'voc') and (not args.multi_label):
+        raise ValueError("nuswide or voc datasets should be run with --multi-label")
+        
+    print(f"=> dataset classes:  {len(train_dataset.classes)}  {train_dataset.classes}")
+
+    train_loader = torch.utils.data.DataLoader(
+        train_dataset, batch_size=args.batch_size, shuffle=True,
+        num_workers=args.workers, pin_memory=True)
+
+    val_loader = torch.utils.data.DataLoader(
+        val_dataset, batch_size=args.batch_size, shuffle=False,
+        num_workers=args.workers, pin_memory=True)
+
+    # create or load the model if using pre-trained (the default)
+    if args.pretrained:
+        print(f"=> using pre-trained model '{args.arch}'")
+        model = models.__dict__[args.arch](pretrained=True)
+    else:
+        print(f"=> creating model '{args.arch}'")
+        model = models.__dict__[args.arch]()
+
+    # reshape the model for the number of classes in the dataset
+    model = reshape_model(model, args.arch, len(train_dataset.classes))
+
+    # define loss function (criterion) and optimizer
+    if args.multi_label:
+        criterion = nn.BCEWithLogitsLoss()
+    else:
+        criterion = nn.CrossEntropyLoss()
+
+    optimizer = torch.optim.SGD(model.parameters(), args.lr,
+                                momentum=args.momentum,
+                                weight_decay=args.weight_decay)
+        
+    # transfer the model to the GPU that it should be run on
+    if args.gpu is not None:
+        torch.cuda.set_device(args.gpu)
+        model = model.cuda(args.gpu)
+        criterion = criterion.cuda(args.gpu)
+
+    # optionally resume from a checkpoint
+    if args.resume:
+        if os.path.isfile(args.resume):
+            print(f"=> loading checkpoint '{args.resume}'")
+            checkpoint = torch.load(args.resume)
+            args.start_epoch = checkpoint['epoch'] + 1
+            best_accuracy = checkpoint['best_accuracy']
+            if args.gpu is not None:
+                best_accuracy = best_accuracy.to(args.gpu)   # best_accuracy may be from a checkpoint from a different GPU
+            model.load_state_dict(checkpoint['state_dict'])
+            optimizer.load_state_dict(checkpoint['optimizer'])
+            print(f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})")
+        else:
+            print(f"=> no checkpoint found at '{args.resume}'")
+
+    cudnn.benchmark = True
+
+    # if in evaluation mode, only run validation
+    if args.evaluate:
+        validate(val_loader, model, criterion, 0)
+        return
+
+    # train for the specified number of epochs
+    for epoch in range(args.start_epoch, args.epochs):
+        # decay the learning rate
+        adjust_learning_rate(optimizer, epoch)
+
+        # train for one epoch
+        train_loss, train_acc = train(train_loader, model, criterion, optimizer, epoch)
+
+        # evaluate on validation set
+        val_loss, val_acc = validate(val_loader, model, criterion, epoch)
+
+        # remember best acc@1 and save checkpoint
+        is_best = val_acc > best_accuracy
+        best_accuracy = max(val_acc, best_accuracy)
+
+        print(f"=> Epoch {epoch}")
+        print(f"  * Train Loss     {train_loss:.4e}")
+        print(f"  * Train Accuracy {train_acc:.4f}")
+        print(f"  * Val Loss       {val_loss:.4e}")
+        print(f"  * Val Accuracy   {val_acc:.4f}{'*' if is_best else ''}")
+        
+        save_checkpoint({
+            'epoch': epoch,
+            'arch': args.arch,
+            'resolution': args.resolution,
+            'classes': train_dataset.classes,
+            'num_classes': len(train_dataset.classes),
+            'multi_label': args.multi_label,
+            'state_dict': model.state_dict(),
+            'accuracy': {'train': train_acc, 'val': val_acc},
+            'loss' : {'train': train_loss, 'val': val_loss},
+            'optimizer' : optimizer.state_dict(),
+        }, is_best)
+
+
+def train(train_loader, model, criterion, optimizer, epoch):
+    """
+    Train one epoch over the dataset
+    """
+    batch_time = AverageMeter('Time', ':6.3f')
+    data_time = AverageMeter('Data', ':6.3f')
+    losses = AverageMeter('Loss', ':.4e')
+    acc = AverageMeter('Accuracy', ':7.3f')
+    
+    progress = ProgressMeter(
+        len(train_loader),
+        [batch_time, data_time, losses, acc],
+        prefix=f"Epoch: [{epoch}]")
+
+    # switch to train mode
+    model.train()
+
+    # get the start time
+    epoch_start = time.time()
+    end = epoch_start
+
+    # train over each image batch from the dataset
+    for i, (images, target) in enumerate(train_loader):
+        # measure data loading time
+        data_time.update(time.time() - end)
+
+        if args.gpu is not None:
+            images = images.cuda(args.gpu, non_blocking=True)
+            target = target.cuda(args.gpu, non_blocking=True)
+
+        # compute output
+        output = model(images)
+        loss = criterion(output, target)
+
+        # record loss and measure accuracy
+        losses.update(loss.item(), images.size(0))
+        acc.update(accuracy(output, target), images.size(0))
+
+        # compute gradient and do SGD step
+        optimizer.zero_grad()
+        loss.backward()
+        optimizer.step()
+
+        # measure elapsed time
+        batch_time.update(time.time() - end)
+        end = time.time()
+
+        if i % args.print_freq == 0 or i == len(train_loader)-1:
+            progress.display(i)
+    
+    print(f"Epoch: [{epoch}] completed, elapsed time {time.time() - epoch_start:6.3f} seconds")
+
+    tensorboard.add_scalar('Loss/train', losses.avg, epoch)
+    tensorboard.add_scalar('Accuracy/train', acc.avg, epoch)
+
+    return losses.avg, acc.avg
+    
+
+def validate(val_loader, model, criterion, epoch):
+    """
+    Measure model performance across the val dataset
+    """
+    batch_time = AverageMeter('Time', ':6.3f')
+    losses = AverageMeter('Loss', ':.4e')
+    acc = AverageMeter('Accuracy', ':7.3f')
+    
+    progress = ProgressMeter(
+        len(val_loader),
+        [batch_time, losses, acc],
+        prefix='Val:   ')
+
+    # switch to evaluate mode
+    model.eval()
+
+    with torch.no_grad():
+        end = time.time()
+        for i, (images, target) in enumerate(val_loader):
+            if args.gpu is not None:
+                images = images.cuda(args.gpu, non_blocking=True)
+                target = target.cuda(args.gpu, non_blocking=True)
+
+            # compute output
+            output = model(images)
+            loss = criterion(output, target)
+
+            # record loss and measure accuracy
+            losses.update(loss.item(), images.size(0))
+            acc.update(accuracy(output, target), images.size(0))
+            
+            # measure elapsed time
+            batch_time.update(time.time() - end)
+            end = time.time()
+
+            if i % args.print_freq == 0 or i == len(val_loader)-1:
+                progress.display(i)
+
+    tensorboard.add_scalar('Loss/val', losses.avg, epoch)
+    tensorboard.add_scalar('Accuracy/val', acc.avg, epoch)
+    
+    return losses.avg, acc.avg
+
+
+def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', best_filename='model_best.pth.tar', labels_filename='labels.txt'):
+    """
+    Save a model checkpoint file, along with the best-performing model if applicable
+    """
+    if args.model_dir:
+        model_dir = os.path.expanduser(args.model_dir)
+
+        if not os.path.exists(model_dir):
+            os.mkdir(model_dir)
+
+        filename = os.path.join(model_dir, filename)
+        best_filename = os.path.join(model_dir, best_filename)
+        labels_filename = os.path.join(model_dir, labels_filename)
+        
+    # save the checkpoint
+    torch.save(state, filename)
+            
+    # earmark the best checkpoint
+    if is_best:
+        shutil.copyfile(filename, best_filename)
+        print(f"saved best model to:  {best_filename}")
+    else:
+        print(f"saved checkpoint to:  {filename}")
+        
+    # save labels.txt on the first epoch
+    if state['epoch'] == 0:
+        with open(labels_filename, 'w') as file:
+            for label in state['classes']:
+                file.write(f"{label}\n")
+        print(f"saved class labels to:  {labels_filename}")
+            
+
+def adjust_learning_rate(optimizer, epoch):
+    """
+    Sets the learning rate to the initial LR decayed by 10 every 30 epochs
+    """
+    lr = args.lr * (0.1 ** (epoch // 30))
+    for param_group in optimizer.param_groups:
+        param_group['lr'] = lr
+
+
+def accuracy(output, target):
+    """
+    Computes the accuracy of predictions vs groundtruth
+    """
+    with torch.no_grad():
+        if args.multi_label:
+            output = F.sigmoid(output)
+            preds = ((output >= args.multi_label_threshold) == target.bool())   # https://medium.com/@yrodriguezmd/tackling-the-accuracy-multi-metric-9e2356f62513
+            
+            # https://stackoverflow.com/a/61585551
+            #output[output >= args.multi_label_threshold] = 1
+            #output[output < args.multi_label_threshold] = 0
+            #preds = (output == target)
+        else:
+            output = F.softmax(output, dim=-1)
+            _, preds = torch.max(output, dim=-1)
+            preds = (preds == target)
+            
+        return preds.float().mean().cpu().item() * 100.0
+        
+        
+class AverageMeter(object):
+    """
+    Computes and stores the average and current value
+    """
+    def __init__(self, name, fmt=':f'):
+        self.name = name
+        self.fmt = fmt
+        self.reset()
+
+    def reset(self):
+        self.val = 0
+        self.avg = 0
+        self.sum = 0
+        self.count = 0
+
+    def update(self, val, n=1):
+        self.val = val
+        self.sum += val * n
+        self.count += n
+        self.avg = self.sum / self.count
+
+    def __str__(self):
+        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
+        return fmtstr.format(**self.__dict__)
+
+
+class ProgressMeter(object):
+    """
+    Progress metering
+    """
+    def __init__(self, num_batches, meters, prefix=""):
+        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
+        self.meters = meters
+        self.prefix = prefix
+
+    def display(self, batch):
+        entries = [self.prefix + self.batch_fmtstr.format(batch)]
+        entries += [str(meter) for meter in self.meters]
+        print('  '.join(entries))
+
+    def _get_batch_fmtstr(self, num_batches):
+        num_digits = len(str(num_batches // 1))
+        fmt = '{:' + str(num_digits) + 'd}'
+        return '[' + fmt + '/' + fmt.format(num_batches) + ']'
+
+
+if __name__ == '__main__':
+    main(args)