|
a |
|
b/app/models/backbones/concare.py |
|
|
1 |
# import packages |
|
|
2 |
import copy |
|
|
3 |
|
|
|
4 |
# import packages |
|
|
5 |
import math |
|
|
6 |
|
|
|
7 |
import torch |
|
|
8 |
import torch.nn.functional as F |
|
|
9 |
from torch import nn |
|
|
10 |
from torch.autograd import Variable |
|
|
11 |
|
|
|
12 |
|
|
|
13 |
class SingleAttention(nn.Module): |
|
|
14 |
def __init__( |
|
|
15 |
self, |
|
|
16 |
attention_input_dim, |
|
|
17 |
attention_hidden_dim, |
|
|
18 |
attention_type="add", |
|
|
19 |
demographic_dim=12, |
|
|
20 |
time_aware=False, |
|
|
21 |
use_demographic=False, |
|
|
22 |
): |
|
|
23 |
super(SingleAttention, self).__init__() |
|
|
24 |
|
|
|
25 |
self.attention_type = attention_type |
|
|
26 |
self.attention_hidden_dim = attention_hidden_dim |
|
|
27 |
self.attention_input_dim = attention_input_dim |
|
|
28 |
self.use_demographic = use_demographic |
|
|
29 |
self.demographic_dim = demographic_dim |
|
|
30 |
self.time_aware = time_aware |
|
|
31 |
|
|
|
32 |
# batch_time = torch.arange(0, batch_mask.size()[1], dtype=torch.float32).reshape(1, batch_mask.size()[1], 1) |
|
|
33 |
# batch_time = batch_time.repeat(batch_mask.size()[0], 1, 1) |
|
|
34 |
|
|
|
35 |
if attention_type == "add": |
|
|
36 |
if self.time_aware: |
|
|
37 |
# self.Wx = nn.Parameter(torch.randn(attention_input_dim+1, attention_hidden_dim)) |
|
|
38 |
self.Wx = nn.Parameter( |
|
|
39 |
torch.randn(attention_input_dim, attention_hidden_dim) |
|
|
40 |
) |
|
|
41 |
self.Wtime_aware = nn.Parameter(torch.randn(1, attention_hidden_dim)) |
|
|
42 |
nn.init.kaiming_uniform_(self.Wtime_aware, a=math.sqrt(5)) |
|
|
43 |
else: |
|
|
44 |
self.Wx = nn.Parameter( |
|
|
45 |
torch.randn(attention_input_dim, attention_hidden_dim) |
|
|
46 |
) |
|
|
47 |
self.Wt = nn.Parameter( |
|
|
48 |
torch.randn(attention_input_dim, attention_hidden_dim) |
|
|
49 |
) |
|
|
50 |
self.Wd = nn.Parameter(torch.randn(demographic_dim, attention_hidden_dim)) |
|
|
51 |
self.bh = nn.Parameter( |
|
|
52 |
torch.zeros( |
|
|
53 |
attention_hidden_dim, |
|
|
54 |
) |
|
|
55 |
) |
|
|
56 |
self.Wa = nn.Parameter(torch.randn(attention_hidden_dim, 1)) |
|
|
57 |
self.ba = nn.Parameter( |
|
|
58 |
torch.zeros( |
|
|
59 |
1, |
|
|
60 |
) |
|
|
61 |
) |
|
|
62 |
|
|
|
63 |
nn.init.kaiming_uniform_(self.Wd, a=math.sqrt(5)) |
|
|
64 |
nn.init.kaiming_uniform_(self.Wx, a=math.sqrt(5)) |
|
|
65 |
nn.init.kaiming_uniform_(self.Wt, a=math.sqrt(5)) |
|
|
66 |
nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5)) |
|
|
67 |
elif attention_type == "mul": |
|
|
68 |
self.Wa = nn.Parameter( |
|
|
69 |
torch.randn(attention_input_dim, attention_input_dim) |
|
|
70 |
) |
|
|
71 |
self.ba = nn.Parameter( |
|
|
72 |
torch.zeros( |
|
|
73 |
1, |
|
|
74 |
) |
|
|
75 |
) |
|
|
76 |
|
|
|
77 |
nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5)) |
|
|
78 |
elif attention_type == "concat": |
|
|
79 |
if self.time_aware: |
|
|
80 |
self.Wh = nn.Parameter( |
|
|
81 |
torch.randn(2 * attention_input_dim + 1, attention_hidden_dim) |
|
|
82 |
) |
|
|
83 |
else: |
|
|
84 |
self.Wh = nn.Parameter( |
|
|
85 |
torch.randn(2 * attention_input_dim, attention_hidden_dim) |
|
|
86 |
) |
|
|
87 |
|
|
|
88 |
self.Wa = nn.Parameter(torch.randn(attention_hidden_dim, 1)) |
|
|
89 |
self.ba = nn.Parameter( |
|
|
90 |
torch.zeros( |
|
|
91 |
1, |
|
|
92 |
) |
|
|
93 |
) |
|
|
94 |
|
|
|
95 |
nn.init.kaiming_uniform_(self.Wh, a=math.sqrt(5)) |
|
|
96 |
nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5)) |
|
|
97 |
|
|
|
98 |
elif attention_type == "new": |
|
|
99 |
self.Wt = nn.Parameter( |
|
|
100 |
torch.randn(attention_input_dim, attention_hidden_dim) |
|
|
101 |
) |
|
|
102 |
self.Wx = nn.Parameter( |
|
|
103 |
torch.randn(attention_input_dim, attention_hidden_dim) |
|
|
104 |
) |
|
|
105 |
|
|
|
106 |
self.rate = nn.Parameter(torch.zeros(1) + 0.8) |
|
|
107 |
nn.init.kaiming_uniform_(self.Wx, a=math.sqrt(5)) |
|
|
108 |
nn.init.kaiming_uniform_(self.Wt, a=math.sqrt(5)) |
|
|
109 |
|
|
|
110 |
else: |
|
|
111 |
raise RuntimeError("Wrong attention type.") |
|
|
112 |
|
|
|
113 |
self.tanh = nn.Tanh() |
|
|
114 |
self.softmax = nn.Softmax(dim=1) |
|
|
115 |
self.sigmoid = nn.Sigmoid() |
|
|
116 |
self.relu = nn.ReLU() |
|
|
117 |
|
|
|
118 |
def forward(self, input, device, demo=None): |
|
|
119 |
|
|
|
120 |
( |
|
|
121 |
batch_size, |
|
|
122 |
time_step, |
|
|
123 |
input_dim, |
|
|
124 |
) = input.size() # batch_size * time_step * hidden_dim(i) |
|
|
125 |
|
|
|
126 |
time_decays = ( |
|
|
127 |
torch.tensor(range(time_step - 1, -1, -1), dtype=torch.float32) |
|
|
128 |
.unsqueeze(-1).unsqueeze(0).to(device=device) |
|
|
129 |
) # 1*t*1 |
|
|
130 |
b_time_decays = time_decays.repeat(batch_size, 1, 1) + 1 # b t 1 |
|
|
131 |
|
|
|
132 |
if self.attention_type == "add": # B*T*I @ H*I |
|
|
133 |
q = torch.matmul(input[:, -1, :], self.Wt) # b h |
|
|
134 |
q = torch.reshape(q, (batch_size, 1, self.attention_hidden_dim)) # B*1*H |
|
|
135 |
if self.time_aware == True: |
|
|
136 |
k = torch.matmul(input, self.Wx) # b t h |
|
|
137 |
time_hidden = torch.matmul(b_time_decays, self.Wtime_aware) # b t h |
|
|
138 |
else: |
|
|
139 |
k = torch.matmul(input, self.Wx) # b t h |
|
|
140 |
if self.use_demographic: |
|
|
141 |
d = torch.matmul(demo, self.Wd) # B*H |
|
|
142 |
d = torch.reshape( |
|
|
143 |
d, (batch_size, 1, self.attention_hidden_dim) |
|
|
144 |
) # b 1 h |
|
|
145 |
h = q + k + self.bh # b t h |
|
|
146 |
if self.time_aware: |
|
|
147 |
h += time_hidden |
|
|
148 |
h = self.tanh(h) # B*T*H |
|
|
149 |
e = torch.matmul(h, self.Wa) + self.ba # B*T*1 |
|
|
150 |
e = torch.reshape(e, (batch_size, time_step)) # b t |
|
|
151 |
elif self.attention_type == "mul": |
|
|
152 |
e = torch.matmul(input[:, -1, :], self.Wa) # b i |
|
|
153 |
e = ( |
|
|
154 |
torch.matmul(e.unsqueeze(1), input.permute(0, 2, 1)).squeeze() + self.ba |
|
|
155 |
) # b t |
|
|
156 |
elif self.attention_type == "concat": |
|
|
157 |
q = input[:, -1, :].unsqueeze(1).repeat(1, time_step, 1) # b t i |
|
|
158 |
k = input |
|
|
159 |
c = torch.cat((q, k), dim=-1) # B*T*2I |
|
|
160 |
if self.time_aware: |
|
|
161 |
c = torch.cat((c, b_time_decays), dim=-1) # B*T*2I+1 |
|
|
162 |
h = torch.matmul(c, self.Wh) |
|
|
163 |
h = self.tanh(h) |
|
|
164 |
e = torch.matmul(h, self.Wa) + self.ba # B*T*1 |
|
|
165 |
e = torch.reshape(e, (batch_size, time_step)) # b t |
|
|
166 |
|
|
|
167 |
elif self.attention_type == "new": |
|
|
168 |
|
|
|
169 |
q = torch.matmul(input[:, -1, :], self.Wt) # b h |
|
|
170 |
q = torch.reshape(q, (batch_size, 1, self.attention_hidden_dim)) # B*1*H |
|
|
171 |
k = torch.matmul(input, self.Wx) # b t h |
|
|
172 |
dot_product = torch.matmul(q, k.transpose(1, 2)).squeeze() # b t |
|
|
173 |
denominator = self.sigmoid(self.rate) * ( |
|
|
174 |
torch.log(2.72 + (1 - self.sigmoid(dot_product))) |
|
|
175 |
* (b_time_decays.squeeze()) |
|
|
176 |
) |
|
|
177 |
e = self.relu(self.sigmoid(dot_product) / (denominator)) # b * t |
|
|
178 |
|
|
|
179 |
a = self.softmax(e) # B*T |
|
|
180 |
v = torch.matmul(a.unsqueeze(1), input).squeeze() # B*I |
|
|
181 |
|
|
|
182 |
return v, a |
|
|
183 |
|
|
|
184 |
|
|
|
185 |
class FinalAttentionQKV(nn.Module): |
|
|
186 |
def __init__( |
|
|
187 |
self, |
|
|
188 |
attention_input_dim, |
|
|
189 |
attention_hidden_dim, |
|
|
190 |
attention_type="add", |
|
|
191 |
dropout=None, |
|
|
192 |
): |
|
|
193 |
super(FinalAttentionQKV, self).__init__() |
|
|
194 |
|
|
|
195 |
self.attention_type = attention_type |
|
|
196 |
self.attention_hidden_dim = attention_hidden_dim |
|
|
197 |
self.attention_input_dim = attention_input_dim |
|
|
198 |
|
|
|
199 |
self.W_q = nn.Linear(attention_input_dim, attention_hidden_dim) |
|
|
200 |
self.W_k = nn.Linear(attention_input_dim, attention_hidden_dim) |
|
|
201 |
self.W_v = nn.Linear(attention_input_dim, attention_hidden_dim) |
|
|
202 |
|
|
|
203 |
self.W_out = nn.Linear(attention_hidden_dim, 1) |
|
|
204 |
|
|
|
205 |
self.b_in = nn.Parameter( |
|
|
206 |
torch.zeros( |
|
|
207 |
1, |
|
|
208 |
) |
|
|
209 |
) |
|
|
210 |
self.b_out = nn.Parameter( |
|
|
211 |
torch.zeros( |
|
|
212 |
1, |
|
|
213 |
) |
|
|
214 |
) |
|
|
215 |
|
|
|
216 |
nn.init.kaiming_uniform_(self.W_q.weight, a=math.sqrt(5)) |
|
|
217 |
nn.init.kaiming_uniform_(self.W_k.weight, a=math.sqrt(5)) |
|
|
218 |
nn.init.kaiming_uniform_(self.W_v.weight, a=math.sqrt(5)) |
|
|
219 |
nn.init.kaiming_uniform_(self.W_out.weight, a=math.sqrt(5)) |
|
|
220 |
|
|
|
221 |
self.Wh = nn.Parameter( |
|
|
222 |
torch.randn(2 * attention_input_dim, attention_hidden_dim) |
|
|
223 |
) |
|
|
224 |
self.Wa = nn.Parameter(torch.randn(attention_hidden_dim, 1)) |
|
|
225 |
self.ba = nn.Parameter( |
|
|
226 |
torch.zeros( |
|
|
227 |
1, |
|
|
228 |
) |
|
|
229 |
) |
|
|
230 |
|
|
|
231 |
nn.init.kaiming_uniform_(self.Wh, a=math.sqrt(5)) |
|
|
232 |
nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5)) |
|
|
233 |
|
|
|
234 |
self.dropout = nn.Dropout(p=dropout) |
|
|
235 |
self.tanh = nn.Tanh() |
|
|
236 |
self.softmax = nn.Softmax(dim=1) |
|
|
237 |
self.sigmoid = nn.Sigmoid() |
|
|
238 |
|
|
|
239 |
def forward(self, input): |
|
|
240 |
|
|
|
241 |
( |
|
|
242 |
batch_size, |
|
|
243 |
time_step, |
|
|
244 |
input_dim, |
|
|
245 |
) = input.size() # batch_size * input_dim + 1 * hidden_dim(i) |
|
|
246 |
input_q = self.W_q(input[:, -1, :]) # b h |
|
|
247 |
input_k = self.W_k(input) # b t h |
|
|
248 |
input_v = self.W_v(input) # b t h |
|
|
249 |
|
|
|
250 |
if self.attention_type == "add": # B*T*I @ H*I |
|
|
251 |
|
|
|
252 |
q = torch.reshape( |
|
|
253 |
input_q, (batch_size, 1, self.attention_hidden_dim) |
|
|
254 |
) # B*1*H |
|
|
255 |
h = q + input_k + self.b_in # b t h |
|
|
256 |
h = self.tanh(h) # B*T*H |
|
|
257 |
e = self.W_out(h) # b t 1 |
|
|
258 |
e = torch.reshape(e, (batch_size, time_step)) # b t |
|
|
259 |
|
|
|
260 |
elif self.attention_type == "mul": |
|
|
261 |
q = torch.reshape( |
|
|
262 |
input_q, (batch_size, self.attention_hidden_dim, 1) |
|
|
263 |
) # B*h 1 |
|
|
264 |
e = torch.matmul(input_k, q).squeeze() # b t |
|
|
265 |
|
|
|
266 |
elif self.attention_type == "concat": |
|
|
267 |
q = input_q.unsqueeze(1).repeat(1, time_step, 1) # b t h |
|
|
268 |
k = input_k |
|
|
269 |
c = torch.cat((q, k), dim=-1) # B*T*2I |
|
|
270 |
h = torch.matmul(c, self.Wh) |
|
|
271 |
h = self.tanh(h) |
|
|
272 |
e = torch.matmul(h, self.Wa) + self.ba # B*T*1 |
|
|
273 |
e = torch.reshape(e, (batch_size, time_step)) # b t |
|
|
274 |
|
|
|
275 |
a = self.softmax(e) # B*T |
|
|
276 |
if self.dropout is not None: |
|
|
277 |
a = self.dropout(a) |
|
|
278 |
v = torch.matmul(a.unsqueeze(1), input_v).squeeze() # B*I |
|
|
279 |
|
|
|
280 |
return v, a |
|
|
281 |
|
|
|
282 |
|
|
|
283 |
class PositionwiseFeedForward(nn.Module): # new added |
|
|
284 |
"Implements FFN equation." |
|
|
285 |
|
|
|
286 |
def __init__(self, d_model, d_ff, dropout=0.1): |
|
|
287 |
super(PositionwiseFeedForward, self).__init__() |
|
|
288 |
self.w_1 = nn.Linear(d_model, d_ff) |
|
|
289 |
self.w_2 = nn.Linear(d_ff, d_model) |
|
|
290 |
self.dropout = nn.Dropout(dropout) |
|
|
291 |
|
|
|
292 |
def forward(self, x): |
|
|
293 |
return self.w_2(self.dropout(F.relu(self.w_1(x)))), None |
|
|
294 |
|
|
|
295 |
|
|
|
296 |
class PositionalEncoding(nn.Module): # new added / not use anymore |
|
|
297 |
"Implement the PE function." |
|
|
298 |
|
|
|
299 |
def __init__(self, d_model, dropout, max_len=400): |
|
|
300 |
super(PositionalEncoding, self).__init__() |
|
|
301 |
self.dropout = nn.Dropout(p=dropout) |
|
|
302 |
|
|
|
303 |
# Compute the positional encodings once in log space. |
|
|
304 |
pe = torch.zeros(max_len, d_model) |
|
|
305 |
position = torch.arange(0.0, max_len).unsqueeze(1) |
|
|
306 |
div_term = torch.exp( |
|
|
307 |
torch.arange(0.0, d_model, 2) * -(math.log(10000.0) / d_model) |
|
|
308 |
) |
|
|
309 |
pe[:, 0::2] = torch.sin(position * div_term) |
|
|
310 |
pe[:, 1::2] = torch.cos(position * div_term) |
|
|
311 |
pe = pe.unsqueeze(0) |
|
|
312 |
self.register_buffer("pe", pe) |
|
|
313 |
|
|
|
314 |
def forward(self, x): |
|
|
315 |
x = x + Variable(self.pe[:, : x.size(1)], requires_grad=False) |
|
|
316 |
return self.dropout(x) |
|
|
317 |
|
|
|
318 |
|
|
|
319 |
class MultiHeadedAttention(nn.Module): |
|
|
320 |
def __init__(self, h, d_model, dropout=0): |
|
|
321 |
"Take in model size and number of heads." |
|
|
322 |
super(MultiHeadedAttention, self).__init__() |
|
|
323 |
assert d_model % h == 0 |
|
|
324 |
# We assume d_v always equals d_k |
|
|
325 |
self.d_k = d_model // h |
|
|
326 |
self.h = h |
|
|
327 |
self.linears = nn.ModuleList( |
|
|
328 |
[nn.Linear(d_model, self.d_k * self.h) for _ in range(3)] |
|
|
329 |
) |
|
|
330 |
self.final_linear = nn.Linear(d_model, d_model) |
|
|
331 |
self.attn = None |
|
|
332 |
self.dropout = nn.Dropout(p=dropout) |
|
|
333 |
|
|
|
334 |
def attention(self, query, key, value, mask=None, dropout=None): |
|
|
335 |
"Compute 'Scaled Dot Product Attention'" |
|
|
336 |
d_k = query.size(-1) # b h t d_k |
|
|
337 |
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) # b h t t |
|
|
338 |
if mask is not None: # 1 1 t t |
|
|
339 |
scores = scores.masked_fill(mask == 0, -1e9) # b h t t 下三角 |
|
|
340 |
p_attn = F.softmax(scores, dim=-1) # b h t t |
|
|
341 |
if dropout is not None: |
|
|
342 |
p_attn = dropout(p_attn) |
|
|
343 |
return torch.matmul(p_attn, value), p_attn # b h t v (d_k) |
|
|
344 |
|
|
|
345 |
def cov(self, m, y=None): |
|
|
346 |
if y is not None: |
|
|
347 |
m = torch.cat((m, y), dim=0) |
|
|
348 |
m_exp = torch.mean(m, dim=1) |
|
|
349 |
x = m - m_exp[:, None] |
|
|
350 |
cov = 1 / (x.size(1) - 1) * x.mm(x.t()) |
|
|
351 |
return cov |
|
|
352 |
|
|
|
353 |
def forward(self, query, key, value, mask=None): |
|
|
354 |
if mask is not None: |
|
|
355 |
# Same mask applied to all h heads. |
|
|
356 |
mask = mask.unsqueeze(1) # 1 1 t t |
|
|
357 |
|
|
|
358 |
nbatches = query.size(0) # b |
|
|
359 |
input_dim = query.size(1) # i+1 |
|
|
360 |
feature_dim = query.size(1) # i+1 |
|
|
361 |
|
|
|
362 |
# input size -> # batch_size * d_input * hidden_dim |
|
|
363 |
|
|
|
364 |
# d_model => h * d_k |
|
|
365 |
query, key, value = [ |
|
|
366 |
l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2) |
|
|
367 |
for l, x in zip(self.linears, (query, key, value)) |
|
|
368 |
] # b num_head d_input d_k |
|
|
369 |
|
|
|
370 |
x, self.attn = self.attention( |
|
|
371 |
query, key, value, mask=mask, dropout=self.dropout |
|
|
372 |
) # b num_head d_input d_v (d_k) |
|
|
373 |
|
|
|
374 |
x = ( |
|
|
375 |
x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k) |
|
|
376 |
) # batch_size * d_input * hidden_dim |
|
|
377 |
|
|
|
378 |
# DeCov |
|
|
379 |
DeCov_contexts = x.transpose(0, 1).transpose(1, 2) # I+1 H B |
|
|
380 |
Covs = self.cov(DeCov_contexts[0, :, :]) |
|
|
381 |
DeCov_loss = 0.5 * ( |
|
|
382 |
torch.norm(Covs, p="fro") ** 2 - torch.norm(torch.diag(Covs)) ** 2 |
|
|
383 |
) |
|
|
384 |
for i in range(feature_dim - 1): |
|
|
385 |
Covs = self.cov(DeCov_contexts[i + 1, :, :]) |
|
|
386 |
DeCov_loss += 0.5 * ( |
|
|
387 |
torch.norm(Covs, p="fro") ** 2 - torch.norm(torch.diag(Covs)) ** 2 |
|
|
388 |
) |
|
|
389 |
|
|
|
390 |
return self.final_linear(x), DeCov_loss |
|
|
391 |
|
|
|
392 |
|
|
|
393 |
class LayerNorm(nn.Module): |
|
|
394 |
def __init__(self, features, eps=1e-7): |
|
|
395 |
super(LayerNorm, self).__init__() |
|
|
396 |
self.a_2 = nn.Parameter(torch.ones(features)) |
|
|
397 |
self.b_2 = nn.Parameter(torch.zeros(features)) |
|
|
398 |
self.eps = eps |
|
|
399 |
|
|
|
400 |
def forward(self, x): |
|
|
401 |
mean = x.mean(-1, keepdim=True) |
|
|
402 |
std = x.std(-1, keepdim=True) |
|
|
403 |
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 |
|
|
404 |
|
|
|
405 |
|
|
|
406 |
class SublayerConnection(nn.Module): |
|
|
407 |
""" |
|
|
408 |
A residual connection followed by a layer norm. |
|
|
409 |
Note for code simplicity the norm is first as opposed to last. |
|
|
410 |
""" |
|
|
411 |
|
|
|
412 |
def __init__(self, size, dropout): |
|
|
413 |
super(SublayerConnection, self).__init__() |
|
|
414 |
self.norm = LayerNorm(size) |
|
|
415 |
self.dropout = nn.Dropout(dropout) |
|
|
416 |
|
|
|
417 |
def forward(self, x, sublayer): |
|
|
418 |
"Apply residual connection to any sublayer with the same size." |
|
|
419 |
returned_value = sublayer(self.norm(x)) |
|
|
420 |
return x + self.dropout(returned_value[0]), returned_value[1] |
|
|
421 |
|
|
|
422 |
|
|
|
423 |
class ConCare(nn.Module): |
|
|
424 |
def __init__( |
|
|
425 |
self, |
|
|
426 |
lab_dim, # lab_dim |
|
|
427 |
hidden_dim, |
|
|
428 |
demo_dim, |
|
|
429 |
d_model, |
|
|
430 |
MHD_num_head, |
|
|
431 |
d_ff, |
|
|
432 |
# output_dim, |
|
|
433 |
# device, |
|
|
434 |
drop=0.5, |
|
|
435 |
): |
|
|
436 |
super(ConCare, self).__init__() |
|
|
437 |
|
|
|
438 |
# hyperparameters |
|
|
439 |
self.lab_dim = lab_dim |
|
|
440 |
self.hidden_dim = hidden_dim # d_model |
|
|
441 |
self.d_model = d_model |
|
|
442 |
self.MHD_num_head = MHD_num_head |
|
|
443 |
self.d_ff = d_ff |
|
|
444 |
# self.output_dim = output_dim |
|
|
445 |
self.drop = drop |
|
|
446 |
self.demo_dim = demo_dim |
|
|
447 |
|
|
|
448 |
# layers |
|
|
449 |
self.PositionalEncoding = PositionalEncoding( |
|
|
450 |
self.d_model, dropout=0, max_len=400 |
|
|
451 |
) |
|
|
452 |
|
|
|
453 |
self.GRUs = nn.ModuleList( |
|
|
454 |
[ |
|
|
455 |
copy.deepcopy(nn.GRU(1, self.hidden_dim, batch_first=True)) |
|
|
456 |
for _ in range(self.lab_dim) |
|
|
457 |
] |
|
|
458 |
) |
|
|
459 |
self.LastStepAttentions = nn.ModuleList( |
|
|
460 |
[ |
|
|
461 |
copy.deepcopy( |
|
|
462 |
SingleAttention( |
|
|
463 |
self.hidden_dim, |
|
|
464 |
8, |
|
|
465 |
attention_type="new", |
|
|
466 |
demographic_dim=12, |
|
|
467 |
time_aware=True, |
|
|
468 |
use_demographic=False, |
|
|
469 |
) |
|
|
470 |
) |
|
|
471 |
for _ in range(self.lab_dim) |
|
|
472 |
] |
|
|
473 |
) |
|
|
474 |
|
|
|
475 |
self.FinalAttentionQKV = FinalAttentionQKV( |
|
|
476 |
self.hidden_dim, |
|
|
477 |
self.hidden_dim, |
|
|
478 |
attention_type="mul", |
|
|
479 |
dropout=self.drop, |
|
|
480 |
) |
|
|
481 |
|
|
|
482 |
self.MultiHeadedAttention = MultiHeadedAttention( |
|
|
483 |
self.MHD_num_head, self.d_model, dropout=self.drop |
|
|
484 |
) |
|
|
485 |
self.SublayerConnection = SublayerConnection(self.d_model, dropout=self.drop) |
|
|
486 |
|
|
|
487 |
self.PositionwiseFeedForward = PositionwiseFeedForward( |
|
|
488 |
self.d_model, self.d_ff, dropout=0.1 |
|
|
489 |
) |
|
|
490 |
|
|
|
491 |
self.demo_lab_proj = nn.Linear(self.demo_dim + self.lab_dim, self.hidden_dim) |
|
|
492 |
self.demo_proj_main = nn.Linear(self.demo_dim, self.hidden_dim) |
|
|
493 |
self.demo_proj = nn.Linear(self.demo_dim, self.hidden_dim) |
|
|
494 |
self.output0 = nn.Linear(self.hidden_dim, self.hidden_dim) |
|
|
495 |
# self.output1 = nn.Linear(self.hidden_dim, self.output_dim) |
|
|
496 |
|
|
|
497 |
self.dropout = nn.Dropout(p=self.drop) |
|
|
498 |
self.tanh = nn.Tanh() |
|
|
499 |
self.softmax = nn.Softmax() |
|
|
500 |
self.sigmoid = nn.Sigmoid() |
|
|
501 |
self.relu = nn.ReLU() |
|
|
502 |
|
|
|
503 |
def concare_encoder(self, input, demo_input, device): |
|
|
504 |
|
|
|
505 |
# input shape [batch_size, timestep, feature_dim] |
|
|
506 |
demo_main = self.tanh(self.demo_proj_main(demo_input)).unsqueeze( |
|
|
507 |
1 |
|
|
508 |
) # b hidden_dim |
|
|
509 |
|
|
|
510 |
batch_size = input.size(0) |
|
|
511 |
time_step = input.size(1) |
|
|
512 |
feature_dim = input.size(2) |
|
|
513 |
assert feature_dim == self.lab_dim # input Tensor : 256 * 48 * 76 |
|
|
514 |
assert self.d_model % self.MHD_num_head == 0 |
|
|
515 |
|
|
|
516 |
# forward |
|
|
517 |
GRU_embeded_input = self.GRUs[0]( |
|
|
518 |
input[:, :, 0].unsqueeze(-1).to(device=device), |
|
|
519 |
Variable( |
|
|
520 |
torch.zeros(batch_size, self.hidden_dim) |
|
|
521 |
.to(device=device) |
|
|
522 |
.unsqueeze(0) |
|
|
523 |
), |
|
|
524 |
)[ |
|
|
525 |
0 |
|
|
526 |
] # b t h |
|
|
527 |
Attention_embeded_input = self.LastStepAttentions[0](GRU_embeded_input, device)[ |
|
|
528 |
0 |
|
|
529 |
].unsqueeze( |
|
|
530 |
1 |
|
|
531 |
) # b 1 h |
|
|
532 |
|
|
|
533 |
for i in range(feature_dim - 1): |
|
|
534 |
embeded_input = self.GRUs[i + 1]( |
|
|
535 |
input[:, :, i + 1].unsqueeze(-1), |
|
|
536 |
Variable( |
|
|
537 |
torch.zeros(batch_size, self.hidden_dim) |
|
|
538 |
.to(device=device) |
|
|
539 |
.unsqueeze(0) |
|
|
540 |
), |
|
|
541 |
)[ |
|
|
542 |
0 |
|
|
543 |
] # b 1 h |
|
|
544 |
embeded_input = self.LastStepAttentions[i + 1](embeded_input, device)[0].unsqueeze( |
|
|
545 |
1 |
|
|
546 |
) # b 1 h |
|
|
547 |
Attention_embeded_input = torch.cat( |
|
|
548 |
(Attention_embeded_input, embeded_input), 1 |
|
|
549 |
) # b i h |
|
|
550 |
Attention_embeded_input = torch.cat( |
|
|
551 |
(Attention_embeded_input, demo_main), 1 |
|
|
552 |
) # b i+1 h |
|
|
553 |
posi_input = self.dropout( |
|
|
554 |
Attention_embeded_input |
|
|
555 |
) # batch_size * d_input+1 * hidden_dim |
|
|
556 |
|
|
|
557 |
contexts = self.SublayerConnection( |
|
|
558 |
posi_input, |
|
|
559 |
lambda x: self.MultiHeadedAttention( |
|
|
560 |
posi_input, posi_input, posi_input, None |
|
|
561 |
), |
|
|
562 |
) # # batch_size * d_input * hidden_dim |
|
|
563 |
|
|
|
564 |
DeCov_loss = contexts[1] |
|
|
565 |
contexts = contexts[0] |
|
|
566 |
|
|
|
567 |
contexts = self.SublayerConnection( |
|
|
568 |
contexts, lambda x: self.PositionwiseFeedForward(contexts) |
|
|
569 |
)[0] |
|
|
570 |
|
|
|
571 |
weighted_contexts = self.FinalAttentionQKV(contexts)[0] |
|
|
572 |
return weighted_contexts |
|
|
573 |
|
|
|
574 |
def forward(self, x, device, info=None): |
|
|
575 |
"""extra info is not used here""" |
|
|
576 |
batch_size, time_steps, _ = x.size() |
|
|
577 |
demo_input = x[:, 0, : self.demo_dim] |
|
|
578 |
lab_input = x[:, :, self.demo_dim :] |
|
|
579 |
out = torch.zeros((batch_size, time_steps, self.hidden_dim)) |
|
|
580 |
for cur_time in range(time_steps): |
|
|
581 |
# print(cur_time, end=" ") |
|
|
582 |
cur_lab = lab_input[:, : cur_time + 1, :] |
|
|
583 |
# print("cur_lab", cur_lab.shape) |
|
|
584 |
if cur_time == 0: |
|
|
585 |
out[:, cur_time, :] = self.demo_lab_proj(x[:, 0, :]) |
|
|
586 |
else: |
|
|
587 |
out[:, cur_time, :] = self.concare_encoder(cur_lab, demo_input, device) |
|
|
588 |
# print() |
|
|
589 |
return out |