Diff of /dl/utils/solver.py [000000] .. [4807fa]

Switch to unified view

a b/dl/utils/solver.py
1
import time
2
import shutil
3
import os.path
4
import sys
5
6
import torch
7
import torch.nn as nn
8
import torch.nn.functional as F
9
from torch.autograd import Variable
10
import torch.optim
11
import torch.utils.data
12
import torch.utils.model_zoo as model_zoo
13
import torchvision.transforms as transforms
14
import torchvision.datasets
15
import torchvision.models
16
17
from .utils import AverageMeter, check_acc
18
from ..models.densenet import DenseNet
19
from .sampler import BatchLoader
20
21
if torch.cuda.is_available():
22
  dtype = {'float': torch.cuda.FloatTensor, 'long': torch.cuda.LongTensor, 'byte': torch.cuda.ByteTensor} 
23
else:
24
  dtype = {'float': torch.FloatTensor, 'long': torch.LongTensor, 'byte': torch.ByteTensor} 
25
26
27
class Solver(object):
28
    """Solver
29
    Args:
30
        model: 
31
        data:
32
        optimizer: e.g., torch.optim.Adam(model.parameters())
33
        loss_fn: loss function; e.g., torch.nn.CrossEntropy()
34
        resume: file path to checkpoint
35
    """
36
    def __init__(self, model, data, optimizer, loss_fn, resume=None):
37
        self.model = model
38
        self.data = data
39
        self.optimizer = optimizer
40
        self.loss_fn = loss_fn
41
        
42
        # keep track of loss and accuracy during training
43
        self.losses_train = []
44
        self.losses_val = []
45
        self.acc_train = []
46
        self.acc_val = []
47
        self.best_acc_val = 0
48
        self.epoch_counter = 0
49
        
50
        if resume:
51
            if os.path.isfile(resume):
52
                checkpoint = torch.load(resume)
53
                self.model.load_state_dict(checkpoint['model_state'])
54
                self.optimizer = checkpoint['optimizer']
55
                self.best_acc_val = checkpoint['best_acc_val']
56
                self.epoch_counter = checkpoint['epoch']
57
                self.losses_train = checkpoint['losses_train']
58
                self.losses_val = checkpoint['losses_val']
59
                self.acc_train = checkpoint['acc_train']
60
                self.acc_val = checkpoint['acc_val']
61
            else:
62
                print("==> No checkpoint found at '{}'".format(resume))
63
        
64
    def _reset_avg_meter(self):
65
        """reset loss_epoch, top1, top5, batch_time at the beginning of each epoch
66
        """
67
        self.loss_epoch = AverageMeter()
68
        self.top1 = AverageMeter()
69
        self.top5 = AverageMeter()
70
        self.batch_time = AverageMeter()
71
        
72
    
73
    def run_one_epoch(self, epoch, batch_size=100, num_samples=None, print_every=100, 
74
                      training=True, balanced_sample=False, topk=5):
75
        """run one epoch for training or validating
76
        Args:
77
            epoch: int; epoch_counter; used for printing only
78
            batch_size: int, default: 100
79
            num_samples: int, default: None. 
80
                How many samples to use in case we don't want train a whole epoch
81
            print_every: int, default: 100
82
            training: bool, default:True. If true, train; else validate
83
            balanced_sample: default: False. Used for unbalanced dataset
84
        """
85
        if 'train_loader' in self.data:
86
            # This is for image related tasks
87
            dataloader = self.data['train_loader'] if training else self.data['val_loader']
88
            # This is very important! dataloader.batch_size is controlled by dataloader.batch_sampler.batch_size
89
            # not the other way around. This is (probably) due to the fact that dataloader was created by setting batch_size
90
            dataloader.batch_sampler.batch_size = batch_size
91
            N = len(dataloader.dataset.imgs)
92
            num_chunks = (N + batch_size - 1) // batch_size
93
        elif 'X_train' in self.data:
94
            X, y = (self.data['X_train'], self.data['y_train']) if training else (self.data['X_val'], self.data['y_val'])
95
            N = X.size(0)
96
            if num_samples:
97
                if num_samples < N and num_samples > 0:
98
                    N = num_samples
99
                    
100
            if balanced_sample and isinstance(y, dtype['long']):
101
                dataloader = BatchLoader((X[:N], y[:N]), batch_size)
102
                num_chunks = len(dataloader)
103
            else:
104
                shuffle_idx = torch.randperm(N)
105
                X = torch.index_select(X, 0, shuffle_idx)
106
                y = torch.index_select(y, 0, shuffle_idx)
