a b/patient/models/adacare.py
1
import torch
2
from torch import nn
3
import torch.nn.utils.rnn as rnn_utils
4
from torch.utils import data
5
from torch.autograd import Variable
6
#from baseline import *
7
from models.baseline import *
8
import torch.nn.functional as F
9
10
class Sparsemax(nn.Module):
11
    """Sparsemax function."""
12
13
    def __init__(self, dim=None):
14
        super(Sparsemax, self).__init__()
15
16
        self.dim = -1 if dim is None else dim
17
18
    def forward(self, input, device='cuda'):
19
        original_size = input.size()
20
        input = input.view(-1, input.size(self.dim))
21
22
        dim = 1
23
        number_of_logits = input.size(dim)
24
25
        input = input - torch.max(input, dim=dim, keepdim=True)[0].expand_as(input)
26
27
        zs = torch.sort(input=input, dim=dim, descending=True)[0]
28
        range = torch.arange(start=1, end=number_of_logits + 1, device=device, dtype=torch.float32).view(1, -1)
29
        range = range.expand_as(zs)
30
31
        bound = 1 + range * zs
32
        cumulative_sum_zs = torch.cumsum(zs, dim)
33
        is_gt = torch.gt(bound, cumulative_sum_zs).type(input.type())
34
        k = torch.max(is_gt * range, dim, keepdim=True)[0]
35
36
        zs_sparse = is_gt * zs
37
        taus = (torch.sum(zs_sparse, dim, keepdim=True) - 1) / k
38
        taus = taus.expand_as(input)
39
40
        self.output = torch.max(torch.zeros_like(input), input - taus)
41
42
        output = self.output.view(original_size)
43
44
        return output
45
46
    def backward(self, grad_output):
47
        dim = 1
48
49
        nonzeros = torch.ne(self.output, 0)
50
        sum = torch.sum(grad_output * nonzeros, dim=dim) / torch.sum(nonzeros, dim=dim)
51
        self.grad_input = nonzeros * (grad_output - sum.expand_as(grad_output))
52
53
        return self.grad_input
54
55
56
class CausalConv1d(torch.nn.Conv1d):
57
    def __init__(self,
58
                 in_channels,
59
                 out_channels,
60
                 kernel_size,
61
                 stride=1,
62
                 dilation=1,
63
                 groups=1,
64
                 bias=True):
65
        self.__padding = (kernel_size - 1) * dilation
66
67
        super(CausalConv1d, self).__init__(
68
            in_channels,
69
            out_channels,
70
            kernel_size=kernel_size,
71
            stride=stride,
72
            padding=self.__padding,
73
            dilation=dilation,
74
            groups=groups,
75
            bias=bias)
76
77
    def forward(self, input):
78
        result = super(CausalConv1d, self).forward(input)
79
        if self.__padding != 0:
80
            return result[:, :, :-self.__padding]
81
        return result
82
83
84
class Recalibration(nn.Module):
85
    def __init__(self, channel, reduction=9, use_h=True, use_c=True, activation='sigmoid'):
86
        super(Recalibration, self).__init__()
87
        self.avg_pool = nn.AdaptiveAvgPool1d(1)
88
        self.use_h = use_h
89
        self.use_c = use_c
90
        scale_dim = 0
91
        self.activation = activation
