a b/train.py
1
import numpy as np
2
import torch
3
import csv 
4
import os
5
6
device = 'cuda' if torch.cuda.is_available() else 'cpu'
7
8
9
def validate(model, val_loader, criterion): 
10
    model.eval()
11
    val_loss = []
12
    total_predictions = 0.0
13
    correct_predictions = 0.0
14
    for idx, (X,Y) in enumerate(val_loader):
15
16
        X = X.to(device)
17
        Y = Y.to(device)
18
            
19
        preds = model(X)
20
        
21
        loss = criterion(preds,Y)
22
        
23
        val_loss.append(loss.item())
24
25
        _, predicted = torch.max(preds, 1)
26
        total_predictions += Y.size(0)
27
        correct_predictions += (predicted == Y).sum().item()
28
29
    acc = (correct_predictions/total_predictions)*100.0
30
31
    return np.mean(val_loss), acc
32
33
34
def train(model, num_epochs, criterion, train_loader, val_loader, optimizer, scheduler, verbose = True):
35
    loss_train = []
36
    loss_val = []
37
    acc_train = []
38
    acc_val = []
39
    prev_dev_acc = 0
40
    for epoch in range(num_epochs): 
41
        model.train()
42
        epoch_loss = []
43
        total_predictions = 0.0
44
        correct_predictions = 0.0
45
        for idx, (X, Y) in enumerate(train_loader):
46
            X = X.to(device)
47
            Y = Y.to(device)
48
            
49
            optimizer.zero_grad()
50
            
51
            preds = model(X)
52
            
53
            loss = criterion(preds, Y)
54
55
            _, predicted = torch.max(preds, 1)
56
            total_predictions += Y.size(0)
57
            correct_predictions += (predicted == Y).sum().item()
58
            
59
            loss.backward()
60
            
61
            optimizer.step()
62
63
            epoch_loss.append(loss.item())
64
65
        loss_train.append(np.mean(epoch_loss))
66
        acc = (correct_predictions/total_predictions)*100.0
67
        acc_train.append(acc)
68
69
        epoch_val_loss, epoch_val_acc = validate(model, val_loader, criterion)
70
        loss_val.append(epoch_val_loss)
71
        acc_val.append(epoch_val_acc)
72
        if epoch_val_acc > prev_dev_acc:
73
            
74
            dir_name = "results/"
75
            test = os.listdir(dir_name)
76
77
            for item in test:
78
                if item.endswith(".pth"):
79
                    os.remove(os.path.join(dir_name, item))
80
            
81
            path = 'results/model_parameters_' + str(epoch_val_acc)[0: 5] + '.pth'
82
            torch.save(model.state_dict(), path)
83
            print('Saving model parameters...')
84
            print('Validation accuracy: ', epoch_val_acc)
85
            prev_dev_acc = epoch_val_acc
86
        
87
88
        if scheduler != None:
89
            scheduler.step()
90
91
        if verbose:
92
            print('EPOCH: ' + str(epoch))
93
            print('TRAIN_LOSS: ' + str(loss_train[-1]))
94
            print('TRAIN_ACC: ' + str(acc_train[-1]))
95
            print('VAL_LOSS: ' + str(loss_val[-1]))
96
            print('VAL_ACC: ' + str(acc_val[-1]))
97
            print('+'*25)
98
99
100
    return loss_train, loss_val, acc_train, acc_val
101
102
103
104
def evaluate(model, loader):
105
    observations = []
106
    heading = ['Predictions', 'Labels']
107
    model.eval()
108
    for idx, (X, Y) in enumerate(loader):
109
        X = X.to(device)
110
        Y = Y.to(device)      
111
        preds = model(X)
112
        _, predicted = torch.max(preds, 1)
113
        observations.append([predicted.item(), Y.item()])
114
    
115
    return np.array(observations)
116
117