Switch to unified view

a b/stay_admission/baseline.py
1
import math
2
import pickle
3
4
import numpy as np
5
import torch
6
import torch.nn as nn
7
import torch.nn.functional as F
8
from torch.autograd import Variable
9
from operations import *
10
import model
11
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
12
13
class LSTM_bimodal(nn.Module):
14
    def __init__(self, vocab_size1, vocab_size2,  d_model = 256, dropout=0.5, dropout_emb=0.5, length=48, pretrain = False):
15
        super().__init__()
16
        self.embbedding1 = nn.Sequential(nn.Linear(vocab_size1, d_model), nn.ReLU())
17
        self.linear = nn.Linear(vocab_size2, d_model)
18
        self.dropout = nn.Dropout(dropout)
19
        self.emb_dropout = nn.Dropout(dropout_emb)
20
        self.output_mlp = nn.Sequential(nn.Linear(d_model, 1))
21
        self.pooler = MaxPoolLayer()
22
        if pretrain:
23
            self.rnns = nn.LSTM(d_model, d_model, 1, bidirectional=False, batch_first=True)
24
        else:
25
            self.rnns = nn.LSTM(vocab_size1, d_model, 1, bidirectional=False, batch_first=True)
26
        self.sig = nn.Sigmoid()
27
        self.ts_encoder = model.LSTM_Encoder(length, vocab_size1, d_model)
28
        self.linear_2 = nn.Linear(32, d_model)
29
        self.pretrain = pretrain
30
31
32
    def forward(self, x):
33
        if self.pretrain == True:
34
            x = self.ts_encoder(x)[0]
35
            x = self.emb_dropout(x)
36
37
        rnn_output, _ = self.rnns(x)
38
        x = self.pooler(rnn_output)
39
        x = self.output_mlp(x)
40
        x = self.sig(x)
41
        return x
42
43
44
class Transformer(nn.Module):
45
    def __init__(self, vocab_size1, vocab_size2, d_model, dropout=0.5, dropout_emb=0.5, length=48, pretrain = False):
46
        super().__init__()
47
        self.embbedding1 = nn.Sequential(nn.Linear(vocab_size1, d_model), nn.ReLU())
48
        self.linear = nn.Linear(vocab_size2, d_model)
49
        self.dropout = nn.Dropout(dropout)
50
        self.emb_dropout = nn.Dropout(dropout_emb)
51
        self.output_mlp = nn.Sequential(nn.Linear(d_model, 1))
52
        self.pooler = MaxPoolLayer()
53
        self.attention = SelfAttention(d_model)
54
        self.ffn = FFN(d_model)
55
        self.sig = nn.Sigmoid()
56
        self.ts_encoder = model.LSTM_Encoder(length, vocab_size1, d_model)
57
        self.linear_2 = nn.Linear(32, d_model)
58
        self.pretrain = pretrain
59
    def forward(self, x):
60
        if self.pretrain == True:
61
            x = self.ts_encoder(x)[0]
62
            x = self.emb_dropout(x)
63
        else:
64
            x = self.embbedding1(x)
65
66
        x = self.attention(x, None, None)
67
        x = self.ffn(x, None, None)
68
        x = self.dropout(x)
69
        x = self.pooler(x)
70
        x = self.output_mlp(x)
71
        x = self.sig(x)
72
        return x
73
    
74
class ClinicalT5(nn.Module):
75
    def __init__(self, d_model = 256):
76
      super().__init__()
77
78
      self.sig = nn.Sigmoid()
79
      self.t5 =  AutoModelForSeq2SeqLM.from_pretrained("LLM/physionet.org/files/clinical-t5/1.0.0/Clinical-T5-Base").encoder
80
      self.fc2 = nn.Linear(768, 1)
81
      self.pooler = MaxPoolLayer()
82
      self.relu1 = nn.ReLU()
83
    def forward(self, ts_x, tb_x, input_ids, attention_mask):
84
      
85
      text = self.t5(input_ids=input_ids,attention_mask=attention_mask, return_dict=True).last_hidden_state
86
      sent_emb = torch.mean(text, dim=1)
87
      sent_emb = self.fc2(sent_emb)
88
      x = self.sig(sent_emb)
89
      return x
90
91
class Raim(nn.Module):
92
    def __init__(self, vocab_size1, vocab_size2, vocab_size3, d_model, dropout=0.1, dropout_emb=0.1, length=48):
