a b/models.py
1
import torch
2
import torch.nn as nn
3
import numpy as np
4
from torch.nn.utils.rnn import pack_padded_sequence,pad_packed_sequence
5
6
7
class LstmAttEncoder(nn.Module):
8
    def __init__(self, in_feat: int = 100):
9
        super().__init__()
10
        self.lstm = nn.LSTM(input_size=in_feat, hidden_size=in_feat, bidirectional=True, batch_first=True)
11
    def forward(self, token_embeds, attention_mask):
12
        batch_size = attention_mask.size(0)
13
        output, (h, c) = self.lstm(token_embeds)
14
        output, lens_output = pad_packed_sequence(output, batch_first=True)
15
        
16
        return output,lens_output
17
18
19
class LstmDecoder(nn.Module):
20
21
    def __init__(self, in_feat: int = 100, dropout_prob: float = 0.1):
22
        super().__init__()
23
24
        self.lstm = nn.LSTM(input_size=in_feat, hidden_size=in_feat, bidirectional=True, batch_first=True)
25
26
    def forward(self, token_embeds, attention_mask):
27
        batch_size = attention_mask.size(0)
28
        output, (h, c) = self.lstm(token_embeds)
29
        output, lens_output = pad_packed_sequence(output, batch_first=True)   # [B, L, H]
30
        output = torch.stack([torch.mean(output[i][:lens_output[i]], dim=0) for i in range(batch_size)], dim=0)
31
        
32
        return output   
33
    
34
class Encoder(nn.Module):
35
    def __init__(self,in_feat=100,dropout_prob=0.1) -> None:
36
        super().__init__()
37
        self.linear1 = nn.Linear(in_feat,in_feat)
38
        self.linear2 = nn.Linear(in_feat,in_feat)
39
        self.act = nn.Tanh()
40
        self.dropout = nn.Dropout(dropout_prob)
41
        
42
    def forward(self,token_embeds,attention_mask):
43
        batch_size = token_embeds.size(0)
44
        
45
        x = torch.stack([token_embeds[i,attention_mask[i,:],:].sum(dim=0) for i in range(batch_size)],dim=0)
46
        x = self.act(self.linear1(x))
47
        x = self.act(self.linear2(self.dropout(x)))
48
        
49
        return x
50
51
class LstmEncoder(nn.Module):
52
53
    def __init__(self, in_feat: int = 100, dropout_prob: float = 0.1):
54
        super().__init__()
55
56
        self.lstm = nn.LSTM(input_size=in_feat, hidden_size=in_feat, bidirectional=True, batch_first=True)
57
58
    def forward(self, token_embeds, attention_mask):
59
        batch_size = attention_mask.size(0)
60
        output, (h, c) = self.lstm(token_embeds)
61
        output, lens_output = pad_packed_sequence(output, batch_first=True)
62
        # 双向LSTM出来的hidden states做平均
63
        output = torch.stack([output[i][:lens_output[i]].mean(dim=0) for i in range(batch_size)], dim=0)
64
        return output
65
66
67
class Classifier(nn.Module):
68
    def __init__(self, in_feat, num_labels: int, dropout_prob: float = 0.1):
69
        super().__init__()
