|
a |
|
b/patient/models/adacare.py |
|
|
1 |
import torch |
|
|
2 |
from torch import nn |
|
|
3 |
import torch.nn.utils.rnn as rnn_utils |
|
|
4 |
from torch.utils import data |
|
|
5 |
from torch.autograd import Variable |
|
|
6 |
#from baseline import * |
|
|
7 |
from models.baseline import * |
|
|
8 |
import torch.nn.functional as F |
|
|
9 |
|
|
|
10 |
class Sparsemax(nn.Module): |
|
|
11 |
"""Sparsemax function.""" |
|
|
12 |
|
|
|
13 |
def __init__(self, dim=None): |
|
|
14 |
super(Sparsemax, self).__init__() |
|
|
15 |
|
|
|
16 |
self.dim = -1 if dim is None else dim |
|
|
17 |
|
|
|
18 |
def forward(self, input, device='cuda'): |
|
|
19 |
original_size = input.size() |
|
|
20 |
input = input.view(-1, input.size(self.dim)) |
|
|
21 |
|
|
|
22 |
dim = 1 |
|
|
23 |
number_of_logits = input.size(dim) |
|
|
24 |
|
|
|
25 |
input = input - torch.max(input, dim=dim, keepdim=True)[0].expand_as(input) |
|
|
26 |
|
|
|
27 |
zs = torch.sort(input=input, dim=dim, descending=True)[0] |
|
|
28 |
range = torch.arange(start=1, end=number_of_logits + 1, device=device, dtype=torch.float32).view(1, -1) |
|
|
29 |
range = range.expand_as(zs) |
|
|
30 |
|
|
|
31 |
bound = 1 + range * zs |
|
|
32 |
cumulative_sum_zs = torch.cumsum(zs, dim) |
|
|
33 |
is_gt = torch.gt(bound, cumulative_sum_zs).type(input.type()) |
|
|
34 |
k = torch.max(is_gt * range, dim, keepdim=True)[0] |
|
|
35 |
|
|
|
36 |
zs_sparse = is_gt * zs |
|
|
37 |
taus = (torch.sum(zs_sparse, dim, keepdim=True) - 1) / k |
|
|
38 |
taus = taus.expand_as(input) |
|
|
39 |
|
|
|
40 |
self.output = torch.max(torch.zeros_like(input), input - taus) |
|
|
41 |
|
|
|
42 |
output = self.output.view(original_size) |
|
|
43 |
|
|
|
44 |
return output |
|
|
45 |
|
|
|
46 |
def backward(self, grad_output): |
|
|
47 |
dim = 1 |
|
|
48 |
|
|
|
49 |
nonzeros = torch.ne(self.output, 0) |
|
|
50 |
sum = torch.sum(grad_output * nonzeros, dim=dim) / torch.sum(nonzeros, dim=dim) |
|
|
51 |
self.grad_input = nonzeros * (grad_output - sum.expand_as(grad_output)) |
|
|
52 |
|
|
|
53 |
return self.grad_input |
|
|
54 |
|
|
|
55 |
|
|
|
56 |
class CausalConv1d(torch.nn.Conv1d): |
|
|
57 |
def __init__(self, |
|
|
58 |
in_channels, |
|
|
59 |
out_channels, |
|
|
60 |
kernel_size, |
|
|
61 |
stride=1, |
|
|
62 |
dilation=1, |
|
|
63 |
groups=1, |
|
|
64 |
bias=True): |
|
|
65 |
self.__padding = (kernel_size - 1) * dilation |
|
|
66 |
|
|
|
67 |
super(CausalConv1d, self).__init__( |
|
|
68 |
in_channels, |
|
|
69 |
out_channels, |
|
|
70 |
kernel_size=kernel_size, |
|
|
71 |
stride=stride, |
|
|
72 |
padding=self.__padding, |
|
|
73 |
dilation=dilation, |
|
|
74 |
groups=groups, |
|
|
75 |
bias=bias) |
|
|
76 |
|
|
|
77 |
def forward(self, input): |
|
|
78 |
result = super(CausalConv1d, self).forward(input) |
|
|
79 |
if self.__padding != 0: |
|
|
80 |
return result[:, :, :-self.__padding] |
|
|
81 |
return result |
|
|
82 |
|
|
|
83 |
|
|
|
84 |
class Recalibration(nn.Module): |
|
|
85 |
def __init__(self, channel, reduction=9, use_h=True, use_c=True, activation='sigmoid'): |
|
|
86 |
super(Recalibration, self).__init__() |
|
|
87 |
self.avg_pool = nn.AdaptiveAvgPool1d(1) |
|
|
88 |
self.use_h = use_h |
|
|
89 |
self.use_c = use_c |
|
|
90 |
scale_dim = 0 |
|
|
91 |
self.activation = activation |
|
|
92 |
|
|
|
93 |
self.nn_c = nn.Linear(channel, channel // reduction) |
|
|
94 |
scale_dim += channel // reduction |
|
|
95 |
|
|
|
96 |
self.nn_rescale = nn.Linear(scale_dim, channel) |
|
|
97 |
self.sparsemax = Sparsemax(dim=1) |
|
|
98 |
|
|
|
99 |
def forward(self, x, device='cuda'): |
|
|
100 |
b, c, t = x.size() |
|
|
101 |
|
|
|
102 |
y_origin = x[:, :, -1] |
|
|
103 |
se_c = self.nn_c(y_origin) |
|
|
104 |
se_c = torch.relu(se_c) |
|
|
105 |
y = se_c |
|
|
106 |
|
|
|
107 |
y = self.nn_rescale(y).view(b, c, 1) |
|
|
108 |
if self.activation == 'sigmoid': |
|
|
109 |
y = torch.sigmoid(y) |
|
|
110 |
else: |
|
|
111 |
y = self.sparsemax(y, device) |
|
|
112 |
return x * y.expand_as(x), y |
|
|
113 |
|
|
|
114 |
|
|
|
115 |
class AdaCare(nn.Module): |
|
|
116 |
def __init__(self, vocab_size, hidden_dim=128, kernel_size=2, kernel_num=64, input_dim=128, output_dim=1, dropout=0.5, r_v=4, |
|
|
117 |
r_c=4, activation='sigmoid', pretrain = False): |
|
|
118 |
super(AdaCare, self).__init__() |
|
|
119 |
# self.embedding = nn.Embedding(vocab_size + 1, hidden_dim, padding_idx=-1) |
|
|
120 |
# self.embedding = nn.Sequential(nn.Linear(vocab_size + 1, hidden_dim), nn.ReLU()) |
|
|
121 |
self.embbedding1 = nn.Sequential(nn.Linear(vocab_size, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU()) |
|
|
122 |
self.hidden_dim = hidden_dim |
|
|
123 |
self.kernel_size = kernel_size |
|
|
124 |
self.kernel_num = kernel_num |
|
|
125 |
self.input_dim = input_dim |
|
|
126 |
self.output_dim = output_dim |
|
|
127 |
self.dropout = dropout |
|
|
128 |
|
|
|
129 |
self.nn_conv1 = CausalConv1d(input_dim, kernel_num, kernel_size, 1, 1) |
|
|
130 |
self.nn_conv3 = CausalConv1d(input_dim, kernel_num, kernel_size, 1, 3) |
|
|
131 |
self.nn_conv5 = CausalConv1d(input_dim, kernel_num, kernel_size, 1, 5) |
|
|
132 |
torch.nn.init.xavier_uniform_(self.nn_conv1.weight) |
|
|
133 |
torch.nn.init.xavier_uniform_(self.nn_conv3.weight) |
|
|
134 |
torch.nn.init.xavier_uniform_(self.nn_conv5.weight) |
|
|
135 |
|
|
|
136 |
self.nn_convse = Recalibration(3 * kernel_num, r_c, use_h=False, use_c=True, activation='sigmoid') |
|
|
137 |
self.nn_inputse = Recalibration(input_dim, r_v, use_h=False, use_c=True, activation=activation) |
|
|
138 |
self.rnn = nn.GRUCell(input_dim + 3 * kernel_num, hidden_dim) |
|
|
139 |
self.nn_output = nn.Linear(hidden_dim, 2) |
|
|
140 |
self.nn_dropout = nn.Dropout(dropout) |
|
|
141 |
|
|
|
142 |
self.relu = nn.ReLU() |
|
|
143 |
self.sigmoid = nn.Sigmoid() |
|
|
144 |
self.tanh = nn.Tanh() |
|
|
145 |
self.pooler = MaxPoolLayer() |
|
|
146 |
self.pretrain = pretrain |
|
|
147 |
if self.pretrain == True or self.pretrain == "True": |
|
|
148 |
print(self.state_dict()["embbedding1.0.weight"]) |
|
|
149 |
pt_emb = torch.load("/home/xmw5190/HMMP/model/HADM_AE_200_5_28.p").state_dict() |
|
|
150 |
|
|
|
151 |
state_dict = {k.split('enc.')[-1]:v for k,v in pt_emb.items() if k.split('enc.')[-1] in self.state_dict().keys() and "embbedding1" in k} |
|
|
152 |
print(state_dict.keys()) |
|
|
153 |
self.state_dict().update(state_dict) |
|
|
154 |
self.load_state_dict(state_dict, strict = False) |
|
|
155 |
print(self.state_dict()["embbedding1.0.weight"]) |
|
|
156 |
# if self.pretrain == True or self.pretrain == "True": |
|
|
157 |
# print(self.state_dict()["embedding.0.weight"]) |
|
|
158 |
# pt_emb = torch.load("/home/xmw5190/HMMP/model/HADM_AE_200_new_mask_lowlr.p").state_dict() |
|
|
159 |
|
|
|
160 |
# weight, bias = pt_emb['enc.embbedding1.0.weight'], pt_emb["enc.embbedding1.0.bias"] |
|
|
161 |
# self.state_dict().update({"embedding.0.weight":weight}) |
|
|
162 |
# self.state_dict().update({"embedding.0.bias":bias}) |
|
|
163 |
# self.load_state_dict({"embedding.0.weight":weight, "embedding.0.bias":bias},strict = False) |
|
|
164 |
# print(self.state_dict()["embedding.0.weight"]) |
|
|
165 |
|
|
|
166 |
|
|
|
167 |
def forward(self, input, device): |
|
|
168 |
# input shape [batch_size, timestep, feature_dim] |
|
|
169 |
batch_size = input.size(0) |
|
|
170 |
time_step = input.size(1) |
|
|
171 |
# feature_dim = input.size(2) |
|
|
172 |
input = self.embbedding1(input)#.sum(dim=2) |
|
|
173 |
|
|
|
174 |
cur_h = Variable(torch.zeros(batch_size, self.hidden_dim)).to(device) |
|
|
175 |
inputse_att = [] |
|
|
176 |
convse_att = [] |
|
|
177 |
h = [] |
|
|
178 |
|
|
|
179 |
conv_input = input.permute(0, 2, 1) |
|
|
180 |
conv_res1 = self.nn_conv1(conv_input) |
|
|
181 |
conv_res3 = self.nn_conv3(conv_input) |
|
|
182 |
conv_res5 = self.nn_conv5(conv_input) |
|
|
183 |
|
|
|
184 |
conv_res = torch.cat((conv_res1, conv_res3, conv_res5), dim=1) |
|
|
185 |
conv_res = self.relu(conv_res) |
|
|
186 |
|
|
|
187 |
for cur_time in range(time_step): |
|
|
188 |
convse_res, cur_convatt = self.nn_convse(conv_res[:, :, :cur_time + 1], device=device) |
|
|
189 |
inputse_res, cur_inputatt = self.nn_inputse(input[:, :cur_time + 1, :].permute(0, 2, 1), device=device) |
|
|
190 |
cur_input = torch.cat((convse_res[:, :, -1], inputse_res[:, :, -1]), dim=-1) |
|
|
191 |
|
|
|
192 |
cur_h = self.rnn(cur_input, cur_h) |
|
|
193 |
h.append(cur_h) |
|
|
194 |
convse_att.append(cur_convatt) |
|
|
195 |
inputse_att.append(cur_inputatt) |
|
|
196 |
|
|
|
197 |
h = torch.stack(h).permute(1, 0, 2) |
|
|
198 |
h_reshape = h.contiguous().view(batch_size, time_step, self.hidden_dim) |
|
|
199 |
if self.dropout > 0.0: |
|
|
200 |
h_reshape = self.nn_dropout(h_reshape) |
|
|
201 |
output = self.pooler(h_reshape) |
|
|
202 |
output = self.nn_output(output) |
|
|
203 |
return output |
|
|
204 |
|
|
|
205 |
if __name__ == '__main__': |
|
|
206 |
model = AdaCare(vocab_size = 7687, hidden_dim=64, kernel_size=2, kernel_num=64, input_dim=64, output_dim=1, dropout=0.5, r_v=4, |
|
|
207 |
r_c=4, activation='sigmoid', pretrain = "False") |