93
        super().__init__()
94
        self.embbedding1 = nn.Sequential(nn.Linear(vocab_size1, d_model), nn.ReLU())
95
        self.embbedding2 = nn.Sequential(nn.Linear(vocab_size2, d_model), nn.ReLU())
96
        self.linear = nn.Linear(vocab_size3, d_model)
97
        self.dropout = nn.Dropout(dropout)
98
        self.emb_dropout = nn.Dropout(dropout_emb)
99
        self.output_mlp = nn.Sequential(nn.Linear(d_model, 2))
100
        self.pooler = MaxPoolLayer()
101
102
        self.hidden_size = d_model
103
104
        self.rnn = nn.LSTM(d_model, d_model, 2, dropout=0.5)
105
        self.attn = nn.Linear(10, 10)
106
        self.attn1 = nn.Linear(60, 10)
107
108
        self.dense_h = nn.Linear(d_model, 1)
109
        self.softmax = nn.Softmax(dim=1)
110
        self.hidden2label = nn.Linear(d_model, 1)
111
        self.grucell = nn.GRUCell(d_model, d_model)
112
113
        self.mlp_for_x = nn.Linear(d_model, 1, bias=False)
114
        self.mlp_for_hidden = nn.Linear(d_model, length, bias=True)
115
        
116
        self.sigmoid = nn.Sigmoid()
117
118
119
    def init_hidden(self, batch_size):
120
        return Variable(torch.zeros(batch_size, self.hidden_size))
121
122
    def forward(self, x1, x2, s):
123
        x1 = self.embbedding1(x1)
124
        x2 = self.embbedding2(x2)
125
        s = self.linear(s)
126
        input_seqs = x1 + x2
127
        x = input_seqs
128
        self.hidden = self.init_hidden(x.size(0)).to(x.device)
129
        for i in range(x.size(1)):
130
            tt = x[:, 0:i + 1, :].reshape(x.size(0), (i + 1) * x[:, 0:i + 1, :].shape[2])
131
            if i < x.size(1) - 1:
132
                padding = torch.zeros(x.size(0), x.size(1)*x.size(2) - tt.shape[1]).to(x.device)
133
                self.temp1 = torch.cat((tt, padding), 1)
134
            else:
135
                self.temp1 = tt
136
137
            self.input_padded = self.temp1.reshape(x.size(0), x.size(1), x.size(-1))
138
139
            #### multuply with guidance #######
140
            temp_guidance = torch.zeros(x.size(0), x.size(1), 1).to(x.device)
141
142
            # temp_guidance[:, 0:i + 1, :] = x2[:, 0:i + 1, 0].unsqueeze(-1)
143
144
            if i > 0:
145
146
                zero_idx = torch.where(torch.sum(x2[:, :i, 0], dim=1) == 0)
147
                if len(zero_idx[0]) > 0:
148
                    temp_guidance[zero_idx[0], :i, 0] = 1
149
150
            temp_guidance[:, i, :] = 1
151
152
            self.guided_input = torch.mul(self.input_padded, temp_guidance)
153
154
            ######### MLP ###########
155
            self.t1 = self.mlp_for_x(self.guided_input) + self.mlp_for_hidden(self.hidden).reshape(x.size(0), x.size(1), 1)
156
157
            ######### softmax-> multiply->  context vector ###########
158
            self.t1_softmax = self.softmax(self.t1)
159
            final_output = torch.mul(self.input_padded, self.t1_softmax)
160
161
            context_vec = torch.sum(final_output, dim=1)
162
163
            self.hx = self.grucell(context_vec, self.hidden)
164
            self.hidden = self.hx
165
166
        y = self.hidden2label(self.hidden + s)
167
        return self.sigmoid(y)
168
169
170
171
class DCMN(nn.Module):
172
173
    def __init__(self, vocab_size1, vocab_size2, vocab_size3, d_model, dropout=0.1, dropout_emb=0.1, length=48):
174
        super().__init__()
