|
a |
|
b/Patient2Vec.py |
|
|
1 |
""" |
|
|
2 |
Patient2Vec: a self-attentive representation learning framework |
|
|
3 |
author: Jinghe Zhang |
|
|
4 |
""" |
|
|
5 |
import torch |
|
|
6 |
import torch.nn as nn |
|
|
7 |
import torch.optim as optim |
|
|
8 |
from torch.autograd import Variable |
|
|
9 |
import numpy as np |
|
|
10 |
|
|
|
11 |
|
|
|
12 |
class Patient2Vec(nn.Module): |
|
|
13 |
""" |
|
|
14 |
Self-attentive representation learning framework, |
|
|
15 |
including convolutional embedding layer, |
|
|
16 |
recurrent autoencoder with an encoder, recurrent module, and a decoder. |
|
|
17 |
In addition, a linear layer is on top of each decode step and the weights are shared at these step. |
|
|
18 |
""" |
|
|
19 |
|
|
|
20 |
def __init__(self, input_size, hidden_size, n_layers, att_dim, initrange, |
|
|
21 |
output_size, rnn_type, seq_len, pad_size, n_filters, bi, dropout_p=0.5): |
|
|
22 |
""" |
|
|
23 |
Initilize a recurrent model |
|
|
24 |
:param input_size: int |
|
|
25 |
:param hidden_size: int |
|
|
26 |
:param n_layers: number of layers; int |
|
|
27 |
:param att_dim: dimension of the attention; int |
|
|
28 |
:param initrange: upper bound of the initial weights; symmetric |
|
|
29 |
:param output_size: int |
|
|
30 |
:param rnn_type: str, such as 'GRU' |
|
|
31 |
:param seq_len: length of the sequence; int |
|
|
32 |
:param pad_size: padding size; int |
|
|
33 |
:param n_filters: number of hops; int |
|
|
34 |
:param bi: bidirectional; bool |
|
|
35 |
:param dropout_p: dropout rate; float |
|
|
36 |
""" |
|
|
37 |
super(Patient2Vec, self).__init__() |
|
|
38 |
|
|
|
39 |
self.initrange = initrange |
|
|
40 |
# convolution |
|
|
41 |
self.b = 1 |
|
|
42 |
if bi: |
|
|
43 |
self.b = 2 |
|
|
44 |
|
|
|
45 |
self.conv = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=input_size, stride=2) |
|
|
46 |
self.conv2 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=hidden_size * self.b, stride=2) |
|
|
47 |
# Bidirectional RNN |
|
|
48 |
self.rnn = getattr(nn, rnn_type)(input_size, hidden_size, n_layers, dropout=dropout_p, |
|
|
49 |
batch_first=True, bias=True, bidirectional=bi) |
|
|
50 |
# initialize 2-layer attention weight matrics |
|
|
51 |
self.att_w1 = nn.Linear(hidden_size * self.b, att_dim, bias=False) |
|
|
52 |
# final linear layer |
|
|
53 |
self.linear = nn.Linear(hidden_size * self.b * n_filters + 3, output_size, bias=True) |
|
|
54 |
|
|
|
55 |
self.func_softmax = nn.Softmax() |
|
|
56 |
self.func_sigmoid = nn.Sigmoid() |
|
|
57 |
self.func_tanh = nn.Hardtanh(0, 1) |
|
|
58 |
# Add dropout |
|
|
59 |
self.dropout_p = dropout_p |
|
|
60 |
self.dropout = nn.Dropout(p=self.dropout_p) |
|
|
61 |
self.init_weights() |
|
|
62 |
|
|
|
63 |
self.pad_size = pad_size |
|
|
64 |
self.input_size = input_size |
|
|
65 |
self.hidden_size = hidden_size |
|
|
66 |
self.output_size = output_size |
|
|
67 |
self.n_layers = n_layers |
|
|
68 |
self.seq_len = seq_len |
|
|
69 |
self.n_filters = n_filters |
|
|
70 |
|
|
|
71 |
def init_weights(self): |
|
|
72 |
""" |
|
|
73 |
weight initialization |
|
|
74 |
""" |
|
|
75 |
for param in self.parameters(): |
|
|
76 |
param.data.uniform_(-self.initrange, self.initrange) |
|
|
77 |
|
|
|
78 |
def convolutional_layer(self, inputs): |
|
|
79 |
convolution_all = [] |
|
|
80 |
conv_wts = [] |
|
|
81 |
for i in range(self.seq_len): |
|
|
82 |
convolution_one_month = [] |
|
|
83 |
for j in range(self.pad_size): |
|
|
84 |
convolution = self.conv(torch.unsqueeze(inputs[:, i, j], dim=1)) |
|
|
85 |
convolution_one_month.append(convolution) |
|
|
86 |
convolution_one_month = torch.stack(convolution_one_month) |
|
|
87 |
convolution_one_month = torch.squeeze(convolution_one_month, dim=3) |
|
|
88 |
convolution_one_month = torch.transpose(convolution_one_month, 0, 1) |
|
|
89 |
convolution_one_month = torch.transpose(convolution_one_month, 1, 2) |
|
|
90 |
convolution_one_month = torch.squeeze(convolution_one_month, dim=1) |
|
|
91 |
convolution_one_month = self.func_tanh(convolution_one_month) |
|
|
92 |
convolution_one_month = torch.unsqueeze(convolution_one_month, dim=1) |
|
|
93 |
vec = torch.bmm(convolution_one_month, inputs[:, i]) |
|
|
94 |
convolution_all.append(vec) |
|
|
95 |
conv_wts.append(convolution_one_month) |
|
|
96 |
convolution_all = torch.stack(convolution_all, dim=1) |
|
|
97 |
convolution_all = torch.squeeze(convolution_all, dim=2) |
|
|
98 |
conv_wts = torch.squeeze(torch.stack(conv_wts, dim=1), dim=2) |
|
|
99 |
return convolution_all, conv_wts |
|
|
100 |
|
|
|
101 |
def encode_rnn(self, embedding, batch_size): |
|
|
102 |
self.weight = next(self.parameters()).data |
|
|
103 |
init_state = (Variable(self.weight.new(self.n_layers * self.b, batch_size, self.hidden_size).zero_())) |
|
|
104 |
embedding = self.dropout(embedding) |
|
|
105 |
outputs_rnn, states_rnn = self.rnn(embedding, init_state) |
|
|
106 |
return outputs_rnn |
|
|
107 |
|
|
|
108 |
def add_beta_attention(self, states, batch_size): |
|
|
109 |
# beta attention |
|
|
110 |
att_wts = [] |
|
|
111 |
for i in range(self.seq_len): |
|
|
112 |
m1 = self.conv2(torch.unsqueeze(states[:, i], dim=1)) |
|
|
113 |
att_wts.append(torch.squeeze(m1, dim=2)) |
|
|
114 |
att_wts = torch.stack(att_wts, dim=2) |
|
|
115 |
att_beta = [] |
|
|
116 |
for i in range(self.n_filters): |
|
|
117 |
a0 = self.func_softmax(att_wts[:, i]) |
|
|
118 |
att_beta.append(a0) |
|
|
119 |
att_beta = torch.stack(att_beta, dim=1) |
|
|
120 |
context = torch.bmm(att_beta, states) |
|
|
121 |
context = context.view(batch_size, -1) |
|
|
122 |
return att_beta, context |
|
|
123 |
|
|
|
124 |
def forward(self, inputs, inputs_other, batch_size): |
|
|
125 |
""" |
|
|
126 |
the recurrent module |
|
|
127 |
""" |
|
|
128 |
# Convolutional |
|
|
129 |
convolutions, alpha = self.convolutional_layer(inputs) |
|
|
130 |
# RNN |
|
|
131 |
states_rnn = self.encode_rnn(convolutions, batch_size) |
|
|
132 |
# Add attentions and get context vector |
|
|
133 |
beta, context = self.add_beta_attention(states_rnn, batch_size) |
|
|
134 |
# Final linear layer with demographic info added as extra variables |
|
|
135 |
context_v2 = torch.cat((context, inputs_other), 1) |
|
|
136 |
linear_y = self.linear(context_v2) |
|
|
137 |
out = self.func_softmax(linear_y) |
|
|
138 |
return out, alpha, beta |
|
|
139 |
|
|
|
140 |
|
|
|
141 |
def get_loss(pred, y, criterion, mtr, a=0.5): |
|
|
142 |
""" |
|
|
143 |
To calculate loss |
|
|
144 |
:param pred: predicted value |
|
|
145 |
:param y: actual value |
|
|
146 |
:param criterion: nn.CrossEntropyLoss |
|
|
147 |
:param mtr: beta matrix |
|
|
148 |
""" |
|
|
149 |
mtr_t = torch.transpose(mtr, 1, 2) |
|
|
150 |
aa = torch.bmm(mtr, mtr_t) |
|
|
151 |
loss_fn = 0 |
|
|
152 |
for i in range(aa.size()[0]): |
|
|
153 |
aai = torch.add(aa[i, ], Variable(torch.neg(torch.eye(mtr.size()[1])))) |
|
|
154 |
loss_fn += torch.trace(torch.mul(aai, aai).data) |
|
|
155 |
loss_fn /= aa.size()[0] |
|
|
156 |
loss = torch.add(criterion(pred, y), Variable(torch.FloatTensor([loss_fn * a]))) |
|
|
157 |
return loss |