Diff of /train.py [000000] .. [455abf]

Switch to unified view

a b/train.py
1
#!/usr/bin/env python3
2
#
3
# Note -- this training script is tweaked from the original at:
4
#           https://github.com/pytorch/examples/tree/master/imagenet
5
#
6
# For a step-by-step guide to transfer learning with PyTorch, see:
7
#           https://pytorch.org/tutorials/beginner/finetuning_torchvision_models_tutorial.html
8
#
9
import argparse
10
import os
11
import random
12
13
import time
14
import shutil
15
import warnings
16
import datetime
17
18
import torch
19
import torch.nn as nn
20
import torch.nn.parallel
21
import torch.nn.functional as F
22
import torch.backends.cudnn as cudnn
23
import torch.optim
24
import torch.utils.data
25
import torchvision.transforms as transforms
26
import torchvision.datasets as datasets
27
import torchvision.models as models
28
29
from torch.utils.tensorboard import SummaryWriter
30
31
from voc import VOCDataset
32
from nuswide import NUSWideDataset
33
from reshape import reshape_model
34
35
36
# get the available network architectures
37
model_names = sorted(name for name in models.__dict__
38
    if name.islower() and not name.startswith("__")
39
    and callable(models.__dict__[name]))
40
41
42
# parse command-line arguments
43
parser = argparse.ArgumentParser(description='PyTorch Image Classifier Training')
44
45
parser.add_argument('data', metavar='DIR',
46
                    help='path to dataset')
47
parser.add_argument('--dataset-type', type=str, default='folder',
48
                    choices=['folder', 'nuswide', 'voc'],
49
                    help='specify the dataset type (default: folder)')
50
parser.add_argument('--multi-label', action='store_true',
51
                    help='multi-label model (aka image tagging)')
52
parser.add_argument('--multi-label-threshold', type=float, default=0.5,
53
                    help='confidence threshold for counting a prediction as correct')
54
parser.add_argument('--model-dir', type=str, default='models', 
55
                    help='path to desired output directory for saving model '
56
                    'checkpoints (default: models/)')
57
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
58
                    choices=model_names,
59
                    help='model architecture: ' + ' | '.join(model_names) + ' (default: resnet18)')
60
parser.add_argument('--resolution', default=224, type=int, metavar='N',
61
                    help='input NxN image resolution of model (default: 224x224) '
62
                         'note than Inception models should use 299x299')
63
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
64
                    help='number of data loading workers (default: 2)')
65
parser.add_argument('--epochs', default=35, type=int, metavar='N',
66
                    help='number of total epochs to run')
67
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
68
                    help='manual epoch number (useful on restarts)')
69
parser.add_argument('-b', '--batch-size', default=8, type=int, metavar='N',
70
                    help='mini-batch size (default: 8)')
71
parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
72
                    metavar='LR', help='initial learning rate', dest='lr')
73
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
74
                    help='momentum')
75
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
76
                    metavar='W', help='weight decay (default: 1e-4)',
77
                    dest='weight_decay')
78
parser.add_argument('-p', '--print-freq', default=10, type=int,
79
                    metavar='N', help='print frequency (default: 10)')
80
parser.add_argument('--resume', default='', type=str, metavar='PATH',
81
                    help='path to latest checkpoint (default: none)')
82
parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
83
                    help='evaluate model on validation set')
84
parser.add_argument('--pretrained', dest='pretrained', action='store_true', default=True,
85
                    help='use pre-trained model')
86
parser.add_argument('--seed', default=None, type=int,
87
                    help='seed for initializing training')
88
parser.add_argument('--gpu', default=0, type=int,
89
                    help='GPU ID to use (default: 0)')
90
91
args = parser.parse_args()
92
93
94
# open tensorboard logger (to model_dir/tensorboard)
95
tensorboard = SummaryWriter(log_dir=os.path.join(args.model_dir, "tensorboard", f"{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"))
96
print(f"To start tensorboard run:  tensorboard --log-dir={os.path.join(args.model_dir, 'tensorboard')}")
97
98
# variable for storing the best model accuracy so far
99
best_accuracy = 0
100
101
102
def main(args):
103
    """
104
    Load dataset, setup model, and train for N epochs
105
    """
106
    global best_accuracy
107
    
108
    if args.seed is not None:
109
        random.seed(args.seed)
110
        torch.manual_seed(args.seed)
111
        cudnn.deterministic = True
