Switch to unified view

a b/AICare-baselines/models/aicare.py
1
import torch
2
from torch import nn
3
import torch.nn.functional as F
4
import math
5
import copy
6
import numpy as np
7
from torch.nn.utils.rnn import pack_padded_sequence
8
9
class Sparsemax(nn.Module):
10
    """Sparsemax function."""
11
12
    def __init__(self, dim=None):
13
        super(Sparsemax, self).__init__()
14
15
        self.dim = -1 if dim is None else dim
16
17
    def forward(self, input, device='cuda'):
18
        original_size = input.size()
19
        input = input.view(-1, input.size(self.dim))
20
        
21
        dim = 1
22
        number_of_logits = input.size(dim)
23
24
        input = input - torch.max(input, dim=dim, keepdim=True)[0].expand_as(input)
25
26
        zs = torch.sort(input=input, dim=dim, descending=True)[0]
27
        range = torch.arange(start=1, end=number_of_logits+1, device=device, dtype=torch.float32).view(1, -1)
28
        range = range.expand_as(zs)
29
30
        bound = 1 + range * zs
31
        cumulative_sum_zs = torch.cumsum(zs, dim)
32
        is_gt = torch.gt(bound, cumulative_sum_zs).type(input.type())
33
        k = torch.max(is_gt * range, dim, keepdim=True)[0]
34
35
        zs_sparse = is_gt * zs
36
        taus = (torch.sum(zs_sparse, dim, keepdim=True) - 1) / k
37
        taus = taus.expand_as(input)
38
39
        self.output = torch.max(torch.zeros_like(input), input - taus)
40
41
        output = self.output.view(original_size)
42
43
        return output
44
45
    def backward(self, grad_output):
46
        dim = 1
47
48
        nonzeros = torch.ne(self.output, 0)
49
        sum = torch.sum(grad_output * nonzeros, dim=dim) / torch.sum(nonzeros, dim=dim)
50
        self.grad_input = nonzeros * (grad_output - sum.expand_as(grad_output))
51
52
        return self.grad_input
53
54
class SingleAttention(nn.Module):
55
    def __init__(self, attention_input_dim, attention_hidden_dim, attention_type='add', demographic_dim=12, time_aware=False, use_demographic=False):
56
        super(SingleAttention, self).__init__()
57
        
58
        self.attention_type = attention_type
59
        self.attention_hidden_dim = attention_hidden_dim
60
        self.attention_input_dim = attention_input_dim
61
        self.use_demographic = use_demographic
62
        self.demographic_dim = demographic_dim
63
        self.time_aware = time_aware
64
65
        # batch_time = torch.arange(0, batch_mask.size()[1], dtype=torch.float32).reshape(1, batch_mask.size()[1], 1)
66
        # batch_time = batch_time.repeat(batch_mask.size()[0], 1, 1)
67
        
68
        if attention_type == 'add':
69
            if self.time_aware == True:
70
                # self.Wx = nn.Parameter(torch.randn(attention_input_dim+1, attention_hidden_dim))
71
                self.Wx = nn.Parameter(torch.randn(attention_input_dim, attention_hidden_dim))
72
                self.Wtime_aware = nn.Parameter(torch.randn(1, attention_hidden_dim))
73
                nn.init.kaiming_uniform_(self.Wtime_aware, a=math.sqrt(5))
74
            else:
75
                self.Wx = nn.Parameter(torch.randn(attention_input_dim, attention_hidden_dim))
76
            self.Wt = nn.Parameter(torch.randn(attention_input_dim, attention_hidden_dim))
77
            self.Wd = nn.Parameter(torch.randn(demographic_dim, attention_hidden_dim))
78
            self.bh = nn.Parameter(torch.zeros(attention_hidden_dim,))
79
            self.Wa = nn.Parameter(torch.randn(attention_hidden_dim, 1))
80
            self.ba = nn.Parameter(torch.zeros(1,))
81
            
