Diff of /CRCNet/train.py [000000] .. [1fc74a]

Switch to unified view

a b/CRCNet/train.py
1
import datetime
2
import os
3
import time
4
5
import torch
6
import torch.utils.data
7
from torch import nn
8
import torchvision
9
from torchvision import transforms
10
11
import utils
12
13
try:
14
    from apex import amp
15
except ImportError:
16
    amp = None
17
18
19
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, print_freq, apex=False):
20
    model.train()
21
    metric_logger = utils.MetricLogger(delimiter="  ")
22
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}'))
23
    metric_logger.add_meter('img/s', utils.SmoothedValue(window_size=10, fmt='{value}'))
24
25
    header = 'Epoch: [{}]'.format(epoch)
26
    for image, target in metric_logger.log_every(data_loader, print_freq, header):
27
        start_time = time.time()
28
        image, target = image.to(device), target.to(device)
29
        output = model(image)
30
        loss = criterion(output, target)
31
32
        optimizer.zero_grad()
33
        if apex:
34
            with amp.scale_loss(loss, optimizer) as scaled_loss:
35
                scaled_loss.backward()
36
        else:
37
            loss.backward()
38
        optimizer.step()
39
40
        acc1, acc5 = utils.accuracy(output, target, topk=(1, 2))
41
        batch_size = image.shape[0]
42
        metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"])
43
        metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
44
        metric_logger.meters['acc2'].update(acc5.item(), n=batch_size)
45
        metric_logger.meters['img/s'].update(batch_size / (time.time() - start_time))
46
47
48
def evaluate(model, criterion, data_loader, device, print_freq=100):
49
    model.eval()
50
    metric_logger = utils.MetricLogger(delimiter="  ")
51
    header = 'Test:'
52
    with torch.no_grad():
53
        for image, target in metric_logger.log_every(data_loader, print_freq, header):
54
            image = image.to(device, non_blocking=True)
55
            target = target.to(device, non_blocking=True)
56
            output = model(image)
57
            loss = criterion(output, target)
58
59
            acc1, acc5 = utils.accuracy(output, target, topk=(1, 2))
60
            # FIXME need to take into account that the datasets
61
            # could have been padded in distributed setup
62
            batch_size = image.shape[0]
63
            metric_logger.update(loss=loss.item())
64
            metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
65
            metric_logger.meters['acc2'].update(acc5.item(), n=batch_size)
66
    # gather the stats from all processes
67
    metric_logger.synchronize_between_processes()
68
69
    print(' * Acc@1 {top1.global_avg:.3f} Acc@2 {top5.global_avg:.3f}'
70
          .format(top1=metric_logger.acc1, top5=metric_logger.acc2))
71
    return metric_logger.acc1.global_avg
72
73
74
def _get_cache_path(filepath):
75
    import hashlib
76
    h = hashlib.sha1(filepath.encode()).hexdigest()
77
    cache_path = os.path.join("~", ".torch", "vision", "datasets", "imagefolder", h[:10] + ".pt")
78
    cache_path = os.path.expanduser(cache_path)
79
    return cache_path
80
81
82
def load_data(traindir, valdir, cache_dataset, distributed):
83
    # Data loading code
84
    print("Loading data")
85
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
86
                                     std=[0.229, 0.224, 0.225])
87
88
    print("Loading training data")
89
    st = time.time()
90
    cache_path = _get_cache_path(traindir)
91
    if cache_dataset and os.path.exists(cache_path):
92
        # Attention, as the transforms are also cached!
93
        print("Loading dataset_train from {}".format(cache_path))
94
        dataset, _ = torch.load(cache_path)
95
    else:
96
        #dataset = torchvision.datasets.ImageFolder(
97
        #    traindir,
98
        #    transforms.Compose([
99
        #        transforms.RandomResizedCrop(224),
100
        #        transforms.RandomHorizontalFlip(),
101
        #        transforms.ToTensor(),
102
        #        normalize,
103
        #    ]))
104
        dataset = utils.CSVDataset(
105
            traindir,
106
            transforms.Compose([
107
                transforms.RandomResizedCrop(224),
108
                transforms.RandomPerspective(),
109
                transforms.RandomHorizontalFlip(),
110
                transforms.RandomRotation(degrees=180),
111
                transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.0),
112
                transforms.ToTensor(),
113
                normalize,
114
            ]))
