a b/SNN_training.py
1
import torch
2
from transformers import AdamW
3
4
def train_siamese_network(model, dataloaders, num_epochs, device):
5
    """
6
    Train the given SNN model.
7
8
    :param model: SNN model
9
    :param dataloaders: a dict that contains train data loader and validation data loader
10
    :param num_epochs: number of epochs
11
    :param device: 'cpu' or 'cuda'
12
13
    :return:  train_loss_history - list of train losses by epochs
14
              val_loss_history -  list of validation losses by epochs
15
16
    """
17
    train_loss_history = []
18
    val_loss_history = []
19
    matching_similarity = []
20
    non_matching_similarity = []
21
22
    val_matching_similarity = []
23
    val_non_matching_similarity = []
24
25
    criterion = torch.nn.BCELoss(reduction='mean') #ContrastiveLoss(margin=1)  #losses.ContrastiveLoss(pos_margin=0, neg_margin=1) # torch.nn.BCEWithLogitsLoss(reduction='mean') # the labels are same class (1) vs. different class (0)
26
    learning_rate = 0.005 # 0.005 # 0.1
27
    optimizer  =  AdamW(model.parameters(),lr =learning_rate)#torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)  # AdamW(model.parameters(),lr =learning_rate)  1e-5
28
29
    # lr = lr * factor
30
    # mode='min': look for the min validation loss to track
31
    # patience: number of epochs - 1 where loss plateaus before decreasing LR
32
    # patience = 0, after 1 bad epoch, reduce LR
33
    # factor: decaying factor
34
35
    #scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=1, verbose=True, min_lr=0.0001)  ########################################################
36
    #cyclic_scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.0001, max_lr=0.1, cycle_momentum=False) ########################################################
37
38
    for epoch in range(num_epochs):  # loop over the train dataset multiple times
39
40
        # Each epoch has a training and validation phase
41
        for phase in ['train', 'val']:
42
43
            if phase == 'train':
44
                model.train()
45
            else:
46
                model.eval()
47
48
            running_loss = 0.0
49
50
            for i, batch in enumerate(dataloaders[phase]):
51
52
                  seq1, seq2, mask1, mask2, label = batch
53
54
                  if device == 'cuda':
55
                    seq1, seq2, mask1, mask2, label = seq1.to(device), seq2.to(device), mask1.to(device), mask2.to(device), label.to(device)
56
57
                  # zero the parameter gradients
58
                  optimizer.zero_grad()
59
60
                  # track history  only in train
61
                  with torch.set_grad_enabled(phase == 'train'):
62
63
                      # forward
64
                      output = model.forward(seq1, seq2, mask1, mask2)
65
                      loss = criterion(output, label.view(output.size())) # criterion(output.squeeze(0), label.view(1))    label.view((trainLoader.batch_size,1))
66
67
                      # backward + optimize only if in training phase
68
                      if phase == 'train': #  with torch.no_grad() if phae == 'val'?
69
                          loss.backward()
70
                          torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # clip the the gradients to 1.0. It helps in preventing the exploding gradient problem
71
                          optimizer.step()
72
                          #cyclic_scheduler.step() ########################################################
73
74
                          # save similarity scores for training data
75
                          output = output.cpu().detach().numpy()
76
                          label = label.cpu().numpy()
77
                          non_matching_similarity.append((sum(output[label == 0]) / sum(label == 0)).item())
78
                          matching_similarity.append((sum(output[label == 1]) / sum(label == 1)).item())
79
80
                      if phase == 'val':
81
                         val_non_matching_similarity.append((sum(output[label == 0]) / sum(label == 0)).item())
82
                         val_matching_similarity.append((sum(output[label == 1]) / sum(label == 1)).item())
83
84
                  running_loss += loss.item() * seq1.size(0)  #we multiply by the batch size (note that the batch size in the last batch may not be the batch size we did since the batch size dont necceraly divide the train size)
85
86
            epoch_loss = running_loss / len(dataloaders[phase].dataset)
87
            if phase == 'train':
88
               train_loss_history.append(epoch_loss)
89
            else:
90
               val_loss_history.append(epoch_loss)
91
               #scheduler.step(epoch_loss) ########################################################
92
93
            print('Epoch {} | {} loss: {:.3f}'.format(epoch, phase, epoch_loss))
94
95
96
    return train_loss_history, val_loss_history, [non_matching_similarity, matching_similarity, val_non_matching_similarity, val_matching_similarity]
97