82
            nn.init.kaiming_uniform_(self.Wd, a=math.sqrt(5))
83
            nn.init.kaiming_uniform_(self.Wx, a=math.sqrt(5))
84
            nn.init.kaiming_uniform_(self.Wt, a=math.sqrt(5))
85
            nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5))
86
        elif attention_type == 'mul':
87
            self.Wa = nn.Parameter(torch.randn(attention_input_dim, attention_input_dim))
88
            self.ba = nn.Parameter(torch.zeros(1,))
89
            
90
            nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5))
91
        elif attention_type == 'concat':
92
            if self.time_aware == True:
93
                self.Wh = nn.Parameter(torch.randn(2*attention_input_dim+1, attention_hidden_dim))
94
            else:
95
                self.Wh = nn.Parameter(torch.randn(2*attention_input_dim, attention_hidden_dim))
96
97
            self.Wa = nn.Parameter(torch.randn(attention_hidden_dim, 1))
98
            self.ba = nn.Parameter(torch.zeros(1,))
99
            
100
            nn.init.kaiming_uniform_(self.Wh, a=math.sqrt(5))
101
            nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5))
102
        else:
103
            raise RuntimeError('Wrong attention type.')
104
        
105
        self.tanh = nn.Tanh()
106
        self.softmax = nn.Softmax()
107
    
108
    def forward(self, input, demo=None):
109
 
110
        batch_size, time_step, input_dim = input.size() # batch_size * time_step * hidden_dim(i)
111
        #assert(input_dim == self.input_dim)
112
113
        # time_decays = torch.zeros((time_step,time_step)).to(device)# t*t
114
        # for this_time in range(time_step):
115
        #     for pre_time in range(time_step):
116
        #         if pre_time > this_time:
117
        #             break
118
        #         time_decays[this_time][pre_time] = torch.tensor(this_time - pre_time, dtype=torch.float32).to(device)
119
        # b_time_decays = tile(time_decays, 0, batch_size).view(batch_size,time_step,time_step).unsqueeze(-1).to(device)# b t t 1
120
121
        time_decays = torch.tensor(range(47,-1,-1), dtype=torch.float32).unsqueeze(-1).unsqueeze(0).to(self.device)# 1*t*1
122
        b_time_decays = time_decays.repeat(batch_size,1,1)# b t 1
123
        
124
        if self.attention_type == 'add': #B*T*I  @ H*I
125
            q = torch.matmul(input[:,-1,:], self.Wt)# b h
126
            q = torch.reshape(q, (batch_size, 1, self.attention_hidden_dim)) #B*1*H
127
            if self.time_aware == True:
128
                # k_input = torch.cat((input, time), dim=-1)
129
                k = torch.matmul(input, self.Wx)#b t h
130
                # k = torch.reshape(k, (batch_size, 1, time_step, self.attention_hidden_dim)) #B*1*T*H
131
                time_hidden = torch.matmul(b_time_decays, self.Wtime_aware)#  b t h
132
            else:
133
                k = torch.matmul(input, self.Wx)# b t h
134
                # k = torch.reshape(k, (batch_size, 1, time_step, self.attention_hidden_dim)) #B*1*T*H
135
            if self.use_demographic == True:
136
                d = torch.matmul(demo, self.Wd) #B*H
137
                d = torch.reshape(d, (batch_size, 1, self.attention_hidden_dim)) # b 1 h
138
            h = q + k + self.bh # b t h
139
            if self.time_aware == True:
140
                h += time_hidden
141
            h = self.tanh(h) #B*T*H
142
            e = torch.matmul(h, self.Wa) + self.ba #B*T*1
143
            e = torch.reshape(e, (batch_size, time_step))# b t
144
        elif self.attention_type == 'mul':
145
            e = torch.matmul(input[:,-1,:], self.Wa)#b i
146
            e = torch.matmul(e.unsqueeze(1), input.permute(0,2,1)).squeeze() + self.ba #b t
