a b/app/models/backbones/concare.py
1
# import packages
2
import copy
3
4
# import packages
5
import math
6
7
import torch
8
import torch.nn.functional as F
9
from torch import nn
10
from torch.autograd import Variable
11
12
13
class SingleAttention(nn.Module):
14
    def __init__(
15
        self,
16
        attention_input_dim,
17
        attention_hidden_dim,
18
        attention_type="add",
19
        demographic_dim=12,
20
        time_aware=False,
21
        use_demographic=False,
22
    ):
23
        super(SingleAttention, self).__init__()
24
25
        self.attention_type = attention_type
26
        self.attention_hidden_dim = attention_hidden_dim
27
        self.attention_input_dim = attention_input_dim
28
        self.use_demographic = use_demographic
29
        self.demographic_dim = demographic_dim
30
        self.time_aware = time_aware
31
32
        # batch_time = torch.arange(0, batch_mask.size()[1], dtype=torch.float32).reshape(1, batch_mask.size()[1], 1)
33
        # batch_time = batch_time.repeat(batch_mask.size()[0], 1, 1)
34
35
        if attention_type == "add":
36
            if self.time_aware:
37
                # self.Wx = nn.Parameter(torch.randn(attention_input_dim+1, attention_hidden_dim))
38
                self.Wx = nn.Parameter(
39
                    torch.randn(attention_input_dim, attention_hidden_dim)
40
                )
41
                self.Wtime_aware = nn.Parameter(torch.randn(1, attention_hidden_dim))
42
                nn.init.kaiming_uniform_(self.Wtime_aware, a=math.sqrt(5))
43
            else:
44
                self.Wx = nn.Parameter(
45
                    torch.randn(attention_input_dim, attention_hidden_dim)
46
                )
47
            self.Wt = nn.Parameter(
48
                torch.randn(attention_input_dim, attention_hidden_dim)
49
            )
50
            self.Wd = nn.Parameter(torch.randn(demographic_dim, attention_hidden_dim))
51
            self.bh = nn.Parameter(
52
                torch.zeros(
53
                    attention_hidden_dim,
54
                )
55
            )
56
            self.Wa = nn.Parameter(torch.randn(attention_hidden_dim, 1))
57
            self.ba = nn.Parameter(
58
                torch.zeros(
59
                    1,
60
                )
61
            )
62
63
            nn.init.kaiming_uniform_(self.Wd, a=math.sqrt(5))
64
            nn.init.kaiming_uniform_(self.Wx, a=math.sqrt(5))
65
            nn.init.kaiming_uniform_(self.Wt, a=math.sqrt(5))
66
            nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5))
67
        elif attention_type == "mul":
68
            self.Wa = nn.Parameter(
69
                torch.randn(attention_input_dim, attention_input_dim)
70
            )
71
            self.ba = nn.Parameter(
72
                torch.zeros(
73
                    1,
74
                )
75
            )
76
77
            nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5))
78
        elif attention_type == "concat":
79
            if self.time_aware:
80
                self.Wh = nn.Parameter(
81
                    torch.randn(2 * attention_input_dim + 1, attention_hidden_dim)
82
                )
83
            else:
84
                self.Wh = nn.Parameter(
85
                    torch.randn(2 * attention_input_dim, attention_hidden_dim)
86
                )
87
88
            self.Wa = nn.Parameter(torch.randn(attention_hidden_dim, 1))
89
            self.ba = nn.Parameter(
90
                torch.zeros(
91
                    1,
92
                )
93
            )
94
95
            nn.init.kaiming_uniform_(self.Wh, a=math.sqrt(5))
96
            nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5))
97
98
        elif attention_type == "new":
99
            self.Wt = nn.Parameter(
100
                torch.randn(attention_input_dim, attention_hidden_dim)
101
            )
102
            self.Wx = nn.Parameter(
103
                torch.randn(attention_input_dim, attention_hidden_dim)
104
            )
105
106
            self.rate = nn.Parameter(torch.zeros(1) + 0.8)
