Diff of /train.py [000000] .. [df6751]

Switch to unified view

a b/train.py
1
import time
2
import copy
3
import torch
4
from torchnet import meter
5
from torch.autograd import Variable
6
from utils import plot_training
7
8
data_cat = ['train', 'valid'] # data categories
9
10
def train_model(model, criterion, optimizer, dataloaders, scheduler, 
11
                dataset_sizes, num_epochs):
12
    since = time.time()
13
    best_model_wts = copy.deepcopy(model.state_dict())
14
    best_acc = 0.0
15
    costs = {x:[] for x in data_cat} # for storing costs per epoch
16
    accs = {x:[] for x in data_cat} # for storing accuracies per epoch
17
    print('Train batches:', len(dataloaders['train']))
18
    print('Valid batches:', len(dataloaders['valid']), '\n')
19
    for epoch in range(num_epochs):
20
        confusion_matrix = {x: meter.ConfusionMeter(2, normalized=True) 
21
                            for x in data_cat}
22
        print('Epoch {}/{}'.format(epoch+1, num_epochs))
23
        print('-' * 10)
24
        # Each epoch has a training and validation phase
25
        for phase in data_cat:
26
            model.train(phase=='train')
27
            running_loss = 0.0
28
            running_corrects = 0
29
            # Iterate over data.
30
            for i, data in enumerate(dataloaders[phase]):
31
                # get the inputs
32
                print(i, end='\r')
33
                inputs = data['images'][0]
34
                labels = data['label'].type(torch.FloatTensor)
35
                # wrap them in Variable
36
                inputs = Variable(inputs.cuda())
37
                labels = Variable(labels.cuda())
38
                # zero the parameter gradients
39
                optimizer.zero_grad()
40
                # forward
41
                outputs = model(inputs)
42
                outputs = torch.mean(outputs)
43
                loss = criterion(outputs, labels, phase)
44
                running_loss += loss.data[0]
45
                # backward + optimize only if in training phase
46
                if phase == 'train':
47
                    loss.backward()
48
                    optimizer.step()
49
                # statistics
50
                preds = (outputs.data > 0.5).type(torch.cuda.FloatTensor)
51
                running_corrects += torch.sum(preds == labels.data)
52
                confusion_matrix[phase].add(preds, labels.data)
53
            epoch_loss = running_loss / dataset_sizes[phase]
54
            epoch_acc = running_corrects / dataset_sizes[phase]
55
            costs[phase].append(epoch_loss)
56
            accs[phase].append(epoch_acc)
57
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
58
                phase, epoch_loss, epoch_acc))
59
            print('Confusion Meter:\n', confusion_matrix[phase].value())
60
            # deep copy the model
61
            if phase == 'valid':
62
                scheduler.step(epoch_loss)
63
                if epoch_acc > best_acc:
64
                    best_acc = epoch_acc
65
                    best_model_wts = copy.deepcopy(model.state_dict())
66
        time_elapsed = time.time() - since
67
        print('Time elapsed: {:.0f}m {:.0f}s'.format(
68
                time_elapsed // 60, time_elapsed % 60))
69
        print()
70
    time_elapsed = time.time() - since
71
    print('Training complete in {:.0f}m {:.0f}s'.format(
72
        time_elapsed // 60, time_elapsed % 60))
73
    print('Best valid Acc: {:4f}'.format(best_acc))
74
    plot_training(costs, accs)
75
    # load best model weights
76
    model.load_state_dict(best_model_wts)
77
    return model
78
79
80
def get_metrics(model, criterion, dataloaders, dataset_sizes, phase='valid'):
81
    '''
82
    Loops over phase (train or valid) set to determine acc, loss and 
83
    confusion meter of the model.
84
    '''
85
    confusion_matrix = meter.ConfusionMeter(2, normalized=True)
86
    running_loss = 0.0
87
    running_corrects = 0
88
    for i, data in enumerate(dataloaders[phase]):
89
        print(i, end='\r')
90
        labels = data['label'].type(torch.FloatTensor)
91
        inputs = data['images'][0]
92
        # wrap them in Variable
93
        inputs = Variable(inputs.cuda())
94
        labels = Variable(labels.cuda())
95
        # forward
96
        outputs = model(inputs)
97
        outputs = torch.mean(outputs)
98
        loss = criterion(outputs, labels, phase)
99
        # statistics
100
        running_loss += loss.data[0] * inputs.size(0)
101
        preds = (outputs.data > 0.5).type(torch.cuda.FloatTensor)
102
        running_corrects += torch.sum(preds == labels.data)
103
        confusion_matrix.add(preds, labels.data)
104
105
    loss = running_loss / dataset_sizes[phase]
106
    acc = running_corrects / dataset_sizes[phase]
107
    print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, loss, acc))
108
    print('Confusion Meter:\n', confusion_matrix.value())