a b/SNN_model.py
1
import torch
2
import torch.nn as nn
3
import torch.nn.functional as F
4
5
class BERT_Arch(nn.Module):
6
7
    def __init__(self, bert):
8
9
      super(BERT_Arch, self).__init__()
10
11
      self.bert = bert
12
      self.conv1 = nn.Conv1d(in_channels=768, out_channels=128, kernel_size=3, stride=1) # kernal_size=3 == three-grams
13
      self.avg_pooling = nn.AvgPool1d(kernel_size=2)
14
      self.conv2 = nn.Conv1d(in_channels=128, out_channels=64, kernel_size=3, stride=1)
15
      self.flatten = nn.Flatten()
16
      self.fc = nn.Linear(64,128)
17
      self.dropout = nn.Dropout(0.2)
18
19
    def forward(self, seq, mask):
20
21
22
      hs, cls_hs = self.bert(seq, attention_mask=mask, return_dict=False)
23
24
      x = hs.permute(0, 2, 1).contiguous()          # Permute `hs` to match input shape requirement of `nn.Conv1d`
25
                                                    # The contiguous() ensures the memory of the tensor is stored contiguously
26
                                                    # which helps avoid potential issues during processing.
27
                                                    # Output shape: (b, 768, 70) = (b, embed_dim, max_len_seq).
28
29
      x = F.relu(self.conv1(x))                     # Output shape: (b, 128, *)  * depends on kernel size and padding
30
      x = self.avg_pooling(x)                       # Output shape: (b, 128, *)
31
      x = F.relu(self.conv2(x))                     # Output shape: (b, 128, *)
32
      x = F.max_pool1d(x, kernel_size=x.shape[2])   # Output shape: (b, 128, 1) # trick: we use kernel of size x.shape[2] to reduce from * to 1
33
      x = self.flatten(x)                           # Output shape: (b, 128)
34
      x = self.fc(x)                                # Output shape: (b, 128)
35
      x = self.dropout(x)
36
37
      return x
38
39
40
class SiameseNeuralNetwork(nn.Module):
41
42
    def __init__(self, bert_arch):
43
        super().__init__()
44
45
        self.bert_arch = bert_arch
46
        self.distance_layer = nn.Sequential(nn.Linear(128, 1), nn.Sigmoid())  # if we would use BCEWithLogitsLoss as loss function, we should delte the sigmoid since we dont need it after the linear layer a sigmoid layer
47
48
49
    def forward(self, seq1, seq2, mask1, mask2):
50
        feature_vec1 = self.bert_arch(seq1, mask1) # feature_vec1 shape:  [batch_size, embedding_size]
51
        feature_vec2 = self.bert_arch(seq2, mask2)
52
        difference = torch.abs(feature_vec1 - feature_vec2)
53
        out = self.distance_layer(difference)
54
        return out
55
56
class ContrastiveLoss(nn.Module):
57
    """
58
    Takes embeddings of two samples and a target label == 1 if samples are from the same class and label == 0 otherwise
59
    """
60
61
    def __init__(self, margin):
62
        super(ContrastiveLoss, self).__init__()
63
        self.margin = margin
64
        self.eps = 1e-9
65
66
    def forward(self, output1, output2, target, size_average=True):
67
        distances = (output2 - output1).pow(2).sum(1)  # squared distances
68
        losses = 0.5 * (target.float() * distances +
69
                        (1 + -1 * target).float() *  torch.nn.functional.relu(self.margin - (distances + self.eps).sqrt()).pow(2))
70
        return losses.mean() if size_average else losses.sum()