Diff of /train.py [000000] .. [ccc736]

Switch to unified view

a b/train.py
1
import torch
2
import settings
3
import copy
4
import time
5
import torch.optim as optim
6
7
from model.resnet import ResnetModel
8
from torch.optim import lr_scheduler
9
from dataloader.dataloader import get_data_loaders
10
from torchvision import transforms
11
12
13
import torch.nn as nn
14
15
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
17
18
def train(model, dataloaders, num_epochs=50):
19
20
    dataset_sizes = {phase: len(dataloaders[phase].dataset) for phase in ['train', 'val']}
21
22
    if torch.cuda.device_count() > 1:
23
        print("Usando", torch.cuda.device_count(), "GPUs")
24
        model = nn.DataParallel(model)
25
26
    model.to(device)
27
28
    criterion = nn.CrossEntropyLoss()
29
30
    optimizer = optim.SGD(model.parameters(), lr=settings.lr, momentum=settings.momentum)
31
    scheduler = lr_scheduler.StepLR(optimizer, step_size=settings.step_size, gamma=settings.gamma)
32
33
    best_acc = 0.0
34
    best_model_wts = copy.deepcopy(model.state_dict())
35
36
    tic = time.time()
37
    for epoch in range(num_epochs):
38
        print('Epoch {}/{}'.format(epoch+1, num_epochs))
39
        print('-' * 20)
40
41
        for phase in ['train', 'val']:
42
            if phase == 'train':
43
                model.train()
44
            else:
45
                model.eval()
46
47
            running_loss = 0.0
48
            running_corrects = 0
49
50
            for inputs, labels in dataloaders[phase]:
51
                inputs = inputs.to(device)
52
                labels = labels.to(device)
53
54
                optimizer.zero_grad()
55
56
                with torch.set_grad_enabled(phase == 'train'):
57
                    outputs = model(inputs)
58
                    _, preds = torch.max(outputs, 1)
59
                    loss = criterion(outputs, labels)
60
61
                    if phase == 'train':
62
                        loss.backward()
63
                        optimizer.step()
64
65
                running_loss += loss.item()*inputs.size(0)
66
                running_corrects += torch.sum(preds == labels.data)
67
68
            if phase == 'train':
69
                scheduler.step()
70
71
            epoch_loss = running_loss / dataset_sizes[phase]
72
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
73
74
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
75
76
            if phase == 'val' and epoch_acc > best_acc:
77
                best_acc = epoch_acc
78
                best_model_wts = copy.deepcopy(model.state_dict())
79
80
    time_elapsed = time.time() - tic
81
    print("Training complete in {:.0f}m {:.0f}s".format(time_elapsed // 60, time_elapsed % 60))
82
    print("Best val acc: {:.4f}".format(best_acc))
83
84
    model.load_state_dict(best_model_wts)
85
86
    torch.save(best_model_wts, 'checkpoints/best_model.pth')
87
88
    return model
89
90
91
if __name__ == '__main__':
92
93
    train_transform = transforms.Compose([
94
        transforms.Resize(224),
95
        transforms.ToTensor(),
96
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
97
    ])
98
99
    val_transform = transforms.Compose([
100
        transforms.Resize(224),
101
        transforms.ToTensor(),
102
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
103
    ])
104
105
    dataloaders = get_data_loaders(train_transform, val_transform)
106
    model = ResnetModel(2)
107
108
    train(model, dataloaders, num_epochs=settings.epochs)