112
        warnings.warn('You have chosen to seed training. '
113
                      'This will turn on the CUDNN deterministic setting, '
114
                      'which can slow down your training considerably! '
115
                      'You may see unexpected behavior when restarting '
116
                      'from checkpoints.')
117
118
    if args.gpu is not None:
119
        print(f"=> using GPU {args.gpu} ({torch.cuda.get_device_name(args.gpu)})")
120
121
    # setup data transformations
122
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
123
                                     std=[0.229, 0.224, 0.225])
124
125
    train_transforms = transforms.Compose([
126
        transforms.RandomResizedCrop(args.resolution),
127
        transforms.RandomHorizontalFlip(),
128
        transforms.ToTensor(),
129
        normalize,
130
    ])
131
        
132
    val_transforms = transforms.Compose([
133
        transforms.Resize(args.resolution),
134
        transforms.CenterCrop(args.resolution),
135
        transforms.ToTensor(),
136
        normalize,
137
    ])
138
        
139
    # load the dataset
140
    if args.dataset_type == 'folder':
141
        train_dataset = datasets.ImageFolder(os.path.join(args.data, 'train'), train_transforms)
142
        val_dataset = datasets.ImageFolder(os.path.join(args.data, 'val'), val_transforms)
143
    elif args.dataset_type == 'nuswide':
144
        train_dataset = NUSWideDataset(args.data, 'trainval', train_transforms)
145
        val_dataset = NUSWideDataset(args.data, 'test', val_transforms)
146
    elif args.dataset_type == 'voc':
147
        train_dataset = VOCDataset(args.data, 'trainval', train_transforms)
148
        val_dataset = VOCDataset(args.data, 'val', val_transforms)
149
    
150
    if (args.dataset_type == 'nuswide' or args.dataset_type == 'voc') and (not args.multi_label):
151
        raise ValueError("nuswide or voc datasets should be run with --multi-label")
152
        
153
    print(f"=> dataset classes:  {len(train_dataset.classes)}  {train_dataset.classes}")
154
155
    train_loader = torch.utils.data.DataLoader(
156
        train_dataset, batch_size=args.batch_size, shuffle=True,
157
        num_workers=args.workers, pin_memory=True)
158
159
    val_loader = torch.utils.data.DataLoader(
160
        val_dataset, batch_size=args.batch_size, shuffle=False,
161
        num_workers=args.workers, pin_memory=True)
162
163
    # create or load the model if using pre-trained (the default)
164
    if args.pretrained:
165
        print(f"=> using pre-trained model '{args.arch}'")
166
        model = models.__dict__[args.arch](pretrained=True)
167
    else:
168
        print(f"=> creating model '{args.arch}'")
169
        model = models.__dict__[args.arch]()
170
171
    # reshape the model for the number of classes in the dataset
172
    model = reshape_model(model, args.arch, len(train_dataset.classes))
173
174
    # define loss function (criterion) and optimizer
175
    if args.multi_label:
176
        criterion = nn.BCEWithLogitsLoss()
177
    else:
178
        criterion = nn.CrossEntropyLoss()
179
180
    optimizer = torch.optim.SGD(model.parameters(), args.lr,
181
                                momentum=args.momentum,
182
                                weight_decay=args.weight_decay)
183
        
184
    # transfer the model to the GPU that it should be run on
185
    if args.gpu is not None:
186
        torch.cuda.set_device(args.gpu)
187
        model = model.cuda(args.gpu)
188
        criterion = criterion.cuda(args.gpu)
189
190
    # optionally resume from a checkpoint
191
    if args.resume:
192
        if os.path.isfile(args.resume):
193
            print(f"=> loading checkpoint '{args.resume}'")
194
            checkpoint = torch.load(args.resume)
195
            args.start_epoch = checkpoint['epoch'] + 1
196
            best_accuracy = checkpoint['best_accuracy']
197
            if args.gpu is not None:
198
                best_accuracy = best_accuracy.to(args.gpu)   # best_accuracy may be from a checkpoint from a different GPU
199
            model.load_state_dict(checkpoint['state_dict'])
200
            optimizer.load_state_dict(checkpoint['optimizer'])
201
            print(f"=> loaded checkpoint '{args.resume}' (epoch {checkpoint['epoch']})")
202
        else:
203
            print(f"=> no checkpoint found at '{args.resume}'")
204
205
    cudnn.benchmark = True