147
        elif self.attention_type == 'concat':
148
            q = input[:,-1,:].unsqueeze(1).repeat(1,time_step,1)# b t i
149
            k = input
150
            c = torch.cat((q, k), dim=-1) #B*T*2I
151
            if self.time_aware == True:
152
                c = torch.cat((c, b_time_decays), dim=-1) #B*T*2I+1
153
            h = torch.matmul(c, self.Wh)
154
            h = self.tanh(h)
155
            e = torch.matmul(h, self.Wa) + self.ba #B*T*1
156
            e = torch.reshape(e, (batch_size, time_step)) # b t 
157
158
        a = self.softmax(e) #B*T
159
        v = torch.matmul(a.unsqueeze(1), input).squeeze() #B*I
160
161
        return v, a
162
163
class FinalAttentionQKV(nn.Module):
164
    def __init__(self, attention_input_dim, attention_hidden_dim, attention_type='add', dropout=None):
165
        super(FinalAttentionQKV, self).__init__()
166
        
167
        self.attention_type = attention_type
168
        self.attention_hidden_dim = attention_hidden_dim
169
        self.attention_input_dim = attention_input_dim
170
171
172
        self.W_q = nn.Linear(attention_input_dim, attention_hidden_dim)
173
        self.W_k = nn.Linear(attention_input_dim, attention_hidden_dim)
174
        self.W_v = nn.Linear(attention_input_dim, attention_hidden_dim)
175
176
        self.W_out = nn.Linear(attention_hidden_dim, 1)
177
178
        self.b_in = nn.Parameter(torch.zeros(1,))
179
        self.b_out = nn.Parameter(torch.zeros(1,))
180
181
        nn.init.kaiming_uniform_(self.W_q.weight, a=math.sqrt(5))
182
        nn.init.kaiming_uniform_(self.W_k.weight, a=math.sqrt(5))
183
        nn.init.kaiming_uniform_(self.W_v.weight, a=math.sqrt(5))
184
        nn.init.kaiming_uniform_(self.W_out.weight, a=math.sqrt(5))
185
186
        self.Wh = nn.Parameter(torch.randn(2*attention_input_dim, attention_hidden_dim))
187
        self.Wa = nn.Parameter(torch.randn(attention_hidden_dim, 1))
188
        self.ba = nn.Parameter(torch.zeros(1,))
189
        
190
        nn.init.kaiming_uniform_(self.Wh, a=math.sqrt(5))
191
        nn.init.kaiming_uniform_(self.Wa, a=math.sqrt(5))
192
        
193
        self.dropout = nn.Dropout(p=dropout)
194
        self.tanh = nn.Tanh()
195
        self.softmax = nn.Softmax(dim=1)
196
        self.sigmoid = nn.Sigmoid()
197
        self.sparsemax = Sparsemax()
198
    
199
    def forward(self, input):
200
 
201
        batch_size, time_step, input_dim = input.size() # batch_size * input_dim + 1 * hidden_dim(i)
202
        input_q = self.W_q(torch.mean(input,1)) # b h
203
        input_k = self.W_k(input[:,:-1,:])# b t h
204
        input_v = self.W_v(input[:,:-1,:])# b t h
205
206
        if self.attention_type == 'add': #B*T*I  @ H*I
207
208
            q = torch.reshape(input_q, (batch_size, 1, self.attention_hidden_dim)) #B*1*H
209
            h = q + input_k + self.b_in # b t h
210
            h = self.tanh(h) #B*T*H
211
            e = self.W_out(h) # b t 1
212
            e = torch.reshape(e, (batch_size, time_step))# b t
213
214
        elif self.attention_type == 'mul':
215
            q = torch.reshape(input_q, (batch_size, self.attention_hidden_dim, 1)) #B*h 1
216
            e = torch.matmul(input_k, q).squeeze()#b t
217
            
218
        elif self.attention_type == 'concat':