107
                num_chunks = (N + batch_size - 1) // batch_size
108
                X_chunks = X.chunk(num_chunks)
109
                y_chunks = y.chunk(num_chunks)
110
                dataloader = zip(X_chunks, y_chunks)
111
        else:
112
            raise ValueError('data must contain either X_train or train_loader')
113
        
114
        if training:
115
            print("Training:")
116
        else:
117
            print("Validating:")
118
            
119
        self._reset_avg_meter()
120
        end_time = time.time()
121
        for i, (X, y) in enumerate(dataloader):
122
            X = Variable(X)
123
            y = Variable(y)
124
            
125
            y_pred = self.model(X)
126
            loss = self.loss_fn(y_pred, y)
127
            
128
            if training:
129
                self.optimizer.zero_grad()
130
                loss.backward()
131
                self.optimizer.step()
132
            
133
            self.loss_epoch.update(loss.item(), y.size(0))
134
            # For classification tasks, y.data is torch.LongTensor
135
            # For regression tasks, y.data is torch.FloatTensor
136
            is_classification = isinstance(y.data, dtype['long'])
137
            if is_classification:
138
                res = check_acc(y_pred, y, (1, topk))
139
                self.top1.update(res[0].item())
140
                self.top5.update(res[1].item())
141
            else:
142
                # top1 is approximately the 'inverse' of loss
143
                self.top1.update(1. / (loss.item() + 1.), y.size(0))
144
            self.batch_time.update(time.time() - end_time)
145
            end_time = time.time()
146
            
147
            if training:
148
                self.losses_train.append(self.loss_epoch.avg)
149
                self.acc_train.append(self.top1.avg)
150
            else:
151
                self.losses_val.append(self.loss_epoch.avg)
152
                self.acc_val.append(self.top1.avg)
153
                
154
            if print_every:
155
                if (i + 1) % print_every == 0:
156
                    print('Epoch {0}: iteration {1}/{2}\t'
157
                          'loss: {losses.val:.3f}, avg: {losses.avg:.3f}\t'
158
                          'Prec@1: {prec1.val:.3f}, avg: {prec1.avg:.3f}\t'
159
                          'Prec@5: {prec5.val:.3f}, avg: {prec5.avg:.3f}\t'
160
                          'batch time: {batch_time.val:.3f} avg: {batch_time.avg:.3f}'.format(
161
                              epoch + 1, i + 1, num_chunks, losses=self.loss_epoch, prec1=self.top1, 
162
                              prec5=self.top5, batch_time=self.batch_time))
163
                    sys.stdout.flush()
164
            
165
        return self.top1.avg
166
    
167
    def train_eval(self, num_iter=100, batch_size=100, X=None, y=None, X_val=None, y_val=None,
168
                   X_test=None, y_test=None, eval_test=False, balanced_sample=False, allow_duplicate=False,
169
                   max_redundancy=1000, seed=None):
170
        if X is None or y is None:
171
            X, y = self.data['X_train'], self.data['y_train']
172
        # Currently only for classification tasks, y is torch.LongTensor 
173
        assert isinstance(y, dtype['long'])
174
        if X_val is None or y_val is None:
175
            X_val, y_val = self.data['X_val'], self.data['y_val']
176
        if eval_test and (X_test is None or y_test is None):
177
            X_test, y_test = self.data['X_test'], self.data['y_test']
178
        
179
        dataloader_train = BatchLoader((X, y), batch_size, balanced=balanced_sample, 
180
            num_iter=num_iter, allow_duplicate=allow_duplicate, max_redundancy=max_redundancy, 
181
            shuffle=True, seed=seed)
182
        dataloader_val = BatchLoader((X_val, y_val), batch_size, balanced=balanced_sample, 
183
            num_iter=num_iter, allow_duplicate=allow_duplicate, max_redundancy=max_redundancy, 
184
            shuffle=True, seed=seed)
185
        if X_test is not None:
186
            dataloader_test = BatchLoader((X_test, y_test), batch_size, balanced=balanced_sample, 
187
                num_iter=num_iter, allow_duplicate=allow_duplicate, max_redundancy=max_redundancy, 
188
                shuffle=True, seed=seed)
189
        else:
190
            dataloader_test = [None]*num_iter
191
192
        loss_train_meter = AverageMeter()
193
        loss_train = {'avg':[], 'batch':[]}
194
        acc_train_meter = AverageMeter()
195
        acc_train = {'avg':[], 'batch':[]}
196
        loss_val_meter = AverageMeter()
197
        loss_val = {'avg':[], 'batch':[]}
198
        acc_val_meter = AverageMeter()
199
        acc_val = {'avg':[], 'batch':[]}
200
        loss_test_meter = AverageMeter()
201
        loss_test = {'avg':[], 'batch':[]}
202
        acc_test_meter = AverageMeter()
203
        acc_test = {'avg':[], 'batch':[]}
204
205
        def forward(X, y, loss_meter, losses, acc_meter, acc, training=False):
206
            X = Variable(X)
207
            y = Variable(y)
208
            y_pred = self.model(X)
209
            loss = self.loss_fn(y_pred, y)
210
            loss_meter.update(loss.item(), y.size(0))
211
            losses['avg'].append(loss_meter.avg)
212
            losses['batch'].append(loss.item())
213
            res = check_acc(y_pred, y, (1,))
214
            acc_meter.update(res[0].item(), y.size(0))
215
            acc['avg'].append(acc_meter.avg)
216
            acc['batch'].append(res[0].item())
217
218
            if training:
219
                self.optimizer.zero_grad()
220
                loss.backward()
221
                self.optimizer.step()
222
            
223
            return y_pred, loss
224
225
        for (X, y), (X_val, y_val), test_data in zip(dataloader_train, 
226
                dataloader_val, dataloader_test):        
227
            forward(X, y, loss_train_meter, loss_train, acc_train_meter, acc_train, 
228
                training=True)
229
            forward(X_val, y_val, loss_val_meter, loss_val, acc_val_meter, acc_val, 
230
                training=False)
231
            if test_data is not None:
232
                X_test, y_test = test_data
233
                forward(X_test, y_test, loss_test_meter, loss_test, acc_test_meter, 
234
                    acc_test, training=False)
235
        
236
        if eval_test:
237
            return loss_train, acc_train, loss_val, acc_val, loss_test, acc_test
238
        else:
239
            return loss_train, acc_train, loss_val, acc_val
240
241
    
242
    def train(self, num_epoch = 10, batch_size=100, num_samples=None, print_every=100, 
243
              use_validation = True, save_checkpoint=True, file_prefix='', balanced_sample=False, topk=5):
244
        """train
245
        Args:
246
            num_epoch: int, default: 100 
247
            batch_size: int, default: 100
248
            num_samples: int, default: None
249
            print_every: int, default: 100
250
            use_validation: bool, default: True. If True, run_one_epoch for both training and validating
251
            save_checkpoint: bool, default: True. If True, save checkpoint with name (file_prefix + 'checkpoint%d.pth' % self.epoch_counter) and best model (file_prefix + 'model_best.pth').
252
            file_prefix: str, default:''
253
            balanced_sample: bool; used for sampling balanced batches from unbalanced dataset
254
        """
255
        for i in range(self.epoch_counter, self.epoch_counter + num_epoch):
256
            accuracy = self.run_one_epoch(i, batch_size, num_samples, print_every,
257
                                          balanced_sample=balanced_sample, topk=topk)
258
            # In case we don't want validation set. Very rare
259
            if use_validation:
260
                accuracy = self.run_one_epoch(i, batch_size, num_samples, print_every, 
261
                                              training=False, balanced_sample=balanced_sample, topk=topk)
262
            
263
            if accuracy > self.best_acc_val:
264
                self.best_acc_val = accuracy
265
                if save_checkpoint:
266
                    state = {'model_state': self.model.state_dict(), 
267
                            'optimizer': self.optimizer,
268
                            'best_acc_val': self.best_acc_val,
269
                            'epoch': i + 1,
270
                            'losses_train': self.losses_train,
271
                            'losses_val': self.losses_val,
272
                            'acc_train': self.acc_train,
273
                            'acc_val': self.acc_val}
274
                    filename = file_prefix + 'checkpoint%d.pth' % (i + 1)
275
                    torch.save(state, filename)
276
                    shutil.copyfile(filename, file_prefix + 'model_best.pth')
277
    
278
    def predict(self, batch_size=100, save_file=True, file_prefix='', X=None, y=None, topk=5, verbose=False):
279
        """predict
280
        Args:
281
            batch_size: int, default: 100; can be larger for large memory
282
            save_file: bool, default: True; if true, save file
283
            file_prefix: save file name: file_prefix + 'y_test.pth'
284
            X: default: None. If not None, use X instead of self.data['X_test']
285
            y: default: None. Similary to X
286
        """
287
        if X is None:
288
            if 'X_test' in self.data:
289
                X = self.data['X_test']
290
            elif 'test_loader' in self.data:
291
                X = self.data['test_loader']
292
                dataloader = X
293
            else:
294
                raise ValueError('If X is None, then self.data '
295
                                 'must contain either X_test or test_loader')
296
            
297
        if y is None and 'y_test' in self.data:
298
                y = self.data['y_test']
299
        
300
        is_truth_avail = isinstance(y, dtype['long']) or isinstance(y, dtype['float'])
301
        
302
        if isinstance(X, dtype['float']):
303
            N = X.size(0)
304
            num_chunks = (N + batch_size - 1) // batch_size
305
            X_chunks = X.chunk(num_chunks)
306
            dataloader = X_chunks
307
        
308
        if is_truth_avail:
309
            N = y.size(0)
310
            num_chunks = (N + batch_size - 1) // batch_size
311
            y_chunks = y.chunk(num_chunks)
312
        else:
313
            y_chunks = [None] * num_chunks
314
        
315
        self._reset_avg_meter()
316
        end_time = time.time()
317
        y_pred = []
318
        for X, y in zip(X_chunks, y_chunks):  
319
            X = Variable(X)
320
            y = Variable(y)
321
            
322
            y_pred_tmp = self.model(X) # sometimes model output a tuple
323
            
324
            if is_truth_avail:
325
                loss = self.loss_fn(y_pred_tmp, y)
326
                self.loss_epoch.update(loss.item(), y.size(0))
327
                if isinstance(y.data, dtype['long']):
328
                    res = check_acc(y_pred_tmp, y, (1, topk))
329
                    self.top1.update(res[0].item())
330
                    self.top5.update(res[1].item())
331
                else:
332
                    self.top1.update(1. / (loss.item() + 1.), y.size(0))
333
            self.batch_time.update(time.time() - end_time)
334
            end_time = time.time()
335
            if isinstance(y_pred_tmp, tuple):
336
                y_pred_tmp = y_pred_tmp[0]
337
            y_pred.append(y_pred_tmp)
338
        
339
        if is_truth_avail and verbose:
340
            print('Test set: loss: {losses.avg:.3f}\t'
341
                  'AP@1: {prec1.avg:.3f}\t'
342
                  'AP@5: {prec5.avg:.3f}\t'
343
                  'batch time: {batch_time.avg:.3f}'.format(
344
                      losses=self.loss_epoch, prec1=self.top1, 
345
                      prec5=self.top5, batch_time=self.batch_time))
346
            sys.stdout.flush()
347
        y_pred = torch.cat(y_pred, 0)
348
        if save_file:
349
            torch.save({'y_pred': y_pred}, file_prefix + 'y_pred.pth')
350
        return y_pred
351
352
353
if __name__ == '__main__':
354
355
    mnist_train = torchvision.datasets.MNIST('/projects/academic/jamesjar/tianlema/dl-datasets/mnist',
356
                                             transform=transforms.Compose([transforms.ToTensor(), 
357
                                                                           transforms.Normalize((0.1307,), (0.3081,))]))
358
    train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=200)
359
360
    mnist_test = torchvision.datasets.MNIST('/projects/academic/jamesjar/tianlema/dl-datasets/mnist',
361
                                            transform=transforms.Compose([transforms.ToTensor(), 
362
                                                                          transforms.Normalize((0.1307,), (0.3081,))]), 
363
                                             train=False)
364
    test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=200)
365
366
    train = list(train_loader)
367
    train = list(zip(*train))
368
    X_train = torch.cat(train[0], 0)
369
    y_train = torch.cat(train[1], 0)
370
371
    X_val = X_train[50000:]
372
    y_val = y_train[50000:]
373
    X_train = X_train[:50000]
374
    y_train = y_train[:50000]
375
376
    test = list(test_loader)
377
    test = list(zip(*test))
378
    X_test = torch.cat(test[0], 0)
379
    y_test = torch.cat(test[1], 0)
380
381
    data = {'X_train': X_train, 'y_train': y_train, 'X_val': X_val, 'y_val': y_val, 
382
            'X_test': X_test, 'y_test': y_test}
383
384
385
386
    
387
    model = DenseNet(input_param=(1, 64), block_layers=(6, 4), num_classes=10, 
388
                     growth_rate=32, bn_size=2, dropout_rate=0, transition_pool_param=(3, 1, 1))
389
390
391
392
    loss_fn = nn.CrossEntropyLoss()
393
394
395
396
    optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-4)
397
398
399
    
400
    solver = Solver(model, data, optimizer, loss_fn)
401
    solver.train(num_epoch=2, file_prefix='mnist-')
402
    solver.predict(file_prefix='mnist-')