175
        self.embbedding1 = nn.Sequential(nn.Conv1d(in_channels=1, out_channels=1, kernel_size=10, stride=5),
176
                                         nn.ReLU(),
177
                                         nn.Linear((vocab_size1 - 10) // 5 + 1, d_model))
178
        self.embbedding2 = nn.Sequential(nn.Conv1d(in_channels=1, out_channels=1, kernel_size=10, stride=5),
179
                                         nn.ReLU(),
180
                                         nn.Linear((vocab_size2 - 10) // 5 + 1, d_model))
181
        self.linear = nn.Linear(vocab_size3, d_model)
182
        self.batchnorm1 = nn.BatchNorm1d(d_model)
183
        self.batchnorm2 = nn.BatchNorm1d(d_model)
184
        self.conv = nn.Conv1d(d_model, d_model, 3, padding=1)
185
        self.dropout = nn.Dropout(dropout)
186
        self.emb_dropout = nn.Dropout(dropout_emb)
187
        self.output_mlp = nn.Sequential(nn.Linear(d_model, 1))
188
        self.c_emb = nn.LSTM(d_model, d_model, 1, bidirectional=False, batch_first=True)
189
        self.c_out = nn.LSTM(d_model, d_model, 1, bidirectional=False, batch_first=True)
190
        self.w_emb = nn.LSTM(d_model, d_model, 1, bidirectional=False, batch_first=True)
191
        self.w_out = nn.LSTM(d_model, d_model, 1, bidirectional=False, batch_first=True)
192
        self.linear1 = nn.Linear(d_model, d_model)
193
        self.linear2 = nn.Linear(d_model, d_model)
194
        self.linear3 = nn.Linear(d_model, d_model)
195
        self.linear4 = nn.Linear(d_model, d_model)
196
        self.gate_linear = nn.Linear(d_model, d_model)
197
        self.gate_linear2 = nn.Linear(d_model, d_model)
198
        self.pooler = MaxPoolLayer()
199
        self.sigmoid = nn.Sigmoid()
200
201
    def forward(self, x1, x2, s):
202
        bs, l, fdim = x1.size()
203
        x1 = x1.view(bs * l, -1).unsqueeze(1)
204
        x2 = x2.view(bs * l, -1).unsqueeze(1)
205
        x1 = self.embbedding1(x1)
206
        x2 = self.embbedding2(x2)
207
        x1 = x1.squeeze().view(bs, l, -1)
208
        x2 = x2.squeeze().view(bs, l, -1)
209
        s = self.dropout(self.linear(s))
210
        x1 = self.batchnorm1(x1.permute(0, 2, 1)).permute(0, 2, 1)
211
        x2 = self.batchnorm2(x2.permute(0, 2, 1)).permute(0, 2, 1)
212
        wm_embedding_memory, _ = self.w_emb(x1)
213
        wm_out_query, _ = self.w_out(x1)
214
        cm_embedding_memory, _ = self.c_emb(x2)
215
        cm_out_query, _ = self.c_out(x2)
216
        wm_in = cm_out_query[:, -1]
217
        cm_in = wm_out_query[:, -1]
218
        w_embedding_E = self.linear1(wm_embedding_memory)
219
        w_embedding_F = self.linear2(wm_embedding_memory)
220
        wm_out = torch.matmul(wm_in.unsqueeze(1), w_embedding_E.permute(0, 2, 1))
221
        wm_prob = torch.softmax(wm_out, dim=-1)
222
        wm_contex = torch.matmul(wm_prob, w_embedding_F)
223
        wm_gate_prob = torch.sigmoid(self.gate_linear(wm_in)).unsqueeze(1)
224
        wm_dout = wm_contex * wm_gate_prob + wm_in.unsqueeze(1) * (1 - wm_gate_prob)
225
226
        c_embedding_E = self.linear3(cm_embedding_memory)
227
        c_embedding_F = self.linear4(cm_embedding_memory)
228
        cm_out = torch.matmul(cm_in.unsqueeze(1), c_embedding_E.permute(0, 2, 1))
229
        cm_prob = torch.softmax(cm_out, dim=-1)
230
        cm_contex = torch.matmul(cm_prob, c_embedding_F)
231
        cm_gate_prob = torch.sigmoid(self.gate_linear2(cm_in)).unsqueeze(1)
232
        cm_dout = cm_contex * cm_gate_prob + cm_in.unsqueeze(1) * (1 - cm_gate_prob)
233
        output = wm_dout + cm_dout
234
        output = self.output_mlp(output.squeeze() + s)
235
        return self.sigmoid(output)
236
237
238
class Mufasa(nn.Module):
239
240
    def __init__(self, vocab_size1, vocab_size2, vocab_size3, d_model, dropout=0.1, dropout_emb=0.1, length=48):
241
        super().__init__()
242
        self.embbedding1 = nn.Sequential(nn.Conv1d(in_channels=1, out_channels=1, kernel_size=10, stride=5),
243
                                         nn.ReLU(),
244
                                         nn.Linear((vocab_size1 - 10) // 5 + 1, d_model))
245
        self.embbedding2 = nn.Sequential(nn.Conv1d(in_channels=1, out_channels=1, kernel_size=10, stride=5),
246
                                         nn.ReLU(),
247
                                         nn.Linear((vocab_size2 - 10) // 5 + 1, d_model))
248
        self.linear = nn.Linear(vocab_size3, d_model)
249
        self.linear_conti = nn.Linear(d_model, d_model)
250
        self.linear_cate = nn.Linear(2*d_model, d_model)
251
        self.linears = nn.Linear(2 * d_model, d_model)
252
        self.linear_late = nn.Sequential(nn.Linear(d_model, d_model), nn.ReLU(inplace=False))
253
        self.dense = nn.Sequential(nn.Linear(d_model, 4*d_model), nn.ReLU(inplace=False), nn.Linear(4*d_model, d_model))
254
        self.relu = nn.ReLU(inplace=False)
255
        self.layernorm = nn.LayerNorm(d_model)
256
        self.layernorm2 = nn.LayerNorm(d_model)
257
        self.layernorm3 = nn.LayerNorm(d_model)
258
        self.self_att = SelfAttention(d_model)
259
        self.self_att2 = SelfAttention(d_model)
260
        self.conv = nn.Conv1d(d_model, d_model, 3, padding=1)
261
        self.leaky = nn.LeakyReLU(inplace=False)
262
        self.pooler = MaxPoolLayer()
263
        self.output_mlp = nn.Sequential(nn.Linear(d_model, 1))
264
        self.sigmoid = nn.Sigmoid()
265
266
267
    def forward(self, x1, x2, s):
268
        bs, l, fdim = x1.size()
269
        x1 = x1.view(bs * l, -1).unsqueeze(1).clone()
270
        x2 = x2.view(bs * l, -1).unsqueeze(1).clone()
271
        x1 = self.embbedding1(x1)
272
        x2 = self.embbedding2(x2)
273
        x1 = x1.squeeze().view(bs, l, -1)
274
        x2 = x2.squeeze().view(bs, l, -1)
275
        s = self.linear(s)
276
        continues_res = x2
277
        continues_hs = self.layernorm(x2)
278
        continues_hs = self.self_att(continues_hs, None, None)
279
        continues_hs = self.leaky(continues_hs)
280
        continues_hs = continues_res + continues_hs
281
        continuous_res = continues_hs
282
        continues_hs = self.layernorm(continues_hs)
283
        continues_hs = self.linear_conti(continues_hs)
284
        continues_hs = self.relu(continues_hs)
285
        continues_hs = continuous_res + continues_hs
286
        categorical_res = x1
287
        categorical_hs = self.layernorm2(x1)
288
        categorical_hs = self.self_att2(categorical_hs, None, None)
289
        categorical_hs = torch.cat((categorical_hs, categorical_res), dim=-1)
290
        categorical_res = categorical_hs.clone()
291
        categorical_hs = self.linear_cate(categorical_hs)
292
        categorical_hs = self.relu(categorical_hs)
293
        categorical_res = self.linears(categorical_res)
294
        categorical_hybrid_point = categorical_hs + categorical_res
295
        categorical_late_point = self.linear_late(categorical_res)
296
        temp = s.unsqueeze(1).clone()
297
        fusion_hs = temp.expand_as(categorical_hybrid_point) + categorical_hybrid_point
298
        fusion_res = fusion_hs
299
        fusion_hs = self.layernorm3(fusion_hs)
300
        fusion_branch = self.conv(fusion_hs.permute(0, 2, 1)).permute(0, 2, 1)
301
        out = fusion_res + fusion_hs + fusion_branch + categorical_late_point + continues_hs
302
        out = self.pooler(out)
303
        out = self.output_mlp(out)
304
        return self.sigmoid(out)
305
306
if __name__ == '__main__':
307
    model = Transformer(1318, 73, 256)
308
    x1 = torch.randn((32, 48, 1318))
309
    s = torch.randn((32, 73))
310
    print(model(x1, s).size())