|
a |
|
b/model.py |
|
|
1 |
import torch |
|
|
2 |
import numpy as np |
|
|
3 |
import torch.nn as nn |
|
|
4 |
import torch.nn.functional as F |
|
|
5 |
import copy |
|
|
6 |
|
|
|
7 |
if torch.cuda.is_available(): |
|
|
8 |
device = 'cuda' |
|
|
9 |
else: |
|
|
10 |
device = 'cpu' |
|
|
11 |
print(device) |
|
|
12 |
|
|
|
13 |
def clones(module, N): |
|
|
14 |
return nn.ModuleList([copy.deepcopy(module) for _ in range(N)]) |
|
|
15 |
|
|
|
16 |
|
|
|
17 |
def clone_params(param, N): |
|
|
18 |
return nn.ParameterList([copy.deepcopy(param) for _ in range(N)]) |
|
|
19 |
|
|
|
20 |
|
|
|
21 |
class LayerNorm(nn.Module): |
|
|
22 |
def __init__(self, features, eps=1e-6): |
|
|
23 |
super(LayerNorm, self).__init__() |
|
|
24 |
self.a_2 = nn.Parameter(torch.ones(features)) |
|
|
25 |
self.b_2 = nn.Parameter(torch.zeros(features)) |
|
|
26 |
self.eps = eps |
|
|
27 |
|
|
|
28 |
def forward(self, x): |
|
|
29 |
mean = x.mean(-1, keepdim=True) |
|
|
30 |
std = x.std(-1, keepdim=True) |
|
|
31 |
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 |
|
|
32 |
|
|
|
33 |
|
|
|
34 |
class GraphLayer(nn.Module): |
|
|
35 |
|
|
|
36 |
def __init__(self, in_features, hidden_features, out_features, num_of_nodes, |
|
|
37 |
num_of_heads, dropout, alpha, concat=True): |
|
|
38 |
super(GraphLayer, self).__init__() |
|
|
39 |
self.in_features = in_features |
|
|
40 |
self.hidden_features = hidden_features |
|
|
41 |
self.out_features = out_features |
|
|
42 |
self.alpha = alpha |
|
|
43 |
self.concat = concat |
|
|
44 |
self.num_of_nodes = num_of_nodes |
|
|
45 |
self.num_of_heads = num_of_heads |
|
|
46 |
self.W = clones(nn.Linear(in_features, hidden_features), num_of_heads) |
|
|
47 |
self.a = clone_params(nn.Parameter(torch.rand(size=(1, 2 * hidden_features)), requires_grad=True), num_of_heads) |
|
|
48 |
self.ffn = nn.Sequential( |
|
|
49 |
nn.Linear(out_features, out_features), |
|
|
50 |
nn.ReLU() |
|
|
51 |
) |
|
|
52 |
if not concat: |
|
|
53 |
self.V = nn.Linear(hidden_features, out_features) |
|
|
54 |
else: |
|
|
55 |
self.V = nn.Linear(num_of_heads * hidden_features, out_features) |
|
|
56 |
self.dropout = nn.Dropout(dropout) |
|
|
57 |
self.leakyrelu = nn.LeakyReLU(self.alpha) |
|
|
58 |
if concat: |
|
|
59 |
self.norm = LayerNorm(hidden_features) |
|
|
60 |
else: |
|
|
61 |
self.norm = LayerNorm(hidden_features) |
|
|
62 |
|
|
|
63 |
def initialize(self): |
|
|
64 |
for i in range(len(self.W)): |
|
|
65 |
nn.init.xavier_normal_(self.W[i].weight.data) |
|
|
66 |
for i in range(len(self.a)): |
|
|
67 |
nn.init.xavier_normal_(self.a[i].data) |
|
|
68 |
if not self.concat: |
|
|
69 |
nn.init.xavier_normal_(self.V.weight.data) |
|
|
70 |
nn.init.xavier_normal_(self.out_layer.weight.data) |
|
|
71 |
|
|
|
72 |
def attention(self, linear, a, N, data, edge): |
|
|
73 |
data = linear(data).unsqueeze(0) |
|
|
74 |
assert not torch.isnan(data).any() |
|
|
75 |
# edge: 2*D x E |
|
|
76 |
h = torch.cat((data[:, edge[0, :], :], data[:, edge[1, :], :]), dim=0) |
|
|
77 |
data = data.squeeze(0) |
|
|
78 |
# h: N x out |
|
|
79 |
assert not torch.isnan(h).any() |
|
|
80 |
# edge_h: 2*D x E |
|
|
81 |
edge_h = torch.cat((h[0, :, :], h[1, :, :]), dim=1).transpose(0, 1) |
|
|
82 |
# edge: 2*D x E |
|
|
83 |
edge_e = torch.exp(self.leakyrelu(a.mm(edge_h).squeeze()) / np.sqrt(self.hidden_features * self.num_of_heads)) |
|
|
84 |
assert not torch.isnan(edge_e).any() |
|
|
85 |
# edge_e: E |
|
|
86 |
edge_e = torch.sparse_coo_tensor(edge, edge_e, torch.Size([N, N])) |
|
|
87 |
e_rowsum = torch.sparse.mm(edge_e, torch.ones(size=(N, 1)).to(device)) |
|
|
88 |
# e_rowsum: N x 1 |
|
|
89 |
row_check = (e_rowsum == 0) |
|
|
90 |
e_rowsum[row_check] = 1 |
|
|
91 |
zero_idx = row_check.nonzero()[:, 0] |
|
|
92 |
edge_e = edge_e.add( |
|
|
93 |
torch.sparse.FloatTensor(zero_idx.repeat(2, 1), torch.ones(len(zero_idx)).to(device), torch.Size([N, N]))) |
|
|
94 |
# edge_e: E |
|
|
95 |
h_prime = torch.sparse.mm(edge_e, data) |
|
|
96 |
assert not torch.isnan(h_prime).any() |
|
|
97 |
# h_prime: N x out |
|
|
98 |
h_prime.div_(e_rowsum) |
|
|
99 |
# h_prime: N x out |
|
|
100 |
assert not torch.isnan(h_prime).any() |
|
|
101 |
return h_prime |
|
|
102 |
|
|
|
103 |
def forward(self, edge, data=None): |
|
|
104 |
N = self.num_of_nodes |
|
|
105 |
if self.concat: |
|
|
106 |
h_prime = torch.cat([self.attention(l, a, N, data, edge) for l, a in zip(self.W, self.a)], dim=1) |
|
|
107 |
else: |
|
|
108 |
h_prime = torch.stack([self.attention(l, a, N, data, edge) for l, a in zip(self.W, self.a)], dim=0).mean( |
|
|
109 |
dim=0) |
|
|
110 |
h_prime = self.dropout(h_prime) |
|
|
111 |
if self.concat: |
|
|
112 |
return F.elu(self.norm(h_prime)) |
|
|
113 |
else: |
|
|
114 |
return self.V(F.relu(self.norm(h_prime))) |
|
|
115 |
|
|
|
116 |
|
|
|
117 |
class VariationalGNN(nn.Module): |
|
|
118 |
|
|
|
119 |
def __init__(self, in_features, out_features, num_of_nodes, n_heads, n_layers, |
|
|
120 |
dropout, alpha, variational=True, none_graph_features=0, concat=True): |
|
|
121 |
super(VariationalGNN, self).__init__() |
|
|
122 |
self.variational = variational |
|
|
123 |
self.num_of_nodes = num_of_nodes + 1 - none_graph_features |
|
|
124 |
self.embed = nn.Embedding(self.num_of_nodes, in_features, padding_idx=0) |
|
|
125 |
self.in_att = clones( |
|
|
126 |
GraphLayer(in_features, in_features, in_features, self.num_of_nodes, |
|
|
127 |
n_heads, dropout, alpha, concat=True), n_layers) |
|
|
128 |
self.out_features = out_features |
|
|
129 |
self.out_att = GraphLayer(in_features, in_features, out_features, self.num_of_nodes, |
|
|
130 |
n_heads, dropout, alpha, concat=False) |
|
|
131 |
self.n_heads = n_heads |
|
|
132 |
self.dropout = nn.Dropout(dropout) |
|
|
133 |
self.parameterize = nn.Linear(out_features, out_features * 2) |
|
|
134 |
self.out_layer = nn.Sequential( |
|
|
135 |
nn.Linear(out_features, out_features), |
|
|
136 |
nn.ReLU(), |
|
|
137 |
nn.Dropout(dropout), |
|
|
138 |
nn.Linear(out_features, 1)) |
|
|
139 |
self.none_graph_features = none_graph_features |
|
|
140 |
if none_graph_features > 0: |
|
|
141 |
self.features_ffn = nn.Sequential( |
|
|
142 |
nn.Linear(none_graph_features, out_features//2), |
|
|
143 |
nn.ReLU(), |
|
|
144 |
nn.Dropout(dropout)) |
|
|
145 |
self.out_layer = nn.Sequential( |
|
|
146 |
nn.Linear(out_features + out_features//2, out_features), |
|
|
147 |
nn.ReLU(), |
|
|
148 |
nn.Dropout(dropout), |
|
|
149 |
nn.Linear(out_features, 1)) |
|
|
150 |
for i in range(n_layers): |
|
|
151 |
self.in_att[i].initialize() |
|
|
152 |
|
|
|
153 |
def data_to_edges(self, data): |
|
|
154 |
data = data.bool() |
|
|
155 |
length = data.size()[0] |
|
|
156 |
nonzero = data.nonzero() |
|
|
157 |
if nonzero.size()[0] == 0: |
|
|
158 |
return torch.LongTensor([[0], [0]]), torch.LongTensor([[length + 1], [length + 1]]) |
|
|
159 |
if self.training: |
|
|
160 |
mask = torch.rand(nonzero.size()[0]) |
|
|
161 |
mask = mask > 0.05 |
|
|
162 |
nonzero = nonzero[mask] |
|
|
163 |
if nonzero.size()[0] == 0: |
|
|
164 |
return torch.LongTensor([[0], [0]]), torch.LongTensor([[length + 1], [length + 1]]) |
|
|
165 |
nonzero = nonzero.transpose(0, 1) + 1 |
|
|
166 |
lengths = nonzero.size()[1] |
|
|
167 |
input_edges = torch.cat((nonzero.repeat(1, lengths), |
|
|
168 |
nonzero.repeat(lengths, 1).transpose(0, 1) |
|
|
169 |
.contiguous().view((1, lengths ** 2))), dim=0) |
|
|
170 |
|
|
|
171 |
nonzero = torch.cat((nonzero, torch.LongTensor([[length + 1]]).to(device)), dim=1) |
|
|
172 |
lengths = nonzero.size()[1] |
|
|
173 |
output_edges = torch.cat((nonzero.repeat(1, lengths), |
|
|
174 |
nonzero.repeat(lengths, 1).transpose(0, 1) |
|
|
175 |
.contiguous().view((1, lengths ** 2))), dim=0) |
|
|
176 |
return input_edges.to(device), output_edges.to(device) |
|
|
177 |
|
|
|
178 |
def reparameterise(self, mu, logvar): |
|
|
179 |
if self.training: |
|
|
180 |
std = logvar.mul(0.5).exp_() |
|
|
181 |
eps = std.data.new(std.size()).normal_() |
|
|
182 |
return eps.mul(std).add_(mu) |
|
|
183 |
else: |
|
|
184 |
return mu |
|
|
185 |
|
|
|
186 |
def encoder_decoder(self, data): |
|
|
187 |
N = self.num_of_nodes |
|
|
188 |
input_edges, output_edges = self.data_to_edges(data) |
|
|
189 |
h_prime = self.embed(torch.arange(N).long().to(device)) |
|
|
190 |
for attn in self.in_att: |
|
|
191 |
h_prime = attn(input_edges, h_prime) |
|
|
192 |
if self.variational: |
|
|
193 |
h_prime = self.parameterize(h_prime).view(-1, 2, self.out_features) |
|
|
194 |
h_prime = self.dropout(h_prime) |
|
|
195 |
mu = h_prime[:, 0, :] |
|
|
196 |
logvar = h_prime[:, 1, :] |
|
|
197 |
h_prime = self.reparameterise(mu, logvar) |
|
|
198 |
mu = mu[data, :] |
|
|
199 |
logvar = logvar[data, :] |
|
|
200 |
h_prime = self.out_att(output_edges, h_prime) |
|
|
201 |
if self.variational: |
|
|
202 |
return h_prime[-1], 0.5 * torch.sum(logvar.exp() - logvar - 1 + mu.pow(2)) / mu.size()[0] |
|
|
203 |
else: |
|
|
204 |
return h_prime[-1], torch.tensor(0.0).to(device) |
|
|
205 |
|
|
|
206 |
def forward(self, data): |
|
|
207 |
# Concate batches |
|
|
208 |
batch_size = data.size()[0] |
|
|
209 |
# In eicu data the first feature whether have be admitted before is not included in the graph |
|
|
210 |
if self.none_graph_features == 0: |
|
|
211 |
outputs = [self.encoder_decoder(data[i, :]) for i in range(batch_size)] |
|
|
212 |
return self.out_layer(F.relu(torch.stack([out[0] for out in outputs]))), \ |
|
|
213 |
torch.sum(torch.stack([out[1] for out in outputs])) |
|
|
214 |
else: |
|
|
215 |
outputs = [(data[i, :self.none_graph_features], |
|
|
216 |
self.encoder_decoder(data[i, self.none_graph_features:])) for i in range(batch_size)] |
|
|
217 |
return self.out_layer(F.relu( |
|
|
218 |
torch.stack([torch.cat((self.features_ffn(torch.FloatTensor([out[0]]).to(device)), out[1][0])) |
|
|
219 |
for out in outputs]))), \ |
|
|
220 |
torch.sum(torch.stack([out[1][1] for out in outputs]), dim=-1) |