|
a |
|
b/SNN_model.py |
|
|
1 |
import torch |
|
|
2 |
import torch.nn as nn |
|
|
3 |
import torch.nn.functional as F |
|
|
4 |
|
|
|
5 |
class BERT_Arch(nn.Module): |
|
|
6 |
|
|
|
7 |
def __init__(self, bert): |
|
|
8 |
|
|
|
9 |
super(BERT_Arch, self).__init__() |
|
|
10 |
|
|
|
11 |
self.bert = bert |
|
|
12 |
self.conv1 = nn.Conv1d(in_channels=768, out_channels=128, kernel_size=3, stride=1) # kernal_size=3 == three-grams |
|
|
13 |
self.avg_pooling = nn.AvgPool1d(kernel_size=2) |
|
|
14 |
self.conv2 = nn.Conv1d(in_channels=128, out_channels=64, kernel_size=3, stride=1) |
|
|
15 |
self.flatten = nn.Flatten() |
|
|
16 |
self.fc = nn.Linear(64,128) |
|
|
17 |
self.dropout = nn.Dropout(0.2) |
|
|
18 |
|
|
|
19 |
def forward(self, seq, mask): |
|
|
20 |
|
|
|
21 |
|
|
|
22 |
hs, cls_hs = self.bert(seq, attention_mask=mask, return_dict=False) |
|
|
23 |
|
|
|
24 |
x = hs.permute(0, 2, 1).contiguous() # Permute `hs` to match input shape requirement of `nn.Conv1d` |
|
|
25 |
# The contiguous() ensures the memory of the tensor is stored contiguously |
|
|
26 |
# which helps avoid potential issues during processing. |
|
|
27 |
# Output shape: (b, 768, 70) = (b, embed_dim, max_len_seq). |
|
|
28 |
|
|
|
29 |
x = F.relu(self.conv1(x)) # Output shape: (b, 128, *) * depends on kernel size and padding |
|
|
30 |
x = self.avg_pooling(x) # Output shape: (b, 128, *) |
|
|
31 |
x = F.relu(self.conv2(x)) # Output shape: (b, 128, *) |
|
|
32 |
x = F.max_pool1d(x, kernel_size=x.shape[2]) # Output shape: (b, 128, 1) # trick: we use kernel of size x.shape[2] to reduce from * to 1 |
|
|
33 |
x = self.flatten(x) # Output shape: (b, 128) |
|
|
34 |
x = self.fc(x) # Output shape: (b, 128) |
|
|
35 |
x = self.dropout(x) |
|
|
36 |
|
|
|
37 |
return x |
|
|
38 |
|
|
|
39 |
|
|
|
40 |
class SiameseNeuralNetwork(nn.Module): |
|
|
41 |
|
|
|
42 |
def __init__(self, bert_arch): |
|
|
43 |
super().__init__() |
|
|
44 |
|
|
|
45 |
self.bert_arch = bert_arch |
|
|
46 |
self.distance_layer = nn.Sequential(nn.Linear(128, 1), nn.Sigmoid()) # if we would use BCEWithLogitsLoss as loss function, we should delte the sigmoid since we dont need it after the linear layer a sigmoid layer |
|
|
47 |
|
|
|
48 |
|
|
|
49 |
def forward(self, seq1, seq2, mask1, mask2): |
|
|
50 |
feature_vec1 = self.bert_arch(seq1, mask1) # feature_vec1 shape: [batch_size, embedding_size] |
|
|
51 |
feature_vec2 = self.bert_arch(seq2, mask2) |
|
|
52 |
difference = torch.abs(feature_vec1 - feature_vec2) |
|
|
53 |
out = self.distance_layer(difference) |
|
|
54 |
return out |
|
|
55 |
|
|
|
56 |
class ContrastiveLoss(nn.Module): |
|
|
57 |
""" |
|
|
58 |
Takes embeddings of two samples and a target label == 1 if samples are from the same class and label == 0 otherwise |
|
|
59 |
""" |
|
|
60 |
|
|
|
61 |
def __init__(self, margin): |
|
|
62 |
super(ContrastiveLoss, self).__init__() |
|
|
63 |
self.margin = margin |
|
|
64 |
self.eps = 1e-9 |
|
|
65 |
|
|
|
66 |
def forward(self, output1, output2, target, size_average=True): |
|
|
67 |
distances = (output2 - output1).pow(2).sum(1) # squared distances |
|
|
68 |
losses = 0.5 * (target.float() * distances + |
|
|
69 |
(1 + -1 * target).float() * torch.nn.functional.relu(self.margin - (distances + self.eps).sqrt()).pow(2)) |
|
|
70 |
return losses.mean() if size_average else losses.sum() |