|
a |
|
b/stay_admission/operations.py |
|
|
1 |
import numpy as np |
|
|
2 |
import torch |
|
|
3 |
import torch.nn as nn |
|
|
4 |
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence |
|
|
5 |
import sklearn |
|
|
6 |
|
|
|
7 |
OPS1 = { |
|
|
8 |
'identity': lambda d_model: Identity(d_model), |
|
|
9 |
'ffn': lambda d_model: FFN(d_model), |
|
|
10 |
'interaction_1': lambda d_model: Attention_s1(d_model), |
|
|
11 |
'interaction_2': lambda d_model: Attention_s2(d_model), |
|
|
12 |
} |
|
|
13 |
|
|
|
14 |
OPS2 = { |
|
|
15 |
'identity': lambda d_model: Identity(d_model), |
|
|
16 |
'conv': lambda d_model: Conv(d_model), |
|
|
17 |
'attention': lambda d_model: SelfAttention(d_model), |
|
|
18 |
'rnn': lambda d_model: RNN(d_model), |
|
|
19 |
'ffn': lambda d_model: FFN(d_model), |
|
|
20 |
'interaction_1': lambda d_model: CatFC(d_model), |
|
|
21 |
'interaction_2': lambda d_model: Attention_x(d_model) |
|
|
22 |
} |
|
|
23 |
|
|
|
24 |
OPS3 = { |
|
|
25 |
'identity': lambda d_model: Identity(d_model), |
|
|
26 |
'zero': lambda d_model: Zero(d_model), |
|
|
27 |
} |
|
|
28 |
|
|
|
29 |
OPS4 = { |
|
|
30 |
'sum': lambda d_model: Sum(d_model), |
|
|
31 |
'mul': lambda d_model: Mul(d_model), |
|
|
32 |
} |
|
|
33 |
|
|
|
34 |
class Zero(nn.Module): |
|
|
35 |
def __init__(self, d_model): |
|
|
36 |
super(Zero, self).__init__() |
|
|
37 |
def forward(self, x, masks, lengths): |
|
|
38 |
return torch.mul(x, 0) |
|
|
39 |
|
|
|
40 |
|
|
|
41 |
class Sum(nn.Module): |
|
|
42 |
def __init__(self, d_model): |
|
|
43 |
super(Sum, self).__init__() |
|
|
44 |
def forward(self, all_x): |
|
|
45 |
out = all_x[0] |
|
|
46 |
for x in all_x[1:]: |
|
|
47 |
out += x |
|
|
48 |
return out |
|
|
49 |
|
|
|
50 |
class Mul(nn.Module): |
|
|
51 |
def __init__(self, d_model): |
|
|
52 |
super(Mul, self).__init__() |
|
|
53 |
def forward(self, all_x): |
|
|
54 |
out = all_x[0] |
|
|
55 |
for x in all_x[1:]: |
|
|
56 |
out = out * x |
|
|
57 |
return out |
|
|
58 |
|
|
|
59 |
class CatFC(nn.Module): |
|
|
60 |
def __init__(self, d_model): |
|
|
61 |
super(CatFC, self).__init__() |
|
|
62 |
self.ffn = nn.Sequential(nn.Linear(2*d_model, 4 * d_model), nn.ReLU(), |
|
|
63 |
nn.Linear(4 * d_model, d_model)) |
|
|
64 |
self.layer_norm = nn.LayerNorm(d_model) |
|
|
65 |
def forward(self, current_x, s, other_x): |
|
|
66 |
s_ = s.unsqueeze(1).expand_as(current_x) |
|
|
67 |
x = torch.cat((current_x, s_), dim=-1) |
|
|
68 |
return self.layer_norm(self.ffn(x)) |
|
|
69 |
|
|
|
70 |
|
|
|
71 |
class Conv(nn.Module): |
|
|
72 |
def __init__(self, d_model): |
|
|
73 |
super(Conv, self).__init__() |
|
|
74 |
self.op = nn.Sequential( |
|
|
75 |
nn.ReLU(), |
|
|
76 |
nn.Conv1d(d_model, d_model, 3, padding=1), |
|
|
77 |
nn.BatchNorm1d(d_model, affine=True) |
|
|
78 |
) |
|
|
79 |
# self.batchnm = nn.BatchNorm1d(d_model, affine=True) |
|
|
80 |
# self.conv = nn.Conv1d(d_model, d_model, 3, padding=1) |
|
|
81 |
|
|
|
82 |
def forward(self, x, masks, lengths): |
|
|
83 |
x = self.op(x.permute(0, 2, 1)) |
|
|
84 |
return x.permute(0, 2, 1) |
|
|
85 |
|
|
|
86 |
|
|
|
87 |
class FFN(nn.Module): |
|
|
88 |
|
|
|
89 |
def __init__(self, d_model): |
|
|
90 |
super(FFN, self).__init__() |
|
|
91 |
self.ffn = nn.Sequential(nn.Linear(d_model, 4 * d_model), nn.ReLU(), |
|
|
92 |
nn.Linear(4 * d_model, d_model)) |
|
|
93 |
self.layer_norm = nn.LayerNorm(d_model) |
|
|
94 |
|
|
|
95 |
def forward(self, x, masks, lengths): |
|
|
96 |
x = self.layer_norm(x + self.ffn(x)) |
|
|
97 |
return x |
|
|
98 |
|
|
|
99 |
|
|
|
100 |
class Identity(nn.Module): |
|
|
101 |
def __init__(self, d_model): |
|
|
102 |
super(Identity, self).__init__() |
|
|
103 |
def forward(self, x, masks, lengths): |
|
|
104 |
return x |
|
|
105 |
|
|
|
106 |
|
|
|
107 |
class SelfAttention(nn.Module): |
|
|
108 |
def __init__(self, in_feature, num_head=4, dropout=0.1): |
|
|
109 |
super(SelfAttention, self).__init__() |
|
|
110 |
self.in_feature = in_feature |
|
|
111 |
self.num_head = num_head |
|
|
112 |
self.size_per_head = in_feature // num_head |
|
|
113 |
self.out_dim = num_head * self.size_per_head |
|
|
114 |
assert self.size_per_head * num_head == in_feature |
|
|
115 |
self.q_linear = nn.Linear(in_feature, in_feature, bias=False) |
|
|
116 |
self.k_linear = nn.Linear(in_feature, in_feature, bias=False) |
|
|
117 |
self.v_linear = nn.Linear(in_feature, in_feature, bias=False) |
|
|
118 |
self.fc = nn.Linear(in_feature, in_feature, bias=False) |
|
|
119 |
self.dropout = nn.Dropout(dropout) |
|
|
120 |
self.layer_norm = nn.LayerNorm(in_feature) |
|
|
121 |
|
|
|
122 |
def forward(self, x, attn_mask, lengths): |
|
|
123 |
batch_size = x.size(0) |
|
|
124 |
res = x |
|
|
125 |
query = self.q_linear(x) |
|
|
126 |
key = self.k_linear(x) |
|
|
127 |
value = self.v_linear(x) |
|
|
128 |
|
|
|
129 |
query = query.view(batch_size, self.num_head, -1, self.size_per_head) |
|
|
130 |
key = key.view(batch_size, self.num_head, -1, self.size_per_head) |
|
|
131 |
value = value.view(batch_size, self.num_head, -1, self.size_per_head) |
|
|
132 |
|
|
|
133 |
scale = np.sqrt(self.size_per_head) |
|
|
134 |
energy = torch.matmul(query, key.permute(0, 1, 3, 2)) / scale |
|
|
135 |
|
|
|
136 |
attention = torch.softmax(energy, dim=-1) |
|
|
137 |
x = torch.matmul(attention, value) |
|
|
138 |
x = x.permute(0, 2, 1, 3).contiguous() |
|
|
139 |
x = x.view(batch_size, -1, self.in_feature) |
|
|
140 |
x = self.fc(x) |
|
|
141 |
x = self.dropout(x) |
|
|
142 |
x += res |
|
|
143 |
x = self.layer_norm(x) |
|
|
144 |
return x |
|
|
145 |
|
|
|
146 |
|
|
|
147 |
class Attention_s1(nn.Module): |
|
|
148 |
def __init__(self, in_feature, num_head=4, dropout=0.1): |
|
|
149 |
super(Attention_s1, self).__init__() |
|
|
150 |
self.in_feature = in_feature |
|
|
151 |
self.num_head = num_head |
|
|
152 |
self.size_per_head = in_feature // num_head |
|
|
153 |
self.out_dim = num_head * self.size_per_head |
|
|
154 |
assert self.size_per_head * num_head == in_feature |
|
|
155 |
self.q_linear = nn.Linear(in_feature, in_feature, bias=False) |
|
|
156 |
self.k_linear = nn.Linear(in_feature, in_feature, bias=False) |
|
|
157 |
self.v_linear = nn.Linear(in_feature, in_feature, bias=False) |
|
|
158 |
self.fc = nn.Linear(in_feature, in_feature, bias=False) |
|
|
159 |
self.dropout = nn.Dropout(dropout) |
|
|
160 |
self.layer_norm = nn.LayerNorm(in_feature) |
|
|
161 |
|
|
|
162 |
def forward(self, s, x1, x2): |
|
|
163 |
batch_size = x1.size(0) |
|
|
164 |
s = s.unsqueeze(1) |
|
|
165 |
res = s |
|
|
166 |
query = self.q_linear(s) |
|
|
167 |
key = self.k_linear(x1) |
|
|
168 |
value = self.v_linear(x1) |
|
|
169 |
|
|
|
170 |
query = query.view(batch_size, self.num_head, -1, self.size_per_head) |
|
|
171 |
key = key.view(batch_size, self.num_head, -1, self.size_per_head) |
|
|
172 |
value = value.view(batch_size, self.num_head, -1, self.size_per_head) |
|
|
173 |
|
|
|
174 |
scale = np.sqrt(self.size_per_head) |
|
|
175 |
energy = torch.matmul(query, key.permute(0, 1, 3, 2)) / scale |
|
|
176 |
|
|
|
177 |
attention = torch.softmax(energy, dim=-1) |
|
|
178 |
x = torch.matmul(attention, value) |
|
|
179 |
x = x.permute(0, 2, 1, 3).contiguous() |
|
|
180 |
x = x.view(batch_size, -1, self.in_feature) |
|
|
181 |
x = self.fc(x) |
|
|
182 |
x = self.dropout(x) |
|
|
183 |
x += res |
|
|
184 |
x = self.layer_norm(x) |
|
|
185 |
return x.squeeze() |
|
|
186 |
|
|
|
187 |
|
|
|
188 |
class Attention_s2(nn.Module): |
|
|
189 |
def __init__(self, in_feature, num_head=4, dropout=0.1): |
|
|
190 |
super(Attention_s2, self).__init__() |
|
|
191 |
self.in_feature = in_feature |
|
|
192 |
self.num_head = num_head |
|
|
193 |
self.size_per_head = in_feature // num_head |
|
|
194 |
self.out_dim = num_head * self.size_per_head |
|
|
195 |
assert self.size_per_head * num_head == in_feature |
|
|
196 |
self.q_linear = nn.Linear(in_feature, in_feature, bias=False) |
|
|
197 |
self.k_linear = nn.Linear(in_feature, in_feature, bias=False) |
|
|
198 |
self.v_linear = nn.Linear(in_feature, in_feature, bias=False) |
|
|
199 |
self.fc = nn.Linear(in_feature, in_feature, bias=False) |
|
|
200 |
self.dropout = nn.Dropout(dropout) |
|
|
201 |
self.layer_norm = nn.LayerNorm(in_feature) |
|
|
202 |
|
|
|
203 |
def forward(self, s, x1, x2): |
|
|
204 |
batch_size = x2.size(0) |
|
|
205 |
s = s.unsqueeze(1) |
|
|
206 |
res = s |
|
|
207 |
query = self.q_linear(s) |
|
|
208 |
key = self.k_linear(x2) |
|
|
209 |
value = self.v_linear(x2) |
|
|
210 |
|
|
|
211 |
query = query.view(batch_size, self.num_head, -1, self.size_per_head) |
|
|
212 |
key = key.view(batch_size, self.num_head, -1, self.size_per_head) |
|
|
213 |
value = value.view(batch_size, self.num_head, -1, self.size_per_head) |
|
|
214 |
|
|
|
215 |
scale = np.sqrt(self.size_per_head) |
|
|
216 |
energy = torch.matmul(query, key.permute(0, 1, 3, 2)) / scale |
|
|
217 |
|
|
|
218 |
attention = torch.softmax(energy, dim=-1) |
|
|
219 |
x = torch.matmul(attention, value) |
|
|
220 |
x = x.permute(0, 2, 1, 3).contiguous() |
|
|
221 |
x = x.view(batch_size, -1, self.in_feature) |
|
|
222 |
x = self.fc(x) |
|
|
223 |
x = self.dropout(x) |
|
|
224 |
x += res |
|
|
225 |
x = self.layer_norm(x) |
|
|
226 |
return x.squeeze() |
|
|
227 |
|
|
|
228 |
class Attention_x(nn.Module): |
|
|
229 |
def __init__(self, in_feature, num_head=4, dropout=0.1): |
|
|
230 |
super(Attention_x, self).__init__() |
|
|
231 |
self.in_feature = in_feature |
|
|
232 |
self.num_head = num_head |
|
|
233 |
self.size_per_head = in_feature // num_head |
|
|
234 |
self.out_dim = num_head * self.size_per_head |
|
|
235 |
assert self.size_per_head * num_head == in_feature |
|
|
236 |
self.q_linear = nn.Linear(in_feature, in_feature, bias=False) |
|
|
237 |
self.k_linear = nn.Linear(in_feature, in_feature, bias=False) |
|
|
238 |
self.v_linear = nn.Linear(in_feature, in_feature, bias=False) |
|
|
239 |
self.fc = nn.Linear(in_feature, in_feature, bias=False) |
|
|
240 |
self.dropout = nn.Dropout(dropout) |
|
|
241 |
self.layer_norm = nn.LayerNorm(in_feature) |
|
|
242 |
|
|
|
243 |
def forward(self, current_x, s, other_x): |
|
|
244 |
batch_size = current_x.size(0) |
|
|
245 |
res = current_x |
|
|
246 |
query = self.q_linear(current_x) |
|
|
247 |
key = self.k_linear(other_x) |
|
|
248 |
value = self.v_linear(other_x) |
|
|
249 |
|
|
|
250 |
query = query.view(batch_size, self.num_head, -1, self.size_per_head) |
|
|
251 |
key = key.view(batch_size, self.num_head, -1, self.size_per_head) |
|
|
252 |
value = value.view(batch_size, self.num_head, -1, self.size_per_head) |
|
|
253 |
|
|
|
254 |
scale = np.sqrt(self.size_per_head) |
|
|
255 |
energy = torch.matmul(query, key.permute(0, 1, 3, 2)) / scale |
|
|
256 |
|
|
|
257 |
attention = torch.softmax(energy, dim=-1) |
|
|
258 |
x = torch.matmul(attention, value) |
|
|
259 |
x = x.permute(0, 2, 1, 3).contiguous() |
|
|
260 |
x = x.view(batch_size, -1, self.in_feature) |
|
|
261 |
x = self.fc(x) |
|
|
262 |
x = self.dropout(x) |
|
|
263 |
x += res |
|
|
264 |
x = self.layer_norm(x) |
|
|
265 |
return x |
|
|
266 |
|
|
|
267 |
|
|
|
268 |
class RNN(nn.Module): |
|
|
269 |
def __init__(self, d_model): |
|
|
270 |
super(RNN, self).__init__() |
|
|
271 |
self.rnn = nn.GRU(d_model, d_model, num_layers=1, batch_first=True) |
|
|
272 |
def forward(self, x, masks, lengths): |
|
|
273 |
rnn_input = x |
|
|
274 |
rnn_output, _ = self.rnn(rnn_input) |
|
|
275 |
return rnn_output |
|
|
276 |
|
|
|
277 |
|
|
|
278 |
class MaxPoolLayer(nn.Module): |
|
|
279 |
""" |
|
|
280 |
A layer that performs max pooling along the sequence dimension |
|
|
281 |
""" |
|
|
282 |
|
|
|
283 |
def __init__(self): |
|
|
284 |
super().__init__() |
|
|
285 |
|
|
|
286 |
def forward(self, inputs, mask_or_lengths=None): |
|
|
287 |
""" |
|
|
288 |
inputs: tensor of shape (batch_size, seq_len, hidden_size) |
|
|
289 |
mask_or_lengths: tensor of shape (batch_size) or (batch_size, seq_len) |
|
|
290 |
|
|
|
291 |
returns: tensor of shape (batch_size, hidden_size) |
|
|
292 |
""" |
|
|
293 |
bs, sl, _ = inputs.size() |
|
|
294 |
if mask_or_lengths is not None: |
|
|
295 |
if len(mask_or_lengths.size()) == 1: |
|
|
296 |
mask = (torch.arange(sl, device=inputs.device).unsqueeze(0).expand(bs, sl) >= mask_or_lengths.unsqueeze( |
|
|
297 |
1)) |
|
|
298 |
else: |
|
|
299 |
mask = mask_or_lengths |
|
|
300 |
inputs = inputs.masked_fill(mask.unsqueeze(-1).expand_as(inputs), float('-inf')) |
|
|
301 |
max_pooled = inputs.max(1)[0] |
|
|
302 |
return max_pooled |
|
|
303 |
|
|
|
304 |
|
|
|
305 |
def prroc(testy, probs): |
|
|
306 |
precision, recall, thresholds = sklearn.metrics.precision_recall_curve(testy, probs) |
|
|
307 |
auc = auc(recall, precision) |
|
|
308 |
return auc |