219
            q = input_q.unsqueeze(1).repeat(1,time_step,1)# b t h
220
            k = input_k
221
            c = torch.cat((q, k), dim=-1) #B*T*2I
222
            h = torch.matmul(c, self.Wh)
223
            h = self.tanh(h)
224
            e = torch.matmul(h, self.Wa) + self.ba #B*T*1
225
            e = torch.reshape(e, (batch_size, time_step)) # b t 
226
        a = self.softmax(e) #B*T
227
        if self.dropout is not None:
228
            a = self.dropout(a)
229
        v = torch.matmul(a.unsqueeze(1), input_v).squeeze() #B*I
230
231
        return v, a
232
233
def clones(module, N):
234
    "Produce N identical layers."
235
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
236
237
# def tile(a, dim, n_tile):
238
#     init_dim = a.size(dim)
239
#     repeat_idx = [1] * a.dim()
240
#     repeat_idx[dim] = n_tile
241
#     a = a.repeat(*(repeat_idx))
242
#     order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])).to(self.device)
243
#     return torch.index_select(a, dim, order_index).to(self.device)
244
245
class PositionwiseFeedForward(nn.Module): # new added
246
    "Implements FFN equation."
247
    def __init__(self, d_model, d_ff, dropout=0.1):
248
        super(PositionwiseFeedForward, self).__init__()
249
        self.w_1 = nn.Linear(d_model, d_ff)
250
        self.w_2 = nn.Linear(d_ff, d_model)
251
        self.dropout = nn.Dropout(dropout)
252
253
    def forward(self, x):
254
        return self.w_2(self.dropout(F.relu(self.w_1(x)))), None
255
256
257
258
class PositionalEncoding(nn.Module): # new added / not use anymore
259
    "Implement the PE function."
260
    def __init__(self, d_model, dropout, max_len=400):
261
        super(PositionalEncoding, self).__init__()
262
        self.dropout = nn.Dropout(p=dropout)
263
        
264
        # Compute the positional encodings once in log space.
265
        pe = torch.zeros(max_len, d_model)
266
        position = torch.arange(0., max_len).unsqueeze(1)
267
        div_term = torch.exp(torch.arange(0., d_model, 2) * -(math.log(10000.0) / d_model))
268
        pe[:, 0::2] = torch.sin(position * div_term)
269
        pe[:, 1::2] = torch.cos(position * div_term)
270
        pe = pe.unsqueeze(0)
271
        self.register_buffer('pe', pe)
272
        
273
    def forward(self, x):
274
        x = x + torch.autograd.Variable(self.pe[:, :x.size(1)], 
275
                         requires_grad=False)
276
        return self.dropout(x)
277
278
def subsequent_mask(size):
279
    "Mask out subsequent positions."
280
    attn_shape = (1, size, size)
281
    subsequent_mask = np.triu(np.ones(attn_shape), k=1).astype('uint8')
282
    return torch.from_numpy(subsequent_mask) == 0 # 下三角矩阵
283
284
def attention(query, key, value, mask=None, dropout=None):
285
    "Compute 'Scaled Dot Product Attention'"
286
    d_k = query.size(-1)# b h t d_k
287
    scores = torch.matmul(query, key.transpose(-2, -1)) \
288
             / math.sqrt(d_k) # b h t t
289
    if mask is not None:# 1 1 t t
290
        scores = scores.masked_fill(mask == 0, -1e9)# b h t t 下三角
291
    p_attn = F.softmax(scores, dim = -1)# b h t t
292
    if dropout is not None:
293
        p_attn = dropout(p_attn)
294
    return torch.matmul(p_attn, value), p_attn # b h t v (d_k) 
295
    
296
class MultiHeadedAttention(nn.Module):
297
    def __init__(self, h, d_model, dropout=0):
298
        "Take in model size and number of heads."
299
        super(MultiHeadedAttention, self).__init__()
300
        assert d_model % h == 0