107
            nn.init.kaiming_uniform_(self.Wx, a=math.sqrt(5))
108
            nn.init.kaiming_uniform_(self.Wt, a=math.sqrt(5))
109
110
        else:
111
            raise RuntimeError("Wrong attention type.")
112
113
        self.tanh = nn.Tanh()
114
        self.softmax = nn.Softmax(dim=1)
115
        self.sigmoid = nn.Sigmoid()
116
        self.relu = nn.ReLU()
117
118
    def forward(self, input, device, demo=None):
119
120
        (
121
            batch_size,
122
            time_step,
123
            input_dim,
124
        ) = input.size()  # batch_size * time_step * hidden_dim(i)
125
126
        time_decays = (
127
            torch.tensor(range(time_step - 1, -1, -1), dtype=torch.float32)
128
            .unsqueeze(-1).unsqueeze(0).to(device=device)
129
        )  # 1*t*1
130
        b_time_decays = time_decays.repeat(batch_size, 1, 1) + 1  # b t 1
131
132
        if self.attention_type == "add":  # B*T*I  @ H*I
133
            q = torch.matmul(input[:, -1, :], self.Wt)  # b h
134
            q = torch.reshape(q, (batch_size, 1, self.attention_hidden_dim))  # B*1*H
135
            if self.time_aware == True:
136
                k = torch.matmul(input, self.Wx)  # b t h
137
                time_hidden = torch.matmul(b_time_decays, self.Wtime_aware)  # b t h
138
            else:
139
                k = torch.matmul(input, self.Wx)  # b t h
140
            if self.use_demographic:
141
                d = torch.matmul(demo, self.Wd)  # B*H
142
                d = torch.reshape(
143
                    d, (batch_size, 1, self.attention_hidden_dim)
144
                )  # b 1 h
145
            h = q + k + self.bh  # b t h
146
            if self.time_aware:
147
                h += time_hidden
148
            h = self.tanh(h)  # B*T*H
149
            e = torch.matmul(h, self.Wa) + self.ba  # B*T*1
150
            e = torch.reshape(e, (batch_size, time_step))  # b t
151
        elif self.attention_type == "mul":
152
            e = torch.matmul(input[:, -1, :], self.Wa)  # b i
153
            e = (
154
                torch.matmul(e.unsqueeze(1), input.permute(0, 2, 1)).squeeze() + self.ba
155
            )  # b t
156
        elif self.attention_type == "concat":
157
            q = input[:, -1, :].unsqueeze(1).repeat(1, time_step, 1)  # b t i
158
            k = input
159
            c = torch.cat((q, k), dim=-1)  # B*T*2I
160
            if self.time_aware:
161
                c = torch.cat((c, b_time_decays), dim=-1)  # B*T*2I+1
162
            h = torch.matmul(c, self.Wh)
163
            h = self.tanh(h)
164
            e = torch.matmul(h, self.Wa) + self.ba  # B*T*1
165
            e = torch.reshape(e, (batch_size, time_step))  # b t
166
167
        elif self.attention_type == "new":
168
169
            q = torch.matmul(input[:, -1, :], self.Wt)  # b h
170
            q = torch.reshape(q, (batch_size, 1, self.attention_hidden_dim))  # B*1*H
171
            k = torch.matmul(input, self.Wx)  # b t h
172
            dot_product = torch.matmul(q, k.transpose(1, 2)).squeeze()  # b t
173
            denominator = self.sigmoid(self.rate) * (
174
                torch.log(2.72 + (1 - self.sigmoid(dot_product)))
175
                * (b_time_decays.squeeze())
176
            )
177
            e = self.relu(self.sigmoid(dot_product) / (denominator))  # b * t
178
179
        a = self.softmax(e)  # B*T
180
        v = torch.matmul(a.unsqueeze(1), input).squeeze()  # B*I
181
182
        return v, a
183
184
185
class FinalAttentionQKV(nn.Module):
186
    def __init__(
187
        self,
188
        attention_input_dim,
189
        attention_hidden_dim,
190
        attention_type="add",
191
        dropout=None,
192
    ):
