Diff of /src/training.py [000000] .. [f45789]

Switch to unified view

a b/src/training.py
1
import torch
2
import os
3
import copy
4
from tqdm import tqdm
5
import matplotlib.pyplot as plt
6
7
def save_losses_graph(train_losses, valid_losses, experiment_dir):
8
    epochs = range(len(train_losses))
9
    plt.plot(epochs, train_losses, 'g', label='Training loss')
10
    plt.plot(epochs, valid_losses, 'b', label='Validation loss')
11
    plt.title('Training and Validation loss')
12
    plt.xlabel('Epochs')
13
    plt.ylabel('Loss')
14
    plt.legend(loc = 'upper right')
15
    plt.savefig(os.path.join(experiment_dir, 'train_valid_loss.png'))
16
    plt.clf()
17
18
def save_acc_graph(train_accs, valid_accs, experiment_dir):
19
    epochs = range(len(train_accs))
20
    plt.plot(epochs, train_accs, 'g', label='Training acc')
21
    plt.plot(epochs, valid_accs, 'b', label='Validation acc')
22
    plt.title('Training and Validation acc')
23
    plt.xlabel('Epochs')
24
    plt.ylabel('Acc')
25
    plt.legend(loc = 'upper right')
26
    plt.savefig(os.path.join(experiment_dir, 'train_valid_acc.png'))
27
    plt.clf()
28
29
30
def train(conf):
31
    device = conf['device']
32
    model = conf['model'].to(device)
33
    dataloaders = conf['dataloaders']
34
    criterion = conf['criterion']
35
    optimizer = conf['optimizer']
36
    scheduler = conf['scheduler']
37
    num_epochs = conf['num_epochs']
38
    experiment_dir = conf['experiment_dir']
39
40
    #valid_acc_history = []
41
    best_acc = 0.0
42
43
    epoch_bar = tqdm(range(num_epochs), desc='Epoch',unit='epoch')
44
    train_losses = []
45
    valid_losses = []
46
    train_accs = []
47
    valid_accs = []
48
    train_loss = -1.0
49
    valid_loss = -1.0
50
    train_acc = -1.0
51
    valid_acc = -1.0
52
    for epoch in epoch_bar:
53
        # Each epoch has a training and validation phase
54
        for phase in ['train', 'valid']:
55
            if phase == 'train':
56
                model.train()  # Set model to training mode
57
            else:
58
                model.eval()   # Set model to evaluate mode
59
60
            running_loss = 0.0
61
            running_corrects = 0
62
63
            # Iterate over data.
64
            batch_bar = tqdm(dataloaders[phase],
65
                             desc='Batch',
66
                             unit='batch',
67
                             leave=False)
68
            batch_losses = []
69
            for inputs, labels in batch_bar:
70
                inputs = inputs.to(device)
71
                labels = labels.to(device)
72
73
                # zero the parameter gradients
74
                optimizer.zero_grad()
75
76
                # forward
77
                # track history if only in train
78
                with torch.set_grad_enabled(phase == 'train'):
79
                    outputs = model(inputs)
80
                    loss = criterion(outputs, labels)
81
82
                    _, preds = torch.max(outputs, 1)
83
84
                    if phase == 'train':
85
                        loss.backward()
86
                        optimizer.step()
87
88
                # statistics
89
                running_corrects += torch.sum(preds == labels.data)
90
                batch_loss = loss.item()
91
                running_loss += batch_loss
92
                batch_losses.append(batch_loss)
93
94
95
                batch_bar.set_postfix(phase=phase, batch_loss=batch_loss)
96
97
            epoch_loss = running_loss / len(dataloaders[phase])
98
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
99
            if phase == 'train':
100
                train_loss = epoch_loss
101
                train_acc = epoch_acc.item()
102
                train_losses.append(train_loss)
103
                train_accs.append(train_acc)
104
            else:
105
                valid_loss = epoch_loss
106
                valid_acc = epoch_acc.item()
107
                valid_losses.append(valid_loss)
108
                valid_accs.append(valid_acc)
109
                save_losses_graph(train_losses, valid_losses, experiment_dir)
110
                save_acc_graph(train_accs, valid_accs, experiment_dir)
111
112
            if phase == 'valid' and epoch_acc > best_acc:
113
                best_acc = epoch_acc
114
                best_weights = copy.deepcopy(model.state_dict())
115
                weights_file = 'best_weights.pt'
116
                weights_path = os.path.join(experiment_dir, weights_file)
117
                torch.save(best_weights, weights_path)
118
119
        epoch_bar.set_postfix(tloss=train_loss,tacc=train_acc,
120
                              vloss=valid_loss, vacc=valid_acc)
121
        scheduler.step()
122
123
    print('Best valid Acc: {:4f}'.format(best_acc))
124
125
if __name__ == '__main__':
126
    from config import get_config
127
    # Get config from conf.yaml
128
    conf = get_config('./conf/training.yaml')
129
130
    # Train model
131
    train(conf)