|
a |
|
b/train.py |
|
|
1 |
import torch |
|
|
2 |
import settings |
|
|
3 |
import copy |
|
|
4 |
import time |
|
|
5 |
import torch.optim as optim |
|
|
6 |
|
|
|
7 |
from model.resnet import ResnetModel |
|
|
8 |
from torch.optim import lr_scheduler |
|
|
9 |
from dataloader.dataloader import get_data_loaders |
|
|
10 |
from torchvision import transforms |
|
|
11 |
|
|
|
12 |
|
|
|
13 |
import torch.nn as nn |
|
|
14 |
|
|
|
15 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
16 |
|
|
|
17 |
|
|
|
18 |
def train(model, dataloaders, num_epochs=50): |
|
|
19 |
|
|
|
20 |
dataset_sizes = {phase: len(dataloaders[phase].dataset) for phase in ['train', 'val']} |
|
|
21 |
|
|
|
22 |
if torch.cuda.device_count() > 1: |
|
|
23 |
print("Usando", torch.cuda.device_count(), "GPUs") |
|
|
24 |
model = nn.DataParallel(model) |
|
|
25 |
|
|
|
26 |
model.to(device) |
|
|
27 |
|
|
|
28 |
criterion = nn.CrossEntropyLoss() |
|
|
29 |
|
|
|
30 |
optimizer = optim.SGD(model.parameters(), lr=settings.lr, momentum=settings.momentum) |
|
|
31 |
scheduler = lr_scheduler.StepLR(optimizer, step_size=settings.step_size, gamma=settings.gamma) |
|
|
32 |
|
|
|
33 |
best_acc = 0.0 |
|
|
34 |
best_model_wts = copy.deepcopy(model.state_dict()) |
|
|
35 |
|
|
|
36 |
tic = time.time() |
|
|
37 |
for epoch in range(num_epochs): |
|
|
38 |
print('Epoch {}/{}'.format(epoch+1, num_epochs)) |
|
|
39 |
print('-' * 20) |
|
|
40 |
|
|
|
41 |
for phase in ['train', 'val']: |
|
|
42 |
if phase == 'train': |
|
|
43 |
model.train() |
|
|
44 |
else: |
|
|
45 |
model.eval() |
|
|
46 |
|
|
|
47 |
running_loss = 0.0 |
|
|
48 |
running_corrects = 0 |
|
|
49 |
|
|
|
50 |
for inputs, labels in dataloaders[phase]: |
|
|
51 |
inputs = inputs.to(device) |
|
|
52 |
labels = labels.to(device) |
|
|
53 |
|
|
|
54 |
optimizer.zero_grad() |
|
|
55 |
|
|
|
56 |
with torch.set_grad_enabled(phase == 'train'): |
|
|
57 |
outputs = model(inputs) |
|
|
58 |
_, preds = torch.max(outputs, 1) |
|
|
59 |
loss = criterion(outputs, labels) |
|
|
60 |
|
|
|
61 |
if phase == 'train': |
|
|
62 |
loss.backward() |
|
|
63 |
optimizer.step() |
|
|
64 |
|
|
|
65 |
running_loss += loss.item()*inputs.size(0) |
|
|
66 |
running_corrects += torch.sum(preds == labels.data) |
|
|
67 |
|
|
|
68 |
if phase == 'train': |
|
|
69 |
scheduler.step() |
|
|
70 |
|
|
|
71 |
epoch_loss = running_loss / dataset_sizes[phase] |
|
|
72 |
epoch_acc = running_corrects.double() / dataset_sizes[phase] |
|
|
73 |
|
|
|
74 |
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc)) |
|
|
75 |
|
|
|
76 |
if phase == 'val' and epoch_acc > best_acc: |
|
|
77 |
best_acc = epoch_acc |
|
|
78 |
best_model_wts = copy.deepcopy(model.state_dict()) |
|
|
79 |
|
|
|
80 |
time_elapsed = time.time() - tic |
|
|
81 |
print("Training complete in {:.0f}m {:.0f}s".format(time_elapsed // 60, time_elapsed % 60)) |
|
|
82 |
print("Best val acc: {:.4f}".format(best_acc)) |
|
|
83 |
|
|
|
84 |
model.load_state_dict(best_model_wts) |
|
|
85 |
|
|
|
86 |
torch.save(best_model_wts, 'checkpoints/best_model.pth') |
|
|
87 |
|
|
|
88 |
return model |
|
|
89 |
|
|
|
90 |
|
|
|
91 |
if __name__ == '__main__': |
|
|
92 |
|
|
|
93 |
train_transform = transforms.Compose([ |
|
|
94 |
transforms.Resize(224), |
|
|
95 |
transforms.ToTensor(), |
|
|
96 |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
|
97 |
]) |
|
|
98 |
|
|
|
99 |
val_transform = transforms.Compose([ |
|
|
100 |
transforms.Resize(224), |
|
|
101 |
transforms.ToTensor(), |
|
|
102 |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
|
103 |
]) |
|
|
104 |
|
|
|
105 |
dataloaders = get_data_loaders(train_transform, val_transform) |
|
|
106 |
model = ResnetModel(2) |
|
|
107 |
|
|
|
108 |
train(model, dataloaders, num_epochs=settings.epochs) |