193
        super(FinalAttentionQKV, self).__init__()
194
195
        self.attention_type = attention_type
196
        self.attention_hidden_dim = attention_hidden_dim
197
        self.attention_input_dim = attention_input_dim
198
199
        self.W_q = nn.Linear(attention_input_dim, attention_hidden_dim)
200
        self.W_k = nn.Linear(attention_input_dim, attention_hidden_dim)
201
        self.W_v = nn.Linear(attention_input_dim, attention_hidden_dim)
202
203
        self.W_out = nn.Linear(attention_hidden_dim, 1)
204
205
        self.b_in = nn.Parameter(
206
            torch.zeros(
207
                1,
208
            )
209
        )
210
        self.b_out = nn.Parameter(
211
            torch.zeros(
212
                1,
213
            )
214
        )
215
216
        nn.init.kaiming_uniform_(self.W_q.weight, a=math.sqrt(5))
217
        nn.init.kaiming_uniform_(self.W_k.weight, a=math.sqrt(5))
218
        nn.init.kaiming_uniform_(self.W_v.weight, a=math.sqrt(5))
219
        nn.init.kaiming_uniform_(self.W_out.weight, a=math.sqrt(5))
220
221
        self.Wh = nn.Parameter(
222
            torch.randn(2 * attention_input_dim, attention_hidden_dim)
223
        )
224
        self.Wa = nn.Parameter(torch.randn(attention_hidden_dim, 1))
225
        self.ba = nn.Parameter(
226
            torch.zeros(
227
                1,
228
            )
229
        )
230
231
        nn.init.kaiming_uniform_(self.Wh, a=math.sqrt(5))
232
        nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5))
233
234
        self.dropout = nn.Dropout(p=dropout)
235
        self.tanh = nn.Tanh()
236
        self.softmax = nn.Softmax(dim=1)
237
        self.sigmoid = nn.Sigmoid()
238
239
    def forward(self, input):
240
241
        (
242
            batch_size,
243
            time_step,
244
            input_dim,
245
        ) = input.size()  # batch_size * input_dim + 1 * hidden_dim(i)
246
        input_q = self.W_q(input[:, -1, :])  # b h
247
        input_k = self.W_k(input)  # b t h
248
        input_v = self.W_v(input)  # b t h
249
250
        if self.attention_type == "add":  # B*T*I  @ H*I
251
252
            q = torch.reshape(
253
                input_q, (batch_size, 1, self.attention_hidden_dim)
254
            )  # B*1*H
255
            h = q + input_k + self.b_in  # b t h
256
            h = self.tanh(h)  # B*T*H
257
            e = self.W_out(h)  # b t 1
258
            e = torch.reshape(e, (batch_size, time_step))  # b t
259
260
        elif self.attention_type == "mul":
261
            q = torch.reshape(
262
                input_q, (batch_size, self.attention_hidden_dim, 1)
263
            )  # B*h 1
264
            e = torch.matmul(input_k, q).squeeze()  # b t
265
266
        elif self.attention_type == "concat":
267
            q = input_q.unsqueeze(1).repeat(1, time_step, 1)  # b t h
268
            k = input_k
269
            c = torch.cat((q, k), dim=-1)  # B*T*2I
270
            h = torch.matmul(c, self.Wh)
271
            h = self.tanh(h)
272
            e = torch.matmul(h, self.Wa) + self.ba  # B*T*1
273
            e = torch.reshape(e, (batch_size, time_step))  # b t
274
275
        a = self.softmax(e)  # B*T
276
        if self.dropout is not None:
277
            a = self.dropout(a)
278
        v = torch.matmul(a.unsqueeze(1), input_v).squeeze()  # B*I
279
280
        return v, a
281
282
283
class PositionwiseFeedForward(nn.Module):  # new added
284
    "Implements FFN equation."
285
286
    def __init__(self, d_model, d_ff, dropout=0.1):
287
        super(PositionwiseFeedForward, self).__init__()