70
        self.dense1 = nn.Linear(in_feat, in_feat // 2)
71
        self.dense2 = nn.Linear(in_feat // 2, num_labels)
72
        self.act = nn.Tanh()
73
        self.dropout = nn.Dropout(dropout_prob)
74
75
    def forward(self, x):
76
        x = self.act(self.dense1(self.dropout(x)))
77
        x = self.dense2(self.dropout(x))
78
        return x     
79
    
80
class SemNN(nn.Module):
81
    def __init__(self,
82
        in_feat = 100,
83
        num_labels = 3,
84
        dropout_prob = 0.1,
85
        w2v_mapping = None,
86
        vocab_size = None,
87
        word_embedding_dim = None
88
    ):
89
        super().__init__()
90
        self.num_labels = num_labels
91
        self._init_word_embedding(w2v_mapping,vocab_size,word_embedding_dim)
92
        self.encoder = Encoder(in_feat=in_feat)
93
        self.classifier = Classifier(in_feat=2*in_feat,num_labels=num_labels,dropout_prob=dropout_prob)
94
        
95
    def _init_word_embedding(self,state_dict=None,vocab_size=None,word_embedding_dim=None):
96
        if state_dict is None:
97
            self.word_embedding = nn.Embedding(vocab_size, word_embedding_dim, padding_idx=0)
98
        else:
99
            state_dict = torch.tensor(state_dict.vectors, dtype=torch.float32)
100
            state_dict[0] = torch.zeros(state_dict.size(-1))
101
            self.word_embedding = nn.Embedding.from_pretrained(state_dict, freeze=True, padding_idx=0)
102
    
103
    def forward(self,
104
                text_a_inputs_id,
105
                text_b_inputs_id,
106
                text_a_attention_mask,
107
                text_b_attention_mask):
108
        
109
        text_a_vec = self.word_embedding(text_a_inputs_id)
110
        text_b_vec = self.word_embedding(text_b_inputs_id)
111
        
112
        text_a_vec = self.encoder(text_a_vec,text_a_attention_mask)
113
        text_b_vec = self.encoder(text_b_vec,text_b_attention_mask)
114
        
115
        pooler_output = torch.cat([text_a_vec,text_b_vec],dim=-1)
116
        logits = self.classifier(pooler_output)
117
        
118
        return logits
119
120
121
class CrossAttention(nn.Module):
122
    def __init__(self,in_feat,dropout_prob):
123
        super().__init__()
124
        self.dense = nn.Linear(4*in_feat,in_feat//2)
125
        self.act = nn.ReLU()
126
        self.dropout = nn.Dropout(dropout_prob)
127
        
128
    
129
    def forward(self,a,b,mask_a,mask_b):
130
        in_feat = a.size(-1)
131
        
132
        # a:[B,L1,H] b:[B,L2,H]
133
        
134
        # attention score [B,L1,L2]
135
        cross_attn = torch.matmul(a,b.transpose(1,2))
136
        
137
        # ignore b(L2) padding information [B,L1,L2]
138
        row_attn = cross_attn.masked_fill((mask_b==False).unsqueeze(1),-1e9)
139
        row_attn = row_attn.softmax(dim=2) #[B,L1,L2]
140
        
141
        # ignore a(L1) padding information
142
        col_attn = cross_attn.permute(0,2,1).contiguous() #[B,L2,L1]
143
        col_attn = col_attn.masked_fill((mask_a==False).unsqueeze(1),-1e9)
144
        col_attn = col_attn.softmax(dim=2) #[B,L2,L1]
145
        
146
        #attention score * value
147
        att_a = torch.matmul(row_attn,b) #[B, L1, H]
148
        att_b = torch.matmul(col_attn,a) #[B, L2, H]
149
        
150
        diff_a = a - att_a
151
        diff_b = b - att_b
152
        prod_a = a * att_a
153
        prod_b = b * att_b
154
        
155
        #Cat
156
        a = torch.cat([a,att_a,diff_a,prod_a],dim=-1)   #[B,L1,4H]
157
        b = torch.cat([b,att_b,diff_b,prod_b],dim=-1)   #[B,L2,4H]
158
        
159
        a = self.act(self.dense(self.dropout(a))) #[B,L1,H/2]
160
        b = self.act(self.dense(self.dropout(b))) #[B,L2,H/2]
161
        
162
        return a,b
163
        
164
165
166
class SemLSTM(nn.Module):
167
    def __init__(self,
168
        in_feat = 100,
169
        num_labels = 3,
170
        dropout_prob = 0.1,
171
        w2v_mapping = None,
172
        vocab_size = None,
173
        word_embedding_dim = None
174
    ):
175
        super().__init__()
176
        self.num_labels = num_labels
177
        self._init_word_embedding(w2v_mapping,vocab_size,word_embedding_dim)
178
        self.encoder = LstmEncoder(in_feat=in_feat)
179
        self.classifier = Classifier(in_feat=4*in_feat,num_labels=num_labels,dropout_prob=dropout_prob)
180
        
181
    def _init_word_embedding(self,state_dict=None,vocab_size=None,word_embedding_dim=None):
182
        if state_dict is None:
183
            self.word_embedding = nn.Embedding(vocab_size, word_embedding_dim, padding_idx=0)
184
        else:
185
            state_dict = torch.tensor(state_dict.vectors, dtype=torch.float32)
186
            state_dict[0] = torch.zeros(state_dict.size(-1))
187
            self.word_embedding = nn.Embedding.from_pretrained(state_dict, freeze=True, padding_idx=0)
188
    
189
    def forward(self,
190
                text_a_inputs_id,
191
                text_b_inputs_id,
192
                text_a_attention_mask,
193
                text_b_attention_mask):
194
        
195
        #Embedding
196
        text_a_vec = self.word_embedding(text_a_inputs_id)
197
        text_b_vec = self.word_embedding(text_b_inputs_id)
198
        
199
        #Pack
200
        text_a_vec = pack_padded_sequence(text_a_vec,text_a_attention_mask.cpu().long().sum(dim=-1),batch_first=True,enforce_sorted=False)
201
        text_b_vec = pack_padded_sequence(text_b_vec,text_b_attention_mask.cpu().long().sum(dim=-1),batch_first=True,enforce_sorted=False)
202
        
203
        #LSTM
204
        text_a_vec = self.encoder(text_a_vec,text_a_attention_mask)
205
        text_b_vec = self.encoder(text_b_vec,text_b_attention_mask)
206
        
207
        #Cat
208
        pooler_output = torch.cat([text_a_vec,text_b_vec],dim=-1)
209
        logits = self.classifier(pooler_output)
210
        
211
        return logits
212
      
213
class SemAttention(nn.Module):
214
    def __init__(self,
215
        in_feat = 100,
216
        num_labels = 3,
217
        dropout_prob = 0.1,
218
        w2v_mapping = None,
219
        vocab_size = None,
220
        word_embedding_dim = None
221
    ):
222
        super().__init__()
223
        self.num_labels = num_labels
224
        self._init_word_embedding(w2v_mapping,vocab_size,word_embedding_dim)
225
        self.encoder = LstmAttEncoder(in_feat=in_feat)
226
        self.classifier = Classifier(in_feat=4*in_feat,num_labels=num_labels,dropout_prob=dropout_prob)
227
        self.crossattention = CrossAttention(in_feat=2*in_feat,dropout_prob=dropout_prob)
228
        self.decoder = LstmDecoder(in_feat=in_feat,dropout_prob=dropout_prob)
229
        
230
    def _init_word_embedding(self,state_dict=None,vocab_size=None,word_embedding_dim=None):
231
        if state_dict is None:
232
            self.word_embedding = nn.Embedding(vocab_size, word_embedding_dim, padding_idx=0)
233
        else:
234
            state_dict = torch.tensor(state_dict.vectors, dtype=torch.float32)
235
            state_dict[0] = torch.zeros(state_dict.size(-1))
236
            self.word_embedding = nn.Embedding.from_pretrained(state_dict, freeze=True, padding_idx=0)
237
    
238
    def forward(self,
239
                text_a_inputs_id,
240
                text_b_inputs_id,
241
                text_a_attention_mask,
242
                text_b_attention_mask):
243
        
244
        #Embedding
245
        text_a_vec = self.word_embedding(text_a_inputs_id) #[B,L1,H]
246
        text_b_vec = self.word_embedding(text_b_inputs_id) #[B,L2,H]
247
        
248
        #Pack
249
        text_a_vec = pack_padded_sequence(text_a_vec,text_a_attention_mask.cpu().long().sum(dim=-1),batch_first=True,enforce_sorted=False)
250
        text_b_vec = pack_padded_sequence(text_b_vec,text_b_attention_mask.cpu().long().sum(dim=-1),batch_first=True,enforce_sorted=False)
251
        text_a_attention_mask = pack_padded_sequence(text_a_attention_mask,text_a_attention_mask.cpu().long().sum(dim=-1),batch_first=True,enforce_sorted=False)
252
        text_b_attention_mask = pack_padded_sequence(text_b_attention_mask,text_b_attention_mask.cpu().long().sum(dim=-1),batch_first=True,enforce_sorted=False)
253
        text_a_attention_mask,_ = pad_packed_sequence(text_a_attention_mask,batch_first=True)
254
        text_b_attention_mask,_ = pad_packed_sequence(text_b_attention_mask,batch_first=True)
255
        
256
        #LSTM_Encoder
257
        text_a_vec,text_a_len = self.encoder(text_a_vec,text_a_attention_mask) #[B,L1,2H]
258
        text_b_vec,text_b_len = self.encoder(text_b_vec,text_b_attention_mask) #[B,L2,2H]
259
        
260
        #cross attention
261
        text_a_vec,text_b_vec = self.crossattention(text_a_vec,text_b_vec,text_a_attention_mask,text_b_attention_mask) #[B,L1,H]
262
        text_a_vec = pack_padded_sequence(text_a_vec,text_a_len,batch_first=True,enforce_sorted=False)
263
        text_b_vec = pack_padded_sequence(text_b_vec,text_b_len,batch_first=True,enforce_sorted=False)
264
        
265
        #Decoder
266
        text_a_vec = self.decoder(text_a_vec,text_a_attention_mask)
267
        text_b_vec = self.decoder(text_b_vec,text_b_attention_mask)
268
        
269
        #Cat
270
        pooler_output = torch.cat([text_a_vec,text_b_vec],dim=-1)
271
        logits = self.classifier(pooler_output)
272
        
273
        return logits