92
93
        self.nn_c = nn.Linear(channel, channel // reduction)
94
        scale_dim += channel // reduction
95
96
        self.nn_rescale = nn.Linear(scale_dim, channel)
97
        self.sparsemax = Sparsemax(dim=1)
98
99
    def forward(self, x, device='cuda'):
100
        b, c, t = x.size()
101
102
        y_origin = x[:, :, -1]
103
        se_c = self.nn_c(y_origin)
104
        se_c = torch.relu(se_c)
105
        y = se_c
106
107
        y = self.nn_rescale(y).view(b, c, 1)
108
        if self.activation == 'sigmoid':
109
            y = torch.sigmoid(y)
110
        else:
111
            y = self.sparsemax(y, device)
112
        return x * y.expand_as(x), y
113
114
115
class AdaCare(nn.Module):
116
    def __init__(self, vocab_size, hidden_dim=128, kernel_size=2, kernel_num=64, input_dim=128, output_dim=1, dropout=0.5, r_v=4,
117
                 r_c=4, activation='sigmoid', pretrain = False):
118
        super(AdaCare, self).__init__()
119
        # self.embedding = nn.Embedding(vocab_size + 1, hidden_dim, padding_idx=-1)
120
        # self.embedding = nn.Sequential(nn.Linear(vocab_size + 1, hidden_dim), nn.ReLU())
121
        self.embbedding1 =  nn.Sequential(nn.Linear(vocab_size, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU())
122
        self.hidden_dim = hidden_dim
123
        self.kernel_size = kernel_size
124
        self.kernel_num = kernel_num
125
        self.input_dim = input_dim
126
        self.output_dim = output_dim
127
        self.dropout = dropout
128
129
        self.nn_conv1 = CausalConv1d(input_dim, kernel_num, kernel_size, 1, 1)
130
        self.nn_conv3 = CausalConv1d(input_dim, kernel_num, kernel_size, 1, 3)
131
        self.nn_conv5 = CausalConv1d(input_dim, kernel_num, kernel_size, 1, 5)
132
        torch.nn.init.xavier_uniform_(self.nn_conv1.weight)
133
        torch.nn.init.xavier_uniform_(self.nn_conv3.weight)
134
        torch.nn.init.xavier_uniform_(self.nn_conv5.weight)
135
136
        self.nn_convse = Recalibration(3 * kernel_num, r_c, use_h=False, use_c=True, activation='sigmoid')
137
        self.nn_inputse = Recalibration(input_dim, r_v, use_h=False, use_c=True, activation=activation)
138
        self.rnn = nn.GRUCell(input_dim + 3 * kernel_num, hidden_dim)
139
        self.nn_output = nn.Linear(hidden_dim, 2)
140
        self.nn_dropout = nn.Dropout(dropout)
141
142
        self.relu = nn.ReLU()
143
        self.sigmoid = nn.Sigmoid()
144
        self.tanh = nn.Tanh()
145
        self.pooler = MaxPoolLayer()
146
        self.pretrain = pretrain
147
        if self.pretrain == True or self.pretrain == "True":
148
          print(self.state_dict()["embbedding1.0.weight"])
149
          pt_emb = torch.load("/home/xmw5190/HMMP/model/HADM_AE_200_5_28.p").state_dict()
150
          
151
          state_dict = {k.split('enc.')[-1]:v for k,v in pt_emb.items() if k.split('enc.')[-1] in self.state_dict().keys() and "embbedding1" in k}
152
          print(state_dict.keys())
153
          self.state_dict().update(state_dict)
154
          self.load_state_dict(state_dict, strict = False)
155
          print(self.state_dict()["embbedding1.0.weight"])
156
        # if self.pretrain == True or self.pretrain == "True":
157
        #   print(self.state_dict()["embedding.0.weight"])
158
        #   pt_emb = torch.load("/home/xmw5190/HMMP/model/HADM_AE_200_new_mask_lowlr.p").state_dict()
159
          
160
        #   weight, bias = pt_emb['enc.embbedding1.0.weight'], pt_emb["enc.embbedding1.0.bias"]
161
        #   self.state_dict().update({"embedding.0.weight":weight})
162
        #   self.state_dict().update({"embedding.0.bias":bias})
163
        #   self.load_state_dict({"embedding.0.weight":weight, "embedding.0.bias":bias},strict = False)
164
        #   print(self.state_dict()["embedding.0.weight"])
165
          
166
167
    def forward(self, input, device):
168
        # input shape [batch_size, timestep, feature_dim]
169
        batch_size = input.size(0)
170
        time_step = input.size(1)
171
        # feature_dim = input.size(2)
172
        input = self.embbedding1(input)#.sum(dim=2)
173
174
        cur_h = Variable(torch.zeros(batch_size, self.hidden_dim)).to(device)
175
        inputse_att = []
176
        convse_att = []
177
        h = []
178
179
        conv_input = input.permute(0, 2, 1)
180
        conv_res1 = self.nn_conv1(conv_input)
181
        conv_res3 = self.nn_conv3(conv_input)
182
        conv_res5 = self.nn_conv5(conv_input)
183
184
        conv_res = torch.cat((conv_res1, conv_res3, conv_res5), dim=1)
185
        conv_res = self.relu(conv_res)
186
187
        for cur_time in range(time_step):
188
            convse_res, cur_convatt = self.nn_convse(conv_res[:, :, :cur_time + 1], device=device)
189
            inputse_res, cur_inputatt = self.nn_inputse(input[:, :cur_time + 1, :].permute(0, 2, 1), device=device)
190
            cur_input = torch.cat((convse_res[:, :, -1], inputse_res[:, :, -1]), dim=-1)
191
192
            cur_h = self.rnn(cur_input, cur_h)
193
            h.append(cur_h)
194
            convse_att.append(cur_convatt)
195
            inputse_att.append(cur_inputatt)
196
197
        h = torch.stack(h).permute(1, 0, 2)
198
        h_reshape = h.contiguous().view(batch_size, time_step, self.hidden_dim)
199
        if self.dropout > 0.0:
200
            h_reshape = self.nn_dropout(h_reshape)
201
        output = self.pooler(h_reshape)
202
        output = self.nn_output(output)
203
        return output
204
205
if __name__ == '__main__':
206
   model = AdaCare(vocab_size = 7687, hidden_dim=64,  kernel_size=2, kernel_num=64, input_dim=64, output_dim=1, dropout=0.5, r_v=4,
207
                 r_c=4, activation='sigmoid', pretrain = "False")