288
        self.w_1 = nn.Linear(d_model, d_ff)
289
        self.w_2 = nn.Linear(d_ff, d_model)
290
        self.dropout = nn.Dropout(dropout)
291
292
    def forward(self, x):
293
        return self.w_2(self.dropout(F.relu(self.w_1(x)))), None
294
295
296
class PositionalEncoding(nn.Module):  # new added / not use anymore
297
    "Implement the PE function."
298
299
    def __init__(self, d_model, dropout, max_len=400):
300
        super(PositionalEncoding, self).__init__()
301
        self.dropout = nn.Dropout(p=dropout)
302
303
        # Compute the positional encodings once in log space.
304
        pe = torch.zeros(max_len, d_model)
305
        position = torch.arange(0.0, max_len).unsqueeze(1)
306
        div_term = torch.exp(
307
            torch.arange(0.0, d_model, 2) * -(math.log(10000.0) / d_model)
308
        )
309
        pe[:, 0::2] = torch.sin(position * div_term)
310
        pe[:, 1::2] = torch.cos(position * div_term)
311
        pe = pe.unsqueeze(0)
312
        self.register_buffer("pe", pe)
313
314
    def forward(self, x):
315
        x = x + Variable(self.pe[:, : x.size(1)], requires_grad=False)
316
        return self.dropout(x)
317
318
319
class MultiHeadedAttention(nn.Module):
320
    def __init__(self, h, d_model, dropout=0):
321
        "Take in model size and number of heads."
322
        super(MultiHeadedAttention, self).__init__()
323
        assert d_model % h == 0
324
        # We assume d_v always equals d_k
325
        self.d_k = d_model // h
326
        self.h = h
327
        self.linears = nn.ModuleList(
328
            [nn.Linear(d_model, self.d_k * self.h) for _ in range(3)]
329
        )
330
        self.final_linear = nn.Linear(d_model, d_model)
331
        self.attn = None
332
        self.dropout = nn.Dropout(p=dropout)
333
334
    def attention(self, query, key, value, mask=None, dropout=None):
335
        "Compute 'Scaled Dot Product Attention'"
336
        d_k = query.size(-1)  # b h t d_k
337
        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)  # b h t t
338
        if mask is not None:  # 1 1 t t
339
            scores = scores.masked_fill(mask == 0, -1e9)  # b h t t 下三角
340
        p_attn = F.softmax(scores, dim=-1)  # b h t t
341
        if dropout is not None:
342
            p_attn = dropout(p_attn)
343
        return torch.matmul(p_attn, value), p_attn  # b h t v (d_k)
344
345
    def cov(self, m, y=None):
346
        if y is not None:
347
            m = torch.cat((m, y), dim=0)
348
        m_exp = torch.mean(m, dim=1)
349
        x = m - m_exp[:, None]
350
        cov = 1 / (x.size(1) - 1) * x.mm(x.t())
351
        return cov
352
353
    def forward(self, query, key, value, mask=None):
354
        if mask is not None:
355
            # Same mask applied to all h heads.
356
            mask = mask.unsqueeze(1)  # 1 1 t t
357
358
        nbatches = query.size(0)  # b
359
        input_dim = query.size(1)  # i+1
360
        feature_dim = query.size(1)  # i+1
361
362
        # input size -> # batch_size * d_input * hidden_dim
363
364
        # d_model => h * d_k
365
        query, key, value = [
366
            l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
367
            for l, x in zip(self.linears, (query, key, value))
368
        ]  # b num_head d_input d_k
369
370
        x, self.attn = self.attention(
371
            query, key, value, mask=mask, dropout=self.dropout
372
        )  # b num_head d_input d_v (d_k)
373
374
        x = (
375
            x.transpose(1, 2).contiguous().view(nbatches, -1, self.h * self.d_k)
376
        )  # batch_size * d_input * hidden_dim
377
378
        # DeCov
379
        DeCov_contexts = x.transpose(0, 1).transpose(1, 2)  # I+1 H B
380
        Covs = self.cov(DeCov_contexts[0, :, :])
