Diff of /model.py [000000] .. [0d4320]

Switch to unified view

a b/model.py
1
import torch
2
import numpy as np
3
import torch.nn as nn
4
import torch.nn.functional as F
5
import copy
6
7
if torch.cuda.is_available():
8
    device = 'cuda'
9
else:
10
    device = 'cpu'
11
print(device)
12
13
def clones(module, N):
14
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
15
16
17
def clone_params(param, N):
18
    return nn.ParameterList([copy.deepcopy(param) for _ in range(N)])
19
20
21
class LayerNorm(nn.Module):
22
    def __init__(self, features, eps=1e-6):
23
        super(LayerNorm, self).__init__()
24
        self.a_2 = nn.Parameter(torch.ones(features))
25
        self.b_2 = nn.Parameter(torch.zeros(features))
26
        self.eps = eps
27
28
    def forward(self, x):
29
        mean = x.mean(-1, keepdim=True)
30
        std = x.std(-1, keepdim=True)
31
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
32
33
34
class GraphLayer(nn.Module):
35
36
    def __init__(self, in_features, hidden_features, out_features, num_of_nodes,
37
                 num_of_heads, dropout, alpha, concat=True):
38
        super(GraphLayer, self).__init__()
39
        self.in_features = in_features
40
        self.hidden_features = hidden_features
41
        self.out_features = out_features
42
        self.alpha = alpha
43
        self.concat = concat
44
        self.num_of_nodes = num_of_nodes
45
        self.num_of_heads = num_of_heads
46
        self.W = clones(nn.Linear(in_features, hidden_features), num_of_heads)
47
        self.a = clone_params(nn.Parameter(torch.rand(size=(1, 2 * hidden_features)), requires_grad=True), num_of_heads)
48
        self.ffn = nn.Sequential(
49
            nn.Linear(out_features, out_features),
50
            nn.ReLU()
51
        )
52
        if not concat:
53
            self.V = nn.Linear(hidden_features, out_features)
54
        else:
55
            self.V = nn.Linear(num_of_heads * hidden_features, out_features)
56
        self.dropout = nn.Dropout(dropout)
57
        self.leakyrelu = nn.LeakyReLU(self.alpha)
58
        if concat:
59
            self.norm = LayerNorm(hidden_features)
60
        else:
61
            self.norm = LayerNorm(hidden_features)
62
63
    def initialize(self):
64
        for i in range(len(self.W)):
65
            nn.init.xavier_normal_(self.W[i].weight.data)
66
        for i in range(len(self.a)):
67
            nn.init.xavier_normal_(self.a[i].data)
68
        if not self.concat:
69
            nn.init.xavier_normal_(self.V.weight.data)
70
            nn.init.xavier_normal_(self.out_layer.weight.data)
71
72
    def attention(self, linear, a, N, data, edge):
73
        data = linear(data).unsqueeze(0)
74
        assert not torch.isnan(data).any()
75
        # edge: 2*D x E
76
        h = torch.cat((data[:, edge[0, :], :], data[:, edge[1, :], :]), dim=0)
77
        data = data.squeeze(0)
78
        # h: N x out
79
        assert not torch.isnan(h).any()
80
        # edge_h: 2*D x E
81
        edge_h = torch.cat((h[0, :, :], h[1, :, :]), dim=1).transpose(0, 1)
82
        # edge: 2*D x E
83
        edge_e = torch.exp(self.leakyrelu(a.mm(edge_h).squeeze()) / np.sqrt(self.hidden_features * self.num_of_heads))
84
        assert not torch.isnan(edge_e).any()
85
        # edge_e: E
86
        edge_e = torch.sparse_coo_tensor(edge, edge_e, torch.Size([N, N]))
87
        e_rowsum = torch.sparse.mm(edge_e, torch.ones(size=(N, 1)).to(device))
88
        # e_rowsum: N x 1
89
        row_check = (e_rowsum == 0)
90
        e_rowsum[row_check] = 1
91
        zero_idx = row_check.nonzero()[:, 0]
92
        edge_e = edge_e.add(
93
            torch.sparse.FloatTensor(zero_idx.repeat(2, 1), torch.ones(len(zero_idx)).to(device), torch.Size([N, N])))
94
        # edge_e: E
95
        h_prime = torch.sparse.mm(edge_e, data)
96
        assert not torch.isnan(h_prime).any()
97
        # h_prime: N x out
98
        h_prime.div_(e_rowsum)
99
        # h_prime: N x out
100
        assert not torch.isnan(h_prime).any()
101
        return h_prime
102
103
    def forward(self, edge, data=None):
104
        N = self.num_of_nodes
105
        if self.concat:
106
            h_prime = torch.cat([self.attention(l, a, N, data, edge) for l, a in zip(self.W, self.a)], dim=1)
107
        else:
108
            h_prime = torch.stack([self.attention(l, a, N, data, edge) for l, a in zip(self.W, self.a)], dim=0).mean(
109
                dim=0)
110
        h_prime = self.dropout(h_prime)
111
        if self.concat:
112
            return F.elu(self.norm(h_prime))
113
        else:
114
            return self.V(F.relu(self.norm(h_prime)))
115
116
117
class VariationalGNN(nn.Module):
118
119
    def __init__(self, in_features, out_features, num_of_nodes, n_heads, n_layers,
120
                 dropout, alpha, variational=True, none_graph_features=0, concat=True):
