|
a |
|
b/SNN_training.py |
|
|
1 |
import torch |
|
|
2 |
from transformers import AdamW |
|
|
3 |
|
|
|
4 |
def train_siamese_network(model, dataloaders, num_epochs, device): |
|
|
5 |
""" |
|
|
6 |
Train the given SNN model. |
|
|
7 |
|
|
|
8 |
:param model: SNN model |
|
|
9 |
:param dataloaders: a dict that contains train data loader and validation data loader |
|
|
10 |
:param num_epochs: number of epochs |
|
|
11 |
:param device: 'cpu' or 'cuda' |
|
|
12 |
|
|
|
13 |
:return: train_loss_history - list of train losses by epochs |
|
|
14 |
val_loss_history - list of validation losses by epochs |
|
|
15 |
|
|
|
16 |
""" |
|
|
17 |
train_loss_history = [] |
|
|
18 |
val_loss_history = [] |
|
|
19 |
matching_similarity = [] |
|
|
20 |
non_matching_similarity = [] |
|
|
21 |
|
|
|
22 |
val_matching_similarity = [] |
|
|
23 |
val_non_matching_similarity = [] |
|
|
24 |
|
|
|
25 |
criterion = torch.nn.BCELoss(reduction='mean') #ContrastiveLoss(margin=1) #losses.ContrastiveLoss(pos_margin=0, neg_margin=1) # torch.nn.BCEWithLogitsLoss(reduction='mean') # the labels are same class (1) vs. different class (0) |
|
|
26 |
learning_rate = 0.005 # 0.005 # 0.1 |
|
|
27 |
optimizer = AdamW(model.parameters(),lr =learning_rate)#torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # AdamW(model.parameters(),lr =learning_rate) 1e-5 |
|
|
28 |
|
|
|
29 |
# lr = lr * factor |
|
|
30 |
# mode='min': look for the min validation loss to track |
|
|
31 |
# patience: number of epochs - 1 where loss plateaus before decreasing LR |
|
|
32 |
# patience = 0, after 1 bad epoch, reduce LR |
|
|
33 |
# factor: decaying factor |
|
|
34 |
|
|
|
35 |
#scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=1, verbose=True, min_lr=0.0001) ######################################################## |
|
|
36 |
#cyclic_scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.0001, max_lr=0.1, cycle_momentum=False) ######################################################## |
|
|
37 |
|
|
|
38 |
for epoch in range(num_epochs): # loop over the train dataset multiple times |
|
|
39 |
|
|
|
40 |
# Each epoch has a training and validation phase |
|
|
41 |
for phase in ['train', 'val']: |
|
|
42 |
|
|
|
43 |
if phase == 'train': |
|
|
44 |
model.train() |
|
|
45 |
else: |
|
|
46 |
model.eval() |
|
|
47 |
|
|
|
48 |
running_loss = 0.0 |
|
|
49 |
|
|
|
50 |
for i, batch in enumerate(dataloaders[phase]): |
|
|
51 |
|
|
|
52 |
seq1, seq2, mask1, mask2, label = batch |
|
|
53 |
|
|
|
54 |
if device == 'cuda': |
|
|
55 |
seq1, seq2, mask1, mask2, label = seq1.to(device), seq2.to(device), mask1.to(device), mask2.to(device), label.to(device) |
|
|
56 |
|
|
|
57 |
# zero the parameter gradients |
|
|
58 |
optimizer.zero_grad() |
|
|
59 |
|
|
|
60 |
# track history only in train |
|
|
61 |
with torch.set_grad_enabled(phase == 'train'): |
|
|
62 |
|
|
|
63 |
# forward |
|
|
64 |
output = model.forward(seq1, seq2, mask1, mask2) |
|
|
65 |
loss = criterion(output, label.view(output.size())) # criterion(output.squeeze(0), label.view(1)) label.view((trainLoader.batch_size,1)) |
|
|
66 |
|
|
|
67 |
# backward + optimize only if in training phase |
|
|
68 |
if phase == 'train': # with torch.no_grad() if phae == 'val'? |
|
|
69 |
loss.backward() |
|
|
70 |
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # clip the the gradients to 1.0. It helps in preventing the exploding gradient problem |
|
|
71 |
optimizer.step() |
|
|
72 |
#cyclic_scheduler.step() ######################################################## |
|
|
73 |
|
|
|
74 |
# save similarity scores for training data |
|
|
75 |
output = output.cpu().detach().numpy() |
|
|
76 |
label = label.cpu().numpy() |
|
|
77 |
non_matching_similarity.append((sum(output[label == 0]) / sum(label == 0)).item()) |
|
|
78 |
matching_similarity.append((sum(output[label == 1]) / sum(label == 1)).item()) |
|
|
79 |
|
|
|
80 |
if phase == 'val': |
|
|
81 |
val_non_matching_similarity.append((sum(output[label == 0]) / sum(label == 0)).item()) |
|
|
82 |
val_matching_similarity.append((sum(output[label == 1]) / sum(label == 1)).item()) |
|
|
83 |
|
|
|
84 |
running_loss += loss.item() * seq1.size(0) #we multiply by the batch size (note that the batch size in the last batch may not be the batch size we did since the batch size dont necceraly divide the train size) |
|
|
85 |
|
|
|
86 |
epoch_loss = running_loss / len(dataloaders[phase].dataset) |
|
|
87 |
if phase == 'train': |
|
|
88 |
train_loss_history.append(epoch_loss) |
|
|
89 |
else: |
|
|
90 |
val_loss_history.append(epoch_loss) |
|
|
91 |
#scheduler.step(epoch_loss) ######################################################## |
|
|
92 |
|
|
|
93 |
print('Epoch {} | {} loss: {:.3f}'.format(epoch, phase, epoch_loss)) |
|
|
94 |
|
|
|
95 |
|
|
|
96 |
return train_loss_history, val_loss_history, [non_matching_similarity, matching_similarity, val_non_matching_similarity, val_matching_similarity] |
|
|
97 |
|