206
207
    # if in evaluation mode, only run validation
208
    if args.evaluate:
209
        validate(val_loader, model, criterion, 0)
210
        return
211
212
    # train for the specified number of epochs
213
    for epoch in range(args.start_epoch, args.epochs):
214
        # decay the learning rate
215
        adjust_learning_rate(optimizer, epoch)
216
217
        # train for one epoch
218
        train_loss, train_acc = train(train_loader, model, criterion, optimizer, epoch)
219
220
        # evaluate on validation set
221
        val_loss, val_acc = validate(val_loader, model, criterion, epoch)
222
223
        # remember best acc@1 and save checkpoint
224
        is_best = val_acc > best_accuracy
225
        best_accuracy = max(val_acc, best_accuracy)
226
227
        print(f"=> Epoch {epoch}")
228
        print(f"  * Train Loss     {train_loss:.4e}")
229
        print(f"  * Train Accuracy {train_acc:.4f}")
230
        print(f"  * Val Loss       {val_loss:.4e}")
231
        print(f"  * Val Accuracy   {val_acc:.4f}{'*' if is_best else ''}")
232
        
233
        save_checkpoint({
234
            'epoch': epoch,
235
            'arch': args.arch,
236
            'resolution': args.resolution,
237
            'classes': train_dataset.classes,
238
            'num_classes': len(train_dataset.classes),
239
            'multi_label': args.multi_label,
240
            'state_dict': model.state_dict(),
241
            'accuracy': {'train': train_acc, 'val': val_acc},
242
            'loss' : {'train': train_loss, 'val': val_loss},
243
            'optimizer' : optimizer.state_dict(),
244
        }, is_best)
245
246
247
def train(train_loader, model, criterion, optimizer, epoch):
248
    """
249
    Train one epoch over the dataset
250
    """
251
    batch_time = AverageMeter('Time', ':6.3f')
252
    data_time = AverageMeter('Data', ':6.3f')
253
    losses = AverageMeter('Loss', ':.4e')
254
    acc = AverageMeter('Accuracy', ':7.3f')
255
    
256
    progress = ProgressMeter(
257
        len(train_loader),
258
        [batch_time, data_time, losses, acc],
259
        prefix=f"Epoch: [{epoch}]")
260
261
    # switch to train mode
262
    model.train()
263
264
    # get the start time
265
    epoch_start = time.time()
266
    end = epoch_start
267
268
    # train over each image batch from the dataset
269
    for i, (images, target) in enumerate(train_loader):
270
        # measure data loading time
271
        data_time.update(time.time() - end)
272
273
        if args.gpu is not None:
274
            images = images.cuda(args.gpu, non_blocking=True)
275
            target = target.cuda(args.gpu, non_blocking=True)
276
277
        # compute output
278
        output = model(images)
279
        loss = criterion(output, target)
280
281
        # record loss and measure accuracy
282
        losses.update(loss.item(), images.size(0))
283
        acc.update(accuracy(output, target), images.size(0))
284
285
        # compute gradient and do SGD step
286
        optimizer.zero_grad()
287
        loss.backward()
288
        optimizer.step()
289
290
        # measure elapsed time
291
        batch_time.update(time.time() - end)
292
        end = time.time()
293
294
        if i % args.print_freq == 0 or i == len(train_loader)-1:
295
            progress.display(i)
296
    
297
    print(f"Epoch: [{epoch}] completed, elapsed time {time.time() - epoch_start:6.3f} seconds")
298
299
    tensorboard.add_scalar('Loss/train', losses.avg, epoch)
300
    tensorboard.add_scalar('Accuracy/train', acc.avg, epoch)
301
302
    return losses.avg, acc.avg
303
    
304
305
def validate(val_loader, model, criterion, epoch):
306
    """
307
    Measure model performance across the val dataset
308
    """
309
    batch_time = AverageMeter('Time', ':6.3f')
310
    losses = AverageMeter('Loss', ':.4e')
311
    acc = AverageMeter('Accuracy', ':7.3f')
312
    
313
    progress = ProgressMeter(
314
        len(val_loader),
315
        [batch_time, losses, acc],
316
        prefix='Val:   ')
317
318
    # switch to evaluate mode
319
    model.eval()
320
321
    with torch.no_grad():
322
        end = time.time()
323
        for i, (images, target) in enumerate(val_loader):
324
            if args.gpu is not None:
325
                images = images.cuda(args.gpu, non_blocking=True)
326
                target = target.cuda(args.gpu, non_blocking=True)
327
328
            # compute output
329
            output = model(images)
330
            loss = criterion(output, target)
331
332
            # record loss and measure accuracy
333
            losses.update(loss.item(), images.size(0))
334
            acc.update(accuracy(output, target), images.size(0))
335
            
336
            # measure elapsed time
337
            batch_time.update(time.time() - end)
338
            end = time.time()
339
340
            if i % args.print_freq == 0 or i == len(val_loader)-1:
341
                progress.display(i)
342
343
    tensorboard.add_scalar('Loss/val', losses.avg, epoch)
344
    tensorboard.add_scalar('Accuracy/val', acc.avg, epoch)
345
    
346
    return losses.avg, acc.avg
347
348
349
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', best_filename='model_best.pth.tar', labels_filename='labels.txt'):
350
    """
351
    Save a model checkpoint file, along with the best-performing model if applicable
352
    """
353
    if args.model_dir:
354
        model_dir = os.path.expanduser(args.model_dir)
355
356
        if not os.path.exists(model_dir):
357
            os.mkdir(model_dir)
358
359
        filename = os.path.join(model_dir, filename)
360
        best_filename = os.path.join(model_dir, best_filename)
361
        labels_filename = os.path.join(model_dir, labels_filename)
362
        
363
    # save the checkpoint
364
    torch.save(state, filename)
365
            
366
    # earmark the best checkpoint
367
    if is_best:
368
        shutil.copyfile(filename, best_filename)
369
        print(f"saved best model to:  {best_filename}")
370
    else:
371
        print(f"saved checkpoint to:  {filename}")
372
        
373
    # save labels.txt on the first epoch
374
    if state['epoch'] == 0:
375
        with open(labels_filename, 'w') as file:
376
            for label in state['classes']:
377
                file.write(f"{label}\n")
378
        print(f"saved class labels to:  {labels_filename}")
379
            
380
381
def adjust_learning_rate(optimizer, epoch):
382
    """
383
    Sets the learning rate to the initial LR decayed by 10 every 30 epochs
384
    """
385
    lr = args.lr * (0.1 ** (epoch // 30))
386
    for param_group in optimizer.param_groups:
387
        param_group['lr'] = lr
388
389
390
def accuracy(output, target):
391
    """
392
    Computes the accuracy of predictions vs groundtruth
393
    """
394
    with torch.no_grad():
395
        if args.multi_label:
396
            output = F.sigmoid(output)
397
            preds = ((output >= args.multi_label_threshold) == target.bool())   # https://medium.com/@yrodriguezmd/tackling-the-accuracy-multi-metric-9e2356f62513
398
            
399
            # https://stackoverflow.com/a/61585551
400
            #output[output >= args.multi_label_threshold] = 1
401
            #output[output < args.multi_label_threshold] = 0
402
            #preds = (output == target)
403
        else:
404
            output = F.softmax(output, dim=-1)
405
            _, preds = torch.max(output, dim=-1)
406
            preds = (preds == target)
407
            
408
        return preds.float().mean().cpu().item() * 100.0
409
        
410
        
411
class AverageMeter(object):
412
    """
413
    Computes and stores the average and current value
414
    """
415
    def __init__(self, name, fmt=':f'):
416
        self.name = name
417
        self.fmt = fmt
418
        self.reset()
419
420
    def reset(self):
421
        self.val = 0
422
        self.avg = 0
423
        self.sum = 0
424
        self.count = 0
425
426
    def update(self, val, n=1):
427
        self.val = val
428
        self.sum += val * n
429
        self.count += n
430
        self.avg = self.sum / self.count
431
432
    def __str__(self):
433
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
434
        return fmtstr.format(**self.__dict__)
435
436
437
class ProgressMeter(object):
438
    """
439
    Progress metering
440
    """
441
    def __init__(self, num_batches, meters, prefix=""):
442
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
443
        self.meters = meters
444
        self.prefix = prefix
445
446
    def display(self, batch):
447
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
448
        entries += [str(meter) for meter in self.meters]
449
        print('  '.join(entries))
450
451
    def _get_batch_fmtstr(self, num_batches):
452
        num_digits = len(str(num_batches // 1))
453
        fmt = '{:' + str(num_digits) + 'd}'
454
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'
455
456
457
if __name__ == '__main__':
458
    main(args)