301
        # We assume d_v always equals d_k
302
        self.d_k = d_model // h
303
        self.h = h
304
        self.linears = clones(nn.Linear(d_model, self.d_k * self.h), 3)
305
        self.final_linear = nn.Linear(d_model, d_model)
306
        self.attn = None
307
        self.dropout = nn.Dropout(p=dropout)
308
        
309
    def forward(self, query, key, value, mask=None):
310
        if mask is not None:
311
            # Same mask applied to all h heads.
312
            mask = mask.unsqueeze(1) # 1 1 t t
313
314
        nbatches = query.size(0)# b
315
        input_dim = query.size(1)# i+1
316
        feature_dim = query.size(-1)# i+1
317
318
        #input size -> # batch_size * d_input * hidden_dim
319
        
320
        # d_model => h * d_k 
321
        query, key, value = \
322
            [l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
323
             for l, x in zip(self.linears, (query, key, value))] # b num_head d_input d_k
324
        
325
       
326
        x, self.attn = attention(query, key, value, mask=mask, 
327
                                 dropout=self.dropout)# b num_head d_input d_v (d_k) 
328
329
330
      
331
        x = x.transpose(1, 2).contiguous() \
332
             .view(nbatches, -1, self.h * self.d_k)# batch_size * d_input * hidden_dim
333
334
        #DeCov 
335
        DeCov_contexts = x.transpose(0, 1).transpose(1, 2) # I+1 H B
336
#         print(DeCov_contexts.shape)
337
        Covs = cov(DeCov_contexts[0,:,:])
338
        DeCov_loss = 0.5 * (torch.norm(Covs, p = 'fro')**2 - torch.norm(torch.diag(Covs))**2 ) 
339
        for i in range(17+1 -1):
340
            Covs = cov(DeCov_contexts[i+1,:,:])
341
            DeCov_loss += 0.5 * (torch.norm(Covs, p = 'fro')**2 - torch.norm(torch.diag(Covs))**2 ) 
342
343
344
        return self.final_linear(x), DeCov_loss
345
346
class LayerNorm(nn.Module):
347
    def __init__(self, size, eps=1e-7):
348
        super(LayerNorm, self).__init__()
349
        self.a_2 = nn.Parameter(torch.ones(size))
350
        self.b_2 = nn.Parameter(torch.zeros(size))
351
        self.eps = eps
352
353
    def forward(self, x):
354
        mean = x.mean(-1, keepdim=True)
355
        std = x.std(-1, keepdim=True)
356
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
357
358
def cov(m, y=None):
359
    if y is not None:
360
        m = torch.cat((m, y), dim=0)
361
    m_exp = torch.mean(m, dim=1)
362
    x = m - m_exp[:, None]
363
    cov = 1 / (x.size(1) - 1) * x.mm(x.t())
364
    return cov
365
366
class SublayerConnection(nn.Module):
367
    """
368
    A residual connection followed by a layer norm.
369
    Note for code simplicity the norm is first as opposed to last.
370
    """
371
    def __init__(self, size, dropout):
372
        super(SublayerConnection, self).__init__()
373
        self.norm = LayerNorm(size)
374
        self.dropout = nn.Dropout(dropout)
375
376
    def forward(self, x, sublayer):
377
        "Apply residual connection to any sublayer with the same size."
378
        returned_value = sublayer(self.norm(x))
379
        return x + self.dropout(returned_value[0]) , returned_value[1]
380
381
class AICare(nn.Module):
382
    def __init__(self, lab_dim=17, demo_dim=4, hidden_dim=32, d_model=32,  MHD_num_head=4, d_ff=64, device='cuda', keep_prob=0.5, **kwargs):
383
        super(AICare, self).__init__()
384
385
        # hyperparameters
386
        self.lab_dim = lab_dim  
387
        self.hidden_dim = hidden_dim  # d_model
388
        self.d_model = d_model
389
        self.MHD_num_head = MHD_num_head