381
        DeCov_loss = 0.5 * (
382
            torch.norm(Covs, p="fro") ** 2 - torch.norm(torch.diag(Covs)) ** 2
383
        )
384
        for i in range(feature_dim - 1):
385
            Covs = self.cov(DeCov_contexts[i + 1, :, :])
386
            DeCov_loss += 0.5 * (
387
                torch.norm(Covs, p="fro") ** 2 - torch.norm(torch.diag(Covs)) ** 2
388
            )
389
390
        return self.final_linear(x), DeCov_loss
391
392
393
class LayerNorm(nn.Module):
394
    def __init__(self, features, eps=1e-7):
395
        super(LayerNorm, self).__init__()
396
        self.a_2 = nn.Parameter(torch.ones(features))
397
        self.b_2 = nn.Parameter(torch.zeros(features))
398
        self.eps = eps
399
400
    def forward(self, x):
401
        mean = x.mean(-1, keepdim=True)
402
        std = x.std(-1, keepdim=True)
403
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
404
405
406
class SublayerConnection(nn.Module):
407
    """
408
    A residual connection followed by a layer norm.
409
    Note for code simplicity the norm is first as opposed to last.
410
    """
411
412
    def __init__(self, size, dropout):
413
        super(SublayerConnection, self).__init__()
414
        self.norm = LayerNorm(size)
415
        self.dropout = nn.Dropout(dropout)
416
417
    def forward(self, x, sublayer):
418
        "Apply residual connection to any sublayer with the same size."
419
        returned_value = sublayer(self.norm(x))
420
        return x + self.dropout(returned_value[0]), returned_value[1]
421
422
423
class ConCare(nn.Module):
424
    def __init__(
425
        self,
426
        lab_dim,  # lab_dim
427
        hidden_dim,
428
        demo_dim,
429
        d_model,
430
        MHD_num_head,
431
        d_ff,
432
        # output_dim,
433
        # device,
434
        drop=0.5,
435
    ):
436
        super(ConCare, self).__init__()
437
438
        # hyperparameters
439
        self.lab_dim = lab_dim
440
        self.hidden_dim = hidden_dim  # d_model
441
        self.d_model = d_model
442
        self.MHD_num_head = MHD_num_head
443
        self.d_ff = d_ff
444
        # self.output_dim = output_dim
445
        self.drop = drop
446
        self.demo_dim = demo_dim
447
448
        # layers
449
        self.PositionalEncoding = PositionalEncoding(
450
            self.d_model, dropout=0, max_len=400
451
        )
452
453
        self.GRUs = nn.ModuleList(
454
            [
455
                copy.deepcopy(nn.GRU(1, self.hidden_dim, batch_first=True))
456
                for _ in range(self.lab_dim)
457
            ]
458
        )
459
        self.LastStepAttentions = nn.ModuleList(
460
            [
461
                copy.deepcopy(
462
                    SingleAttention(
463
                        self.hidden_dim,
464
                        8,
465
                        attention_type="new",
466
                        demographic_dim=12,
467
                        time_aware=True,
468
                        use_demographic=False,
469
                    )
470
                )
471
                for _ in range(self.lab_dim)
472
            ]
473
        )
474
475
        self.FinalAttentionQKV = FinalAttentionQKV(
476
            self.hidden_dim,
477
            self.hidden_dim,
478
            attention_type="mul",
479
            dropout=self.drop,
480
        )
481
482
        self.MultiHeadedAttention = MultiHeadedAttention(
483
            self.MHD_num_head, self.d_model, dropout=self.drop
484
        )
485
        self.SublayerConnection = SublayerConnection(self.d_model, dropout=self.drop)
486
487
        self.PositionwiseFeedForward = PositionwiseFeedForward(
488
            self.d_model, self.d_ff, dropout=0.1
489
        )
490
491
        self.demo_lab_proj = nn.Linear(self.demo_dim + self.lab_dim, self.hidden_dim)
492
        self.demo_proj_main = nn.Linear(self.demo_dim, self.hidden_dim)