121
        super(VariationalGNN, self).__init__()
122
        self.variational = variational
123
        self.num_of_nodes = num_of_nodes + 1 - none_graph_features
124
        self.embed = nn.Embedding(self.num_of_nodes, in_features, padding_idx=0)
125
        self.in_att = clones(
126
            GraphLayer(in_features, in_features, in_features, self.num_of_nodes,
127
                       n_heads, dropout, alpha, concat=True), n_layers)
128
        self.out_features = out_features
129
        self.out_att = GraphLayer(in_features, in_features, out_features, self.num_of_nodes,
130
                                  n_heads, dropout, alpha, concat=False)
131
        self.n_heads = n_heads
132
        self.dropout = nn.Dropout(dropout)
133
        self.parameterize = nn.Linear(out_features, out_features * 2)
134
        self.out_layer = nn.Sequential(
135
            nn.Linear(out_features, out_features),
136
            nn.ReLU(),
137
            nn.Dropout(dropout),
138
            nn.Linear(out_features, 1))
139
        self.none_graph_features = none_graph_features
140
        if none_graph_features > 0:
141
            self.features_ffn = nn.Sequential(
142
                nn.Linear(none_graph_features, out_features//2),
143
                nn.ReLU(),
144
                nn.Dropout(dropout))
145
            self.out_layer = nn.Sequential(
146
                nn.Linear(out_features + out_features//2, out_features),
147
                nn.ReLU(),
148
                nn.Dropout(dropout),
149
                nn.Linear(out_features, 1))
150
        for i in range(n_layers):
151
            self.in_att[i].initialize()
152
153
    def data_to_edges(self, data):
154
        data = data.bool()
155
        length = data.size()[0]
156
        nonzero = data.nonzero()
157
        if nonzero.size()[0] == 0:
158
            return torch.LongTensor([[0], [0]]), torch.LongTensor([[length + 1], [length + 1]])
159
        if self.training:
160
            mask = torch.rand(nonzero.size()[0])
161
            mask = mask > 0.05
162
            nonzero = nonzero[mask]
163
            if nonzero.size()[0] == 0:
164
                return torch.LongTensor([[0], [0]]), torch.LongTensor([[length + 1], [length + 1]])
165
        nonzero = nonzero.transpose(0, 1) + 1
166
        lengths = nonzero.size()[1]
167
        input_edges = torch.cat((nonzero.repeat(1, lengths),
168
                                 nonzero.repeat(lengths, 1).transpose(0, 1)
169
                                 .contiguous().view((1, lengths ** 2))), dim=0)
170
171
        nonzero = torch.cat((nonzero, torch.LongTensor([[length + 1]]).to(device)), dim=1)
172
        lengths = nonzero.size()[1]
173
        output_edges = torch.cat((nonzero.repeat(1, lengths),
174
                                  nonzero.repeat(lengths, 1).transpose(0, 1)
175
                                  .contiguous().view((1, lengths ** 2))), dim=0)
176
        return input_edges.to(device), output_edges.to(device)
177
178
    def reparameterise(self, mu, logvar):
179
        if self.training:
180
            std = logvar.mul(0.5).exp_()
181
            eps = std.data.new(std.size()).normal_()
182
            return eps.mul(std).add_(mu)
183
        else:
184
            return mu
185
186
    def encoder_decoder(self, data):
187
        N = self.num_of_nodes
188
        input_edges, output_edges = self.data_to_edges(data)
189
        h_prime = self.embed(torch.arange(N).long().to(device))
190
        for attn in self.in_att:
191
            h_prime = attn(input_edges, h_prime)
192
        if self.variational:
193
            h_prime = self.parameterize(h_prime).view(-1, 2, self.out_features)
194
            h_prime = self.dropout(h_prime)
195
            mu = h_prime[:, 0, :]
196
            logvar = h_prime[:, 1, :]
197
            h_prime = self.reparameterise(mu, logvar)
198
            mu = mu[data, :]
199
            logvar = logvar[data, :]
200
        h_prime = self.out_att(output_edges, h_prime)
201
        if self.variational:
202
            return h_prime[-1], 0.5 * torch.sum(logvar.exp() - logvar - 1 + mu.pow(2)) / mu.size()[0]
203
        else:
204
            return h_prime[-1], torch.tensor(0.0).to(device)
205
206
    def forward(self, data):
207
        # Concate batches
208
        batch_size = data.size()[0]
209
        # In eicu data the first feature whether have be admitted before is not included in the graph
210
        if self.none_graph_features == 0:
211
            outputs = [self.encoder_decoder(data[i, :]) for i in range(batch_size)]
212
            return self.out_layer(F.relu(torch.stack([out[0] for out in outputs]))), \
213
                   torch.sum(torch.stack([out[1] for out in outputs]))
214
        else:
215
            outputs = [(data[i, :self.none_graph_features],
216
                        self.encoder_decoder(data[i, self.none_graph_features:])) for i in range(batch_size)]
217
            return self.out_layer(F.relu(
218
                torch.stack([torch.cat((self.features_ffn(torch.FloatTensor([out[0]]).to(device)), out[1][0]))
219
                             for out in outputs]))), \
220
                   torch.sum(torch.stack([out[1][1] for out in outputs]), dim=-1)