390
        self.device = device
391
        self.d_ff = d_ff
392
        self.keep_prob = keep_prob
393
394
        # layers
395
        self.PositionalEncoding = PositionalEncoding(self.d_model, dropout = 0, max_len = 400)
396
397
#         self.GRUs = clones(nn.GRU(1, self.hidden_dim, batch_first = True), self.lab_dim)
398
        self.GRUs = clones(nn.RNN(1, self.hidden_dim, bidirectional = True, batch_first = True), self.lab_dim)
399
400
        self.LastStepAttentions = clones(SingleAttention(self.hidden_dim, 8, attention_type='concat', demographic_dim=12, time_aware=True, use_demographic=False),self.lab_dim)
401
        
402
        self.FinalAttentionQKV = FinalAttentionQKV(self.hidden_dim, self.hidden_dim, attention_type='mul',dropout = 1 - self.keep_prob)
403
404
        self.MultiHeadedAttention = MultiHeadedAttention(self.MHD_num_head, self.d_model,dropout = 1 - self.keep_prob)
405
        self.SublayerConnection = SublayerConnection(self.d_model, dropout = 1 - self.keep_prob)
406
407
        self.PositionwiseFeedForward = PositionwiseFeedForward(self.d_model, self.d_ff, dropout=0.1)
408
409
        self.demo_proj_main = nn.Linear(demo_dim, self.hidden_dim)
410
        self.demo_proj = nn.Linear(demo_dim, self.hidden_dim)
411
        self.output_proj = nn.Linear(self.hidden_dim*2, self.hidden_dim)
412
413
        self.dropout = nn.Dropout(p = 1 - self.keep_prob)
414
        self.tanh=nn.Tanh()
415
        self.softmax = nn.Softmax()
416
        self.sigmoid = nn.Sigmoid()
417
        self.relu=nn.ReLU()
418
419
    def forward(self, input, demo_input, mask, **kwargs):
420
        # input shape [batch_size, timestep, feature_dim]
421
        demo_main = self.tanh(self.demo_proj_main(demo_input)).unsqueeze(1)# b hidden_dim
422
        
423
        batch_size = input.size(0)
424
        time_step = input.size(1)
425
        feature_dim = input.size(2)
426
        assert(feature_dim == self.lab_dim)# input Tensor : 256 * 48 * 76
427
        assert(self.d_model % self.MHD_num_head == 0)
428
429
        lens = mask.sum(dim=1)
430
431
        GRU_embeded_input = torch.sum(self.GRUs[0](pack_padded_sequence(input[:,:,0].unsqueeze(-1), lens.cpu(), batch_first=True, enforce_sorted=False))[1], 0).squeeze().unsqueeze(1) # b 1 h
432
#         print(GRU_embeded_input.shape)
433
        for i in range(feature_dim-1):
434
            embeded_input = torch.sum(self.GRUs[i+1](pack_padded_sequence(input[:,:,i+1].unsqueeze(-1), lens.cpu(), batch_first=True, enforce_sorted=False))[1], 0).squeeze().unsqueeze(1)  # b 1 h
435
            GRU_embeded_input = torch.cat((GRU_embeded_input, embeded_input), 1)
436
437
#         print(demo_main.shape)
438
        GRU_embeded_input = torch.cat((GRU_embeded_input, demo_main), 1)# b i+1 h
439
        posi_input = self.dropout(GRU_embeded_input) # batch_size * d_input * hidden_dim
440
441
442
        weighted_contexts = self.FinalAttentionQKV(posi_input)[0]
443
        combined_hidden = torch.cat((weighted_contexts, \
444
                                     demo_main.squeeze(1)),-1)#b n h
445
        out = self.output_proj(combined_hidden)
446
        # out = self.dropout(out)
447
        # output = self.output(self.dropout(combined_hidden))# b 1
448
        # output = self.sigmoid(output)
449
        # return output
450
        return out