493
        self.demo_proj = nn.Linear(self.demo_dim, self.hidden_dim)
494
        self.output0 = nn.Linear(self.hidden_dim, self.hidden_dim)
495
        # self.output1 = nn.Linear(self.hidden_dim, self.output_dim)
496
497
        self.dropout = nn.Dropout(p=self.drop)
498
        self.tanh = nn.Tanh()
499
        self.softmax = nn.Softmax()
500
        self.sigmoid = nn.Sigmoid()
501
        self.relu = nn.ReLU()
502
503
    def concare_encoder(self, input, demo_input, device):
504
505
        # input shape [batch_size, timestep, feature_dim]
506
        demo_main = self.tanh(self.demo_proj_main(demo_input)).unsqueeze(
507
            1
508
        )  # b hidden_dim
509
510
        batch_size = input.size(0)
511
        time_step = input.size(1)
512
        feature_dim = input.size(2)
513
        assert feature_dim == self.lab_dim  # input Tensor : 256 * 48 * 76
514
        assert self.d_model % self.MHD_num_head == 0
515
516
        # forward
517
        GRU_embeded_input = self.GRUs[0](
518
            input[:, :, 0].unsqueeze(-1).to(device=device),
519
            Variable(
520
                torch.zeros(batch_size, self.hidden_dim)
521
                .to(device=device)
522
                .unsqueeze(0)
523
            ),
524
        )[
525
            0
526
        ]  # b t h
527
        Attention_embeded_input = self.LastStepAttentions[0](GRU_embeded_input, device)[
528
            0
529
        ].unsqueeze(
530
            1
531
        )  # b 1 h
532
533
        for i in range(feature_dim - 1):
534
            embeded_input = self.GRUs[i + 1](
535
                input[:, :, i + 1].unsqueeze(-1),
536
                Variable(
537
                    torch.zeros(batch_size, self.hidden_dim)
538
                    .to(device=device)
539
                    .unsqueeze(0)
540
                ),
541
            )[
542
                0
543
            ]  # b 1 h
544
            embeded_input = self.LastStepAttentions[i + 1](embeded_input, device)[0].unsqueeze(
545
                1
546
            )  # b 1 h
547
            Attention_embeded_input = torch.cat(
548
                (Attention_embeded_input, embeded_input), 1
549
            )  # b i h
550
        Attention_embeded_input = torch.cat(
551
            (Attention_embeded_input, demo_main), 1
552
        )  # b i+1 h
553
        posi_input = self.dropout(
554
            Attention_embeded_input
555
        )  # batch_size * d_input+1 * hidden_dim
556
557
        contexts = self.SublayerConnection(
558
            posi_input,
559
            lambda x: self.MultiHeadedAttention(
560
                posi_input, posi_input, posi_input, None
561
            ),
562
        )  # # batch_size * d_input * hidden_dim
563
564
        DeCov_loss = contexts[1]
565
        contexts = contexts[0]
566
567
        contexts = self.SublayerConnection(
568
            contexts, lambda x: self.PositionwiseFeedForward(contexts)
569
        )[0]
570
571
        weighted_contexts = self.FinalAttentionQKV(contexts)[0]
572
        return weighted_contexts
573
574
    def forward(self, x, device, info=None):
575
        """extra info is not used here"""
576
        batch_size, time_steps, _ = x.size()
577
        demo_input = x[:, 0, : self.demo_dim]
578
        lab_input = x[:, :, self.demo_dim :]
579
        out = torch.zeros((batch_size, time_steps, self.hidden_dim))
580
        for cur_time in range(time_steps):
581
            # print(cur_time, end=" ")
582
            cur_lab = lab_input[:, : cur_time + 1, :]
583
            # print("cur_lab", cur_lab.shape)
584
            if cur_time == 0:
585
                out[:, cur_time, :] = self.demo_lab_proj(x[:, 0, :])
586
            else:
587
                out[:, cur_time, :] = self.concare_encoder(cur_lab, demo_input, device)
588
        # print()
589
        return out