115
        if cache_dataset:
116
            print("Saving dataset_train to {}".format(cache_path))
117
            utils.mkdir(os.path.dirname(cache_path))
118
            utils.save_on_master((dataset, traindir), cache_path)
119
    print("Took", time.time() - st)
120
121
    print("Loading validation data")
122
    cache_path = _get_cache_path(valdir)
123
    if cache_dataset and os.path.exists(cache_path):
124
        # Attention, as the transforms are also cached!
125
        print("Loading dataset_test from {}".format(cache_path))
126
        dataset_test, _ = torch.load(cache_path)
127
    else:
128
        #dataset_test = torchvision.datasets.ImageFolder(
129
        #    valdir,
130
        #    transforms.Compose([
131
        #        transforms.Resize(256),
132
        #        transforms.CenterCrop(224),
133
        #        transforms.ToTensor(),
134
        #        normalize,
135
        #    ]))
136
        dataset_test = utils.CSVDataset(
137
            valdir,
138
            transforms.Compose([
139
                transforms.Resize(256),
140
                transforms.CenterCrop(224),
141
                transforms.ToTensor(),
142
                normalize,
143
            ]))
144
        if cache_dataset:
145
            print("Saving dataset_test to {}".format(cache_path))
146
            utils.mkdir(os.path.dirname(cache_path))
147
            utils.save_on_master((dataset_test, valdir), cache_path)
148
149
    print("Creating data loaders")
150
    if distributed:
151
        train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
152
        test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
153
    else:
154
        train_sampler = torch.utils.data.RandomSampler(dataset)
155
        test_sampler = torch.utils.data.SequentialSampler(dataset_test)
156
157
    return dataset, dataset_test, train_sampler, test_sampler
158
159
160
def main(args):
161
    if args.apex and amp is None:
162
        raise RuntimeError("Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
163
                           "to enable mixed-precision training.")
164
165
    if args.output_dir:
166
        utils.mkdir(args.output_dir)
167
168
    utils.init_distributed_mode(args)
169
    print(args)
170
171
    device = torch.device(args.device)
172
173
    torch.backends.cudnn.benchmark = True
174
175
    #train_dir = os.path.join(args.data_path, 'train')
176
    #val_dir = os.path.join(args.data_path, 'val')
177
    train_dir = args.train_file
178
    val_dir = args.val_file
179
    dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir,
180
                                                                   args.cache_dataset, args.distributed)
181
    data_loader = torch.utils.data.DataLoader(
182
        dataset, batch_size=args.batch_size,
183
        sampler=train_sampler, num_workers=args.workers, pin_memory=True)
184
185
    data_loader_test = torch.utils.data.DataLoader(
186
        dataset_test, batch_size=args.batch_size,
187
        sampler=test_sampler, num_workers=args.workers, pin_memory=True)
188
189
    print("Creating model")
190
    #model = torchvision.models.__dict__[args.model](pretrained=args.pretrained)
191
    model = torchvision.models.densenet169(pretrained=args.pretrained)
192
    # modify the last layer for objective task
193
    num_ftrs = model.features.norm5.num_features
194
    model.classifier = nn.Linear(num_ftrs, args.num_classes)
195
        
196
    model.to(device)
197
    if args.distributed and args.sync_bn:
198
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
199
200
    if args.focal_loss:
201
        import pytorch_toolbelt.losses
202
        criterion = pytorch_toolbelt.losses.FocalLoss()
203
    else:
204
        criterion = nn.CrossEntropyLoss()
205
206
    optimizer = torch.optim.SGD(
207
        model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
208
209
    if args.apex:
210
        model, optimizer = amp.initialize(model, optimizer,
211
                                          opt_level=args.apex_opt_level
212
                                          )
213
214
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step_size, gamma=args.lr_gamma)
215
216
    model_without_ddp = model
217
    if args.distributed:
218
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
219
        model_without_ddp = model.module
220
221
    if args.resume:
222
        checkpoint = torch.load(args.resume, map_location='cpu')
223
        model_without_ddp.load_state_dict(checkpoint['model'])
