|
a |
|
b/train.py |
|
|
1 |
import numpy as np |
|
|
2 |
import torch |
|
|
3 |
import csv |
|
|
4 |
import os |
|
|
5 |
|
|
|
6 |
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
7 |
|
|
|
8 |
|
|
|
9 |
def validate(model, val_loader, criterion): |
|
|
10 |
model.eval() |
|
|
11 |
val_loss = [] |
|
|
12 |
total_predictions = 0.0 |
|
|
13 |
correct_predictions = 0.0 |
|
|
14 |
for idx, (X,Y) in enumerate(val_loader): |
|
|
15 |
|
|
|
16 |
X = X.to(device) |
|
|
17 |
Y = Y.to(device) |
|
|
18 |
|
|
|
19 |
preds = model(X) |
|
|
20 |
|
|
|
21 |
loss = criterion(preds,Y) |
|
|
22 |
|
|
|
23 |
val_loss.append(loss.item()) |
|
|
24 |
|
|
|
25 |
_, predicted = torch.max(preds, 1) |
|
|
26 |
total_predictions += Y.size(0) |
|
|
27 |
correct_predictions += (predicted == Y).sum().item() |
|
|
28 |
|
|
|
29 |
acc = (correct_predictions/total_predictions)*100.0 |
|
|
30 |
|
|
|
31 |
return np.mean(val_loss), acc |
|
|
32 |
|
|
|
33 |
|
|
|
34 |
def train(model, num_epochs, criterion, train_loader, val_loader, optimizer, scheduler, verbose = True): |
|
|
35 |
loss_train = [] |
|
|
36 |
loss_val = [] |
|
|
37 |
acc_train = [] |
|
|
38 |
acc_val = [] |
|
|
39 |
prev_dev_acc = 0 |
|
|
40 |
for epoch in range(num_epochs): |
|
|
41 |
model.train() |
|
|
42 |
epoch_loss = [] |
|
|
43 |
total_predictions = 0.0 |
|
|
44 |
correct_predictions = 0.0 |
|
|
45 |
for idx, (X, Y) in enumerate(train_loader): |
|
|
46 |
X = X.to(device) |
|
|
47 |
Y = Y.to(device) |
|
|
48 |
|
|
|
49 |
optimizer.zero_grad() |
|
|
50 |
|
|
|
51 |
preds = model(X) |
|
|
52 |
|
|
|
53 |
loss = criterion(preds, Y) |
|
|
54 |
|
|
|
55 |
_, predicted = torch.max(preds, 1) |
|
|
56 |
total_predictions += Y.size(0) |
|
|
57 |
correct_predictions += (predicted == Y).sum().item() |
|
|
58 |
|
|
|
59 |
loss.backward() |
|
|
60 |
|
|
|
61 |
optimizer.step() |
|
|
62 |
|
|
|
63 |
epoch_loss.append(loss.item()) |
|
|
64 |
|
|
|
65 |
loss_train.append(np.mean(epoch_loss)) |
|
|
66 |
acc = (correct_predictions/total_predictions)*100.0 |
|
|
67 |
acc_train.append(acc) |
|
|
68 |
|
|
|
69 |
epoch_val_loss, epoch_val_acc = validate(model, val_loader, criterion) |
|
|
70 |
loss_val.append(epoch_val_loss) |
|
|
71 |
acc_val.append(epoch_val_acc) |
|
|
72 |
if epoch_val_acc > prev_dev_acc: |
|
|
73 |
|
|
|
74 |
dir_name = "results/" |
|
|
75 |
test = os.listdir(dir_name) |
|
|
76 |
|
|
|
77 |
for item in test: |
|
|
78 |
if item.endswith(".pth"): |
|
|
79 |
os.remove(os.path.join(dir_name, item)) |
|
|
80 |
|
|
|
81 |
path = 'results/model_parameters_' + str(epoch_val_acc)[0: 5] + '.pth' |
|
|
82 |
torch.save(model.state_dict(), path) |
|
|
83 |
print('Saving model parameters...') |
|
|
84 |
print('Validation accuracy: ', epoch_val_acc) |
|
|
85 |
prev_dev_acc = epoch_val_acc |
|
|
86 |
|
|
|
87 |
|
|
|
88 |
if scheduler != None: |
|
|
89 |
scheduler.step() |
|
|
90 |
|
|
|
91 |
if verbose: |
|
|
92 |
print('EPOCH: ' + str(epoch)) |
|
|
93 |
print('TRAIN_LOSS: ' + str(loss_train[-1])) |
|
|
94 |
print('TRAIN_ACC: ' + str(acc_train[-1])) |
|
|
95 |
print('VAL_LOSS: ' + str(loss_val[-1])) |
|
|
96 |
print('VAL_ACC: ' + str(acc_val[-1])) |
|
|
97 |
print('+'*25) |
|
|
98 |
|
|
|
99 |
|
|
|
100 |
return loss_train, loss_val, acc_train, acc_val |
|
|
101 |
|
|
|
102 |
|
|
|
103 |
|
|
|
104 |
def evaluate(model, loader): |
|
|
105 |
observations = [] |
|
|
106 |
heading = ['Predictions', 'Labels'] |
|
|
107 |
model.eval() |
|
|
108 |
for idx, (X, Y) in enumerate(loader): |
|
|
109 |
X = X.to(device) |
|
|
110 |
Y = Y.to(device) |
|
|
111 |
preds = model(X) |
|
|
112 |
_, predicted = torch.max(preds, 1) |
|
|
113 |
observations.append([predicted.item(), Y.item()]) |
|
|
114 |
|
|
|
115 |
return np.array(observations) |
|
|
116 |
|
|
|
117 |
|