|
a |
|
b/train.py |
|
|
1 |
import time |
|
|
2 |
import copy |
|
|
3 |
import torch |
|
|
4 |
from torchnet import meter |
|
|
5 |
from torch.autograd import Variable |
|
|
6 |
from utils import plot_training |
|
|
7 |
|
|
|
8 |
data_cat = ['train', 'valid'] # data categories |
|
|
9 |
|
|
|
10 |
def train_model(model, criterion, optimizer, dataloaders, scheduler, |
|
|
11 |
dataset_sizes, num_epochs): |
|
|
12 |
since = time.time() |
|
|
13 |
best_model_wts = copy.deepcopy(model.state_dict()) |
|
|
14 |
best_acc = 0.0 |
|
|
15 |
costs = {x:[] for x in data_cat} # for storing costs per epoch |
|
|
16 |
accs = {x:[] for x in data_cat} # for storing accuracies per epoch |
|
|
17 |
print('Train batches:', len(dataloaders['train'])) |
|
|
18 |
print('Valid batches:', len(dataloaders['valid']), '\n') |
|
|
19 |
for epoch in range(num_epochs): |
|
|
20 |
confusion_matrix = {x: meter.ConfusionMeter(2, normalized=True) |
|
|
21 |
for x in data_cat} |
|
|
22 |
print('Epoch {}/{}'.format(epoch+1, num_epochs)) |
|
|
23 |
print('-' * 10) |
|
|
24 |
# Each epoch has a training and validation phase |
|
|
25 |
for phase in data_cat: |
|
|
26 |
model.train(phase=='train') |
|
|
27 |
running_loss = 0.0 |
|
|
28 |
running_corrects = 0 |
|
|
29 |
# Iterate over data. |
|
|
30 |
for i, data in enumerate(dataloaders[phase]): |
|
|
31 |
# get the inputs |
|
|
32 |
print(i, end='\r') |
|
|
33 |
inputs = data['images'][0] |
|
|
34 |
labels = data['label'].type(torch.FloatTensor) |
|
|
35 |
# wrap them in Variable |
|
|
36 |
inputs = Variable(inputs.cuda()) |
|
|
37 |
labels = Variable(labels.cuda()) |
|
|
38 |
# zero the parameter gradients |
|
|
39 |
optimizer.zero_grad() |
|
|
40 |
# forward |
|
|
41 |
outputs = model(inputs) |
|
|
42 |
outputs = torch.mean(outputs) |
|
|
43 |
loss = criterion(outputs, labels, phase) |
|
|
44 |
running_loss += loss.data[0] |
|
|
45 |
# backward + optimize only if in training phase |
|
|
46 |
if phase == 'train': |
|
|
47 |
loss.backward() |
|
|
48 |
optimizer.step() |
|
|
49 |
# statistics |
|
|
50 |
preds = (outputs.data > 0.5).type(torch.cuda.FloatTensor) |
|
|
51 |
running_corrects += torch.sum(preds == labels.data) |
|
|
52 |
confusion_matrix[phase].add(preds, labels.data) |
|
|
53 |
epoch_loss = running_loss / dataset_sizes[phase] |
|
|
54 |
epoch_acc = running_corrects / dataset_sizes[phase] |
|
|
55 |
costs[phase].append(epoch_loss) |
|
|
56 |
accs[phase].append(epoch_acc) |
|
|
57 |
print('{} Loss: {:.4f} Acc: {:.4f}'.format( |
|
|
58 |
phase, epoch_loss, epoch_acc)) |
|
|
59 |
print('Confusion Meter:\n', confusion_matrix[phase].value()) |
|
|
60 |
# deep copy the model |
|
|
61 |
if phase == 'valid': |
|
|
62 |
scheduler.step(epoch_loss) |
|
|
63 |
if epoch_acc > best_acc: |
|
|
64 |
best_acc = epoch_acc |
|
|
65 |
best_model_wts = copy.deepcopy(model.state_dict()) |
|
|
66 |
time_elapsed = time.time() - since |
|
|
67 |
print('Time elapsed: {:.0f}m {:.0f}s'.format( |
|
|
68 |
time_elapsed // 60, time_elapsed % 60)) |
|
|
69 |
print() |
|
|
70 |
time_elapsed = time.time() - since |
|
|
71 |
print('Training complete in {:.0f}m {:.0f}s'.format( |
|
|
72 |
time_elapsed // 60, time_elapsed % 60)) |
|
|
73 |
print('Best valid Acc: {:4f}'.format(best_acc)) |
|
|
74 |
plot_training(costs, accs) |
|
|
75 |
# load best model weights |
|
|
76 |
model.load_state_dict(best_model_wts) |
|
|
77 |
return model |
|
|
78 |
|
|
|
79 |
|
|
|
80 |
def get_metrics(model, criterion, dataloaders, dataset_sizes, phase='valid'): |
|
|
81 |
''' |
|
|
82 |
Loops over phase (train or valid) set to determine acc, loss and |
|
|
83 |
confusion meter of the model. |
|
|
84 |
''' |
|
|
85 |
confusion_matrix = meter.ConfusionMeter(2, normalized=True) |
|
|
86 |
running_loss = 0.0 |
|
|
87 |
running_corrects = 0 |
|
|
88 |
for i, data in enumerate(dataloaders[phase]): |
|
|
89 |
print(i, end='\r') |
|
|
90 |
labels = data['label'].type(torch.FloatTensor) |
|
|
91 |
inputs = data['images'][0] |
|
|
92 |
# wrap them in Variable |
|
|
93 |
inputs = Variable(inputs.cuda()) |
|
|
94 |
labels = Variable(labels.cuda()) |
|
|
95 |
# forward |
|
|
96 |
outputs = model(inputs) |
|
|
97 |
outputs = torch.mean(outputs) |
|
|
98 |
loss = criterion(outputs, labels, phase) |
|
|
99 |
# statistics |
|
|
100 |
running_loss += loss.data[0] * inputs.size(0) |
|
|
101 |
preds = (outputs.data > 0.5).type(torch.cuda.FloatTensor) |
|
|
102 |
running_corrects += torch.sum(preds == labels.data) |
|
|
103 |
confusion_matrix.add(preds, labels.data) |
|
|
104 |
|
|
|
105 |
loss = running_loss / dataset_sizes[phase] |
|
|
106 |
acc = running_corrects / dataset_sizes[phase] |
|
|
107 |
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, loss, acc)) |
|
|
108 |
print('Confusion Meter:\n', confusion_matrix.value()) |