224
        optimizer.load_state_dict(checkpoint['optimizer'])
225
        lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
226
        args.start_epoch = checkpoint['epoch'] + 1
227
228
    if args.test_only:
229
        evaluate(model, criterion, data_loader_test, device=device)
230
        return
231
232
    print("Start training")
233
    start_time = time.time()
234
    for epoch in range(args.start_epoch, args.epochs):
235
        if args.distributed:
236
            train_sampler.set_epoch(epoch)
237
        train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq, args.apex)
238
        lr_scheduler.step()
239
        evaluate(model, criterion, data_loader_test, device=device)
240
        if args.output_dir:
241
            checkpoint = {
242
                'model': model_without_ddp.state_dict(),
243
                'optimizer': optimizer.state_dict(),
244
                'lr_scheduler': lr_scheduler.state_dict(),
245
                'epoch': epoch,
246
                'args': args}
247
            utils.save_on_master(
248
                checkpoint,
249
                os.path.join(args.output_dir, 'model_{}.pth'.format(epoch)))
250
            utils.save_on_master(
251
                checkpoint,
252
                os.path.join(args.output_dir, 'checkpoint.pth'))
253
254
    total_time = time.time() - start_time
255
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
256
    print('Training time {}'.format(total_time_str))
257
258
259
def parse_args():
260
    import argparse
261
    parser = argparse.ArgumentParser(description='PyTorch Classification Training')
262
263
    #parser.add_argument('--data-path', default='/datasets01/imagenet_full_size/061417/', help='dataset')
264
    parser.add_argument('--train-file', help='training set')
265
    parser.add_argument('--val-file', help='validation set')
266
    parser.add_argument('--num-classes', help='number of classes for the objective task', type=int)
267
    parser.add_argument(
268
        "--focal-loss", help="Use focal loss",action="store_true")
269
    
270
    #parser.add_argument('--model', default='resnet18', help='model')
271
    parser.add_argument('--device', default='cuda', help='device')
272
    parser.add_argument('-b', '--batch-size', default=32, type=int)
273
    parser.add_argument('--epochs', default=90, type=int, metavar='N',
274
                        help='number of total epochs to run')
275
    parser.add_argument('-j', '--workers', default=16, type=int, metavar='N',
276
                        help='number of data loading workers (default: 16)')
277
    parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate')
278
    parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
279
                        help='momentum')
280
    parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
281
                        metavar='W', help='weight decay (default: 1e-4)',
282
                        dest='weight_decay')
283
    parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs')
284
    parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
285
    parser.add_argument('--print-freq', default=10, type=int, help='print frequency')
286
    parser.add_argument('--output-dir', default='.', help='path where to save')
287
    parser.add_argument('--resume', default='', help='resume from checkpoint')
288
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
289
                        help='start epoch')
290
    parser.add_argument(
291
        "--cache-dataset",
292
        dest="cache_dataset",
293
        help="Cache the datasets for quicker initialization. It also serializes the transforms",
294
        action="store_true",
295
    )
296
    parser.add_argument(
297
        "--sync-bn",
298
        dest="sync_bn",
299
        help="Use sync batch norm",
300
        action="store_true",
301
    )
302
    parser.add_argument(
303
        "--test-only",
304
        dest="test_only",
305
        help="Only test the model",
306
        action="store_true",
307
    )
308
    parser.add_argument(
309
        "--pretrained",
310
        dest="pretrained",
311
        help="Use pre-trained models from the modelzoo",
312
        action="store_true",
313
    )
314
315
    # Mixed precision training parameters
316
    parser.add_argument('--apex', action='store_true',
317
                        help='Use apex for mixed precision training')
318
    parser.add_argument('--apex-opt-level', default='O1', type=str,
319
                        help='For apex mixed precision training'
320
                             'O0 for FP32 training, O1 for mixed precision training.'
321
                             'For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet'
322
                        )
323
324
    # distributed training parameters
325
    parser.add_argument('--world-size', default=1, type=int,
326
                        help='number of distributed processes')
327
    parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')
328
329
    args = parser.parse_args()
330
331
    return args
332
333
334
if __name__ == "__main__":
335
    args = parse_args()
336
    main(args)