|
a |
|
b/dl/utils/solver.py |
|
|
1 |
import time |
|
|
2 |
import shutil |
|
|
3 |
import os.path |
|
|
4 |
import sys |
|
|
5 |
|
|
|
6 |
import torch |
|
|
7 |
import torch.nn as nn |
|
|
8 |
import torch.nn.functional as F |
|
|
9 |
from torch.autograd import Variable |
|
|
10 |
import torch.optim |
|
|
11 |
import torch.utils.data |
|
|
12 |
import torch.utils.model_zoo as model_zoo |
|
|
13 |
import torchvision.transforms as transforms |
|
|
14 |
import torchvision.datasets |
|
|
15 |
import torchvision.models |
|
|
16 |
|
|
|
17 |
from .utils import AverageMeter, check_acc |
|
|
18 |
from ..models.densenet import DenseNet |
|
|
19 |
from .sampler import BatchLoader |
|
|
20 |
|
|
|
21 |
if torch.cuda.is_available(): |
|
|
22 |
dtype = {'float': torch.cuda.FloatTensor, 'long': torch.cuda.LongTensor, 'byte': torch.cuda.ByteTensor} |
|
|
23 |
else: |
|
|
24 |
dtype = {'float': torch.FloatTensor, 'long': torch.LongTensor, 'byte': torch.ByteTensor} |
|
|
25 |
|
|
|
26 |
|
|
|
27 |
class Solver(object): |
|
|
28 |
"""Solver |
|
|
29 |
Args: |
|
|
30 |
model: |
|
|
31 |
data: |
|
|
32 |
optimizer: e.g., torch.optim.Adam(model.parameters()) |
|
|
33 |
loss_fn: loss function; e.g., torch.nn.CrossEntropy() |
|
|
34 |
resume: file path to checkpoint |
|
|
35 |
""" |
|
|
36 |
def __init__(self, model, data, optimizer, loss_fn, resume=None): |
|
|
37 |
self.model = model |
|
|
38 |
self.data = data |
|
|
39 |
self.optimizer = optimizer |
|
|
40 |
self.loss_fn = loss_fn |
|
|
41 |
|
|
|
42 |
# keep track of loss and accuracy during training |
|
|
43 |
self.losses_train = [] |
|
|
44 |
self.losses_val = [] |
|
|
45 |
self.acc_train = [] |
|
|
46 |
self.acc_val = [] |
|
|
47 |
self.best_acc_val = 0 |
|
|
48 |
self.epoch_counter = 0 |
|
|
49 |
|
|
|
50 |
if resume: |
|
|
51 |
if os.path.isfile(resume): |
|
|
52 |
checkpoint = torch.load(resume) |
|
|
53 |
self.model.load_state_dict(checkpoint['model_state']) |
|
|
54 |
self.optimizer = checkpoint['optimizer'] |
|
|
55 |
self.best_acc_val = checkpoint['best_acc_val'] |
|
|
56 |
self.epoch_counter = checkpoint['epoch'] |
|
|
57 |
self.losses_train = checkpoint['losses_train'] |
|
|
58 |
self.losses_val = checkpoint['losses_val'] |
|
|
59 |
self.acc_train = checkpoint['acc_train'] |
|
|
60 |
self.acc_val = checkpoint['acc_val'] |
|
|
61 |
else: |
|
|
62 |
print("==> No checkpoint found at '{}'".format(resume)) |
|
|
63 |
|
|
|
64 |
def _reset_avg_meter(self): |
|
|
65 |
"""reset loss_epoch, top1, top5, batch_time at the beginning of each epoch |
|
|
66 |
""" |
|
|
67 |
self.loss_epoch = AverageMeter() |
|
|
68 |
self.top1 = AverageMeter() |
|
|
69 |
self.top5 = AverageMeter() |
|
|
70 |
self.batch_time = AverageMeter() |
|
|
71 |
|
|
|
72 |
|
|
|
73 |
def run_one_epoch(self, epoch, batch_size=100, num_samples=None, print_every=100, |
|
|
74 |
training=True, balanced_sample=False, topk=5): |
|
|
75 |
"""run one epoch for training or validating |
|
|
76 |
Args: |
|
|
77 |
epoch: int; epoch_counter; used for printing only |
|
|
78 |
batch_size: int, default: 100 |
|
|
79 |
num_samples: int, default: None. |
|
|
80 |
How many samples to use in case we don't want train a whole epoch |
|
|
81 |
print_every: int, default: 100 |
|
|
82 |
training: bool, default:True. If true, train; else validate |
|
|
83 |
balanced_sample: default: False. Used for unbalanced dataset |
|
|
84 |
""" |
|
|
85 |
if 'train_loader' in self.data: |
|
|
86 |
# This is for image related tasks |
|
|
87 |
dataloader = self.data['train_loader'] if training else self.data['val_loader'] |
|
|
88 |
# This is very important! dataloader.batch_size is controlled by dataloader.batch_sampler.batch_size |
|
|
89 |
# not the other way around. This is (probably) due to the fact that dataloader was created by setting batch_size |
|
|
90 |
dataloader.batch_sampler.batch_size = batch_size |
|
|
91 |
N = len(dataloader.dataset.imgs) |
|
|
92 |
num_chunks = (N + batch_size - 1) // batch_size |
|
|
93 |
elif 'X_train' in self.data: |
|
|
94 |
X, y = (self.data['X_train'], self.data['y_train']) if training else (self.data['X_val'], self.data['y_val']) |
|
|
95 |
N = X.size(0) |
|
|
96 |
if num_samples: |
|
|
97 |
if num_samples < N and num_samples > 0: |
|
|
98 |
N = num_samples |
|
|
99 |
|
|
|
100 |
if balanced_sample and isinstance(y, dtype['long']): |
|
|
101 |
dataloader = BatchLoader((X[:N], y[:N]), batch_size) |
|
|
102 |
num_chunks = len(dataloader) |
|
|
103 |
else: |
|
|
104 |
shuffle_idx = torch.randperm(N) |
|
|
105 |
X = torch.index_select(X, 0, shuffle_idx) |
|
|
106 |
y = torch.index_select(y, 0, shuffle_idx) |
|
|
107 |
num_chunks = (N + batch_size - 1) // batch_size |
|
|
108 |
X_chunks = X.chunk(num_chunks) |
|
|
109 |
y_chunks = y.chunk(num_chunks) |
|
|
110 |
dataloader = zip(X_chunks, y_chunks) |
|
|
111 |
else: |
|
|
112 |
raise ValueError('data must contain either X_train or train_loader') |
|
|
113 |
|
|
|
114 |
if training: |
|
|
115 |
print("Training:") |
|
|
116 |
else: |
|
|
117 |
print("Validating:") |
|
|
118 |
|
|
|
119 |
self._reset_avg_meter() |
|
|
120 |
end_time = time.time() |
|
|
121 |
for i, (X, y) in enumerate(dataloader): |
|
|
122 |
X = Variable(X) |
|
|
123 |
y = Variable(y) |
|
|
124 |
|
|
|
125 |
y_pred = self.model(X) |
|
|
126 |
loss = self.loss_fn(y_pred, y) |
|
|
127 |
|
|
|
128 |
if training: |
|
|
129 |
self.optimizer.zero_grad() |
|
|
130 |
loss.backward() |
|
|
131 |
self.optimizer.step() |
|
|
132 |
|
|
|
133 |
self.loss_epoch.update(loss.item(), y.size(0)) |
|
|
134 |
# For classification tasks, y.data is torch.LongTensor |
|
|
135 |
# For regression tasks, y.data is torch.FloatTensor |
|
|
136 |
is_classification = isinstance(y.data, dtype['long']) |
|
|
137 |
if is_classification: |
|
|
138 |
res = check_acc(y_pred, y, (1, topk)) |
|
|
139 |
self.top1.update(res[0].item()) |
|
|
140 |
self.top5.update(res[1].item()) |
|
|
141 |
else: |
|
|
142 |
# top1 is approximately the 'inverse' of loss |
|
|
143 |
self.top1.update(1. / (loss.item() + 1.), y.size(0)) |
|
|
144 |
self.batch_time.update(time.time() - end_time) |
|
|
145 |
end_time = time.time() |
|
|
146 |
|
|
|
147 |
if training: |
|
|
148 |
self.losses_train.append(self.loss_epoch.avg) |
|
|
149 |
self.acc_train.append(self.top1.avg) |
|
|
150 |
else: |
|
|
151 |
self.losses_val.append(self.loss_epoch.avg) |
|
|
152 |
self.acc_val.append(self.top1.avg) |
|
|
153 |
|
|
|
154 |
if print_every: |
|
|
155 |
if (i + 1) % print_every == 0: |
|
|
156 |
print('Epoch {0}: iteration {1}/{2}\t' |
|
|
157 |
'loss: {losses.val:.3f}, avg: {losses.avg:.3f}\t' |
|
|
158 |
'Prec@1: {prec1.val:.3f}, avg: {prec1.avg:.3f}\t' |
|
|
159 |
'Prec@5: {prec5.val:.3f}, avg: {prec5.avg:.3f}\t' |
|
|
160 |
'batch time: {batch_time.val:.3f} avg: {batch_time.avg:.3f}'.format( |
|
|
161 |
epoch + 1, i + 1, num_chunks, losses=self.loss_epoch, prec1=self.top1, |
|
|
162 |
prec5=self.top5, batch_time=self.batch_time)) |
|
|
163 |
sys.stdout.flush() |
|
|
164 |
|
|
|
165 |
return self.top1.avg |
|
|
166 |
|
|
|
167 |
def train_eval(self, num_iter=100, batch_size=100, X=None, y=None, X_val=None, y_val=None, |
|
|
168 |
X_test=None, y_test=None, eval_test=False, balanced_sample=False, allow_duplicate=False, |
|
|
169 |
max_redundancy=1000, seed=None): |
|
|
170 |
if X is None or y is None: |
|
|
171 |
X, y = self.data['X_train'], self.data['y_train'] |
|
|
172 |
# Currently only for classification tasks, y is torch.LongTensor |
|
|
173 |
assert isinstance(y, dtype['long']) |
|
|
174 |
if X_val is None or y_val is None: |
|
|
175 |
X_val, y_val = self.data['X_val'], self.data['y_val'] |
|
|
176 |
if eval_test and (X_test is None or y_test is None): |
|
|
177 |
X_test, y_test = self.data['X_test'], self.data['y_test'] |
|
|
178 |
|
|
|
179 |
dataloader_train = BatchLoader((X, y), batch_size, balanced=balanced_sample, |
|
|
180 |
num_iter=num_iter, allow_duplicate=allow_duplicate, max_redundancy=max_redundancy, |
|
|
181 |
shuffle=True, seed=seed) |
|
|
182 |
dataloader_val = BatchLoader((X_val, y_val), batch_size, balanced=balanced_sample, |
|
|
183 |
num_iter=num_iter, allow_duplicate=allow_duplicate, max_redundancy=max_redundancy, |
|
|
184 |
shuffle=True, seed=seed) |
|
|
185 |
if X_test is not None: |
|
|
186 |
dataloader_test = BatchLoader((X_test, y_test), batch_size, balanced=balanced_sample, |
|
|
187 |
num_iter=num_iter, allow_duplicate=allow_duplicate, max_redundancy=max_redundancy, |
|
|
188 |
shuffle=True, seed=seed) |
|
|
189 |
else: |
|
|
190 |
dataloader_test = [None]*num_iter |
|
|
191 |
|
|
|
192 |
loss_train_meter = AverageMeter() |
|
|
193 |
loss_train = {'avg':[], 'batch':[]} |
|
|
194 |
acc_train_meter = AverageMeter() |
|
|
195 |
acc_train = {'avg':[], 'batch':[]} |
|
|
196 |
loss_val_meter = AverageMeter() |
|
|
197 |
loss_val = {'avg':[], 'batch':[]} |
|
|
198 |
acc_val_meter = AverageMeter() |
|
|
199 |
acc_val = {'avg':[], 'batch':[]} |
|
|
200 |
loss_test_meter = AverageMeter() |
|
|
201 |
loss_test = {'avg':[], 'batch':[]} |
|
|
202 |
acc_test_meter = AverageMeter() |
|
|
203 |
acc_test = {'avg':[], 'batch':[]} |
|
|
204 |
|
|
|
205 |
def forward(X, y, loss_meter, losses, acc_meter, acc, training=False): |
|
|
206 |
X = Variable(X) |
|
|
207 |
y = Variable(y) |
|
|
208 |
y_pred = self.model(X) |
|
|
209 |
loss = self.loss_fn(y_pred, y) |
|
|
210 |
loss_meter.update(loss.item(), y.size(0)) |
|
|
211 |
losses['avg'].append(loss_meter.avg) |
|
|
212 |
losses['batch'].append(loss.item()) |
|
|
213 |
res = check_acc(y_pred, y, (1,)) |
|
|
214 |
acc_meter.update(res[0].item(), y.size(0)) |
|
|
215 |
acc['avg'].append(acc_meter.avg) |
|
|
216 |
acc['batch'].append(res[0].item()) |
|
|
217 |
|
|
|
218 |
if training: |
|
|
219 |
self.optimizer.zero_grad() |
|
|
220 |
loss.backward() |
|
|
221 |
self.optimizer.step() |
|
|
222 |
|
|
|
223 |
return y_pred, loss |
|
|
224 |
|
|
|
225 |
for (X, y), (X_val, y_val), test_data in zip(dataloader_train, |
|
|
226 |
dataloader_val, dataloader_test): |
|
|
227 |
forward(X, y, loss_train_meter, loss_train, acc_train_meter, acc_train, |
|
|
228 |
training=True) |
|
|
229 |
forward(X_val, y_val, loss_val_meter, loss_val, acc_val_meter, acc_val, |
|
|
230 |
training=False) |
|
|
231 |
if test_data is not None: |
|
|
232 |
X_test, y_test = test_data |
|
|
233 |
forward(X_test, y_test, loss_test_meter, loss_test, acc_test_meter, |
|
|
234 |
acc_test, training=False) |
|
|
235 |
|
|
|
236 |
if eval_test: |
|
|
237 |
return loss_train, acc_train, loss_val, acc_val, loss_test, acc_test |
|
|
238 |
else: |
|
|
239 |
return loss_train, acc_train, loss_val, acc_val |
|
|
240 |
|
|
|
241 |
|
|
|
242 |
def train(self, num_epoch = 10, batch_size=100, num_samples=None, print_every=100, |
|
|
243 |
use_validation = True, save_checkpoint=True, file_prefix='', balanced_sample=False, topk=5): |
|
|
244 |
"""train |
|
|
245 |
Args: |
|
|
246 |
num_epoch: int, default: 100 |
|
|
247 |
batch_size: int, default: 100 |
|
|
248 |
num_samples: int, default: None |
|
|
249 |
print_every: int, default: 100 |
|
|
250 |
use_validation: bool, default: True. If True, run_one_epoch for both training and validating |
|
|
251 |
save_checkpoint: bool, default: True. If True, save checkpoint with name (file_prefix + 'checkpoint%d.pth' % self.epoch_counter) and best model (file_prefix + 'model_best.pth'). |
|
|
252 |
file_prefix: str, default:'' |
|
|
253 |
balanced_sample: bool; used for sampling balanced batches from unbalanced dataset |
|
|
254 |
""" |
|
|
255 |
for i in range(self.epoch_counter, self.epoch_counter + num_epoch): |
|
|
256 |
accuracy = self.run_one_epoch(i, batch_size, num_samples, print_every, |
|
|
257 |
balanced_sample=balanced_sample, topk=topk) |
|
|
258 |
# In case we don't want validation set. Very rare |
|
|
259 |
if use_validation: |
|
|
260 |
accuracy = self.run_one_epoch(i, batch_size, num_samples, print_every, |
|
|
261 |
training=False, balanced_sample=balanced_sample, topk=topk) |
|
|
262 |
|
|
|
263 |
if accuracy > self.best_acc_val: |
|
|
264 |
self.best_acc_val = accuracy |
|
|
265 |
if save_checkpoint: |
|
|
266 |
state = {'model_state': self.model.state_dict(), |
|
|
267 |
'optimizer': self.optimizer, |
|
|
268 |
'best_acc_val': self.best_acc_val, |
|
|
269 |
'epoch': i + 1, |
|
|
270 |
'losses_train': self.losses_train, |
|
|
271 |
'losses_val': self.losses_val, |
|
|
272 |
'acc_train': self.acc_train, |
|
|
273 |
'acc_val': self.acc_val} |
|
|
274 |
filename = file_prefix + 'checkpoint%d.pth' % (i + 1) |
|
|
275 |
torch.save(state, filename) |
|
|
276 |
shutil.copyfile(filename, file_prefix + 'model_best.pth') |
|
|
277 |
|
|
|
278 |
def predict(self, batch_size=100, save_file=True, file_prefix='', X=None, y=None, topk=5, verbose=False): |
|
|
279 |
"""predict |
|
|
280 |
Args: |
|
|
281 |
batch_size: int, default: 100; can be larger for large memory |
|
|
282 |
save_file: bool, default: True; if true, save file |
|
|
283 |
file_prefix: save file name: file_prefix + 'y_test.pth' |
|
|
284 |
X: default: None. If not None, use X instead of self.data['X_test'] |
|
|
285 |
y: default: None. Similary to X |
|
|
286 |
""" |
|
|
287 |
if X is None: |
|
|
288 |
if 'X_test' in self.data: |
|
|
289 |
X = self.data['X_test'] |
|
|
290 |
elif 'test_loader' in self.data: |
|
|
291 |
X = self.data['test_loader'] |
|
|
292 |
dataloader = X |
|
|
293 |
else: |
|
|
294 |
raise ValueError('If X is None, then self.data ' |
|
|
295 |
'must contain either X_test or test_loader') |
|
|
296 |
|
|
|
297 |
if y is None and 'y_test' in self.data: |
|
|
298 |
y = self.data['y_test'] |
|
|
299 |
|
|
|
300 |
is_truth_avail = isinstance(y, dtype['long']) or isinstance(y, dtype['float']) |
|
|
301 |
|
|
|
302 |
if isinstance(X, dtype['float']): |
|
|
303 |
N = X.size(0) |
|
|
304 |
num_chunks = (N + batch_size - 1) // batch_size |
|
|
305 |
X_chunks = X.chunk(num_chunks) |
|
|
306 |
dataloader = X_chunks |
|
|
307 |
|
|
|
308 |
if is_truth_avail: |
|
|
309 |
N = y.size(0) |
|
|
310 |
num_chunks = (N + batch_size - 1) // batch_size |
|
|
311 |
y_chunks = y.chunk(num_chunks) |
|
|
312 |
else: |
|
|
313 |
y_chunks = [None] * num_chunks |
|
|
314 |
|
|
|
315 |
self._reset_avg_meter() |
|
|
316 |
end_time = time.time() |
|
|
317 |
y_pred = [] |
|
|
318 |
for X, y in zip(X_chunks, y_chunks): |
|
|
319 |
X = Variable(X) |
|
|
320 |
y = Variable(y) |
|
|
321 |
|
|
|
322 |
y_pred_tmp = self.model(X) # sometimes model output a tuple |
|
|
323 |
|
|
|
324 |
if is_truth_avail: |
|
|
325 |
loss = self.loss_fn(y_pred_tmp, y) |
|
|
326 |
self.loss_epoch.update(loss.item(), y.size(0)) |
|
|
327 |
if isinstance(y.data, dtype['long']): |
|
|
328 |
res = check_acc(y_pred_tmp, y, (1, topk)) |
|
|
329 |
self.top1.update(res[0].item()) |
|
|
330 |
self.top5.update(res[1].item()) |
|
|
331 |
else: |
|
|
332 |
self.top1.update(1. / (loss.item() + 1.), y.size(0)) |
|
|
333 |
self.batch_time.update(time.time() - end_time) |
|
|
334 |
end_time = time.time() |
|
|
335 |
if isinstance(y_pred_tmp, tuple): |
|
|
336 |
y_pred_tmp = y_pred_tmp[0] |
|
|
337 |
y_pred.append(y_pred_tmp) |
|
|
338 |
|
|
|
339 |
if is_truth_avail and verbose: |
|
|
340 |
print('Test set: loss: {losses.avg:.3f}\t' |
|
|
341 |
'AP@1: {prec1.avg:.3f}\t' |
|
|
342 |
'AP@5: {prec5.avg:.3f}\t' |
|
|
343 |
'batch time: {batch_time.avg:.3f}'.format( |
|
|
344 |
losses=self.loss_epoch, prec1=self.top1, |
|
|
345 |
prec5=self.top5, batch_time=self.batch_time)) |
|
|
346 |
sys.stdout.flush() |
|
|
347 |
y_pred = torch.cat(y_pred, 0) |
|
|
348 |
if save_file: |
|
|
349 |
torch.save({'y_pred': y_pred}, file_prefix + 'y_pred.pth') |
|
|
350 |
return y_pred |
|
|
351 |
|
|
|
352 |
|
|
|
353 |
if __name__ == '__main__': |
|
|
354 |
|
|
|
355 |
mnist_train = torchvision.datasets.MNIST('/projects/academic/jamesjar/tianlema/dl-datasets/mnist', |
|
|
356 |
transform=transforms.Compose([transforms.ToTensor(), |
|
|
357 |
transforms.Normalize((0.1307,), (0.3081,))])) |
|
|
358 |
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size=200) |
|
|
359 |
|
|
|
360 |
mnist_test = torchvision.datasets.MNIST('/projects/academic/jamesjar/tianlema/dl-datasets/mnist', |
|
|
361 |
transform=transforms.Compose([transforms.ToTensor(), |
|
|
362 |
transforms.Normalize((0.1307,), (0.3081,))]), |
|
|
363 |
train=False) |
|
|
364 |
test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=200) |
|
|
365 |
|
|
|
366 |
train = list(train_loader) |
|
|
367 |
train = list(zip(*train)) |
|
|
368 |
X_train = torch.cat(train[0], 0) |
|
|
369 |
y_train = torch.cat(train[1], 0) |
|
|
370 |
|
|
|
371 |
X_val = X_train[50000:] |
|
|
372 |
y_val = y_train[50000:] |
|
|
373 |
X_train = X_train[:50000] |
|
|
374 |
y_train = y_train[:50000] |
|
|
375 |
|
|
|
376 |
test = list(test_loader) |
|
|
377 |
test = list(zip(*test)) |
|
|
378 |
X_test = torch.cat(test[0], 0) |
|
|
379 |
y_test = torch.cat(test[1], 0) |
|
|
380 |
|
|
|
381 |
data = {'X_train': X_train, 'y_train': y_train, 'X_val': X_val, 'y_val': y_val, |
|
|
382 |
'X_test': X_test, 'y_test': y_test} |
|
|
383 |
|
|
|
384 |
|
|
|
385 |
|
|
|
386 |
|
|
|
387 |
model = DenseNet(input_param=(1, 64), block_layers=(6, 4), num_classes=10, |
|
|
388 |
growth_rate=32, bn_size=2, dropout_rate=0, transition_pool_param=(3, 1, 1)) |
|
|
389 |
|
|
|
390 |
|
|
|
391 |
|
|
|
392 |
loss_fn = nn.CrossEntropyLoss() |
|
|
393 |
|
|
|
394 |
|
|
|
395 |
|
|
|
396 |
optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-4) |
|
|
397 |
|
|
|
398 |
|
|
|
399 |
|
|
|
400 |
solver = Solver(model, data, optimizer, loss_fn) |
|
|
401 |
solver.train(num_epoch=2, file_prefix='mnist-') |
|
|
402 |
solver.predict(file_prefix='mnist-') |