--- a +++ b/Patient2Vec.py @@ -0,0 +1,157 @@ +""" +Patient2Vec: a self-attentive representation learning framework +author: Jinghe Zhang +""" +import torch +import torch.nn as nn +import torch.optim as optim +from torch.autograd import Variable +import numpy as np + + +class Patient2Vec(nn.Module): + """ + Self-attentive representation learning framework, + including convolutional embedding layer, + recurrent autoencoder with an encoder, recurrent module, and a decoder. + In addition, a linear layer is on top of each decode step and the weights are shared at these step. + """ + + def __init__(self, input_size, hidden_size, n_layers, att_dim, initrange, + output_size, rnn_type, seq_len, pad_size, n_filters, bi, dropout_p=0.5): + """ + Initilize a recurrent model + :param input_size: int + :param hidden_size: int + :param n_layers: number of layers; int + :param att_dim: dimension of the attention; int + :param initrange: upper bound of the initial weights; symmetric + :param output_size: int + :param rnn_type: str, such as 'GRU' + :param seq_len: length of the sequence; int + :param pad_size: padding size; int + :param n_filters: number of hops; int + :param bi: bidirectional; bool + :param dropout_p: dropout rate; float + """ + super(Patient2Vec, self).__init__() + + self.initrange = initrange + # convolution + self.b = 1 + if bi: + self.b = 2 + + self.conv = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=input_size, stride=2) + self.conv2 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=hidden_size * self.b, stride=2) + # Bidirectional RNN + self.rnn = getattr(nn, rnn_type)(input_size, hidden_size, n_layers, dropout=dropout_p, + batch_first=True, bias=True, bidirectional=bi) + # initialize 2-layer attention weight matrics + self.att_w1 = nn.Linear(hidden_size * self.b, att_dim, bias=False) + # final linear layer + self.linear = nn.Linear(hidden_size * self.b * n_filters + 3, output_size, bias=True) + + self.func_softmax = nn.Softmax() + self.func_sigmoid = nn.Sigmoid() + self.func_tanh = nn.Hardtanh(0, 1) + # Add dropout + self.dropout_p = dropout_p + self.dropout = nn.Dropout(p=self.dropout_p) + self.init_weights() + + self.pad_size = pad_size + self.input_size = input_size + self.hidden_size = hidden_size + self.output_size = output_size + self.n_layers = n_layers + self.seq_len = seq_len + self.n_filters = n_filters + + def init_weights(self): + """ + weight initialization + """ + for param in self.parameters(): + param.data.uniform_(-self.initrange, self.initrange) + + def convolutional_layer(self, inputs): + convolution_all = [] + conv_wts = [] + for i in range(self.seq_len): + convolution_one_month = [] + for j in range(self.pad_size): + convolution = self.conv(torch.unsqueeze(inputs[:, i, j], dim=1)) + convolution_one_month.append(convolution) + convolution_one_month = torch.stack(convolution_one_month) + convolution_one_month = torch.squeeze(convolution_one_month, dim=3) + convolution_one_month = torch.transpose(convolution_one_month, 0, 1) + convolution_one_month = torch.transpose(convolution_one_month, 1, 2) + convolution_one_month = torch.squeeze(convolution_one_month, dim=1) + convolution_one_month = self.func_tanh(convolution_one_month) + convolution_one_month = torch.unsqueeze(convolution_one_month, dim=1) + vec = torch.bmm(convolution_one_month, inputs[:, i]) + convolution_all.append(vec) + conv_wts.append(convolution_one_month) + convolution_all = torch.stack(convolution_all, dim=1) + convolution_all = torch.squeeze(convolution_all, dim=2) + conv_wts = torch.squeeze(torch.stack(conv_wts, dim=1), dim=2) + return convolution_all, conv_wts + + def encode_rnn(self, embedding, batch_size): + self.weight = next(self.parameters()).data + init_state = (Variable(self.weight.new(self.n_layers * self.b, batch_size, self.hidden_size).zero_())) + embedding = self.dropout(embedding) + outputs_rnn, states_rnn = self.rnn(embedding, init_state) + return outputs_rnn + + def add_beta_attention(self, states, batch_size): + # beta attention + att_wts = [] + for i in range(self.seq_len): + m1 = self.conv2(torch.unsqueeze(states[:, i], dim=1)) + att_wts.append(torch.squeeze(m1, dim=2)) + att_wts = torch.stack(att_wts, dim=2) + att_beta = [] + for i in range(self.n_filters): + a0 = self.func_softmax(att_wts[:, i]) + att_beta.append(a0) + att_beta = torch.stack(att_beta, dim=1) + context = torch.bmm(att_beta, states) + context = context.view(batch_size, -1) + return att_beta, context + + def forward(self, inputs, inputs_other, batch_size): + """ + the recurrent module + """ + # Convolutional + convolutions, alpha = self.convolutional_layer(inputs) + # RNN + states_rnn = self.encode_rnn(convolutions, batch_size) + # Add attentions and get context vector + beta, context = self.add_beta_attention(states_rnn, batch_size) + # Final linear layer with demographic info added as extra variables + context_v2 = torch.cat((context, inputs_other), 1) + linear_y = self.linear(context_v2) + out = self.func_softmax(linear_y) + return out, alpha, beta + + +def get_loss(pred, y, criterion, mtr, a=0.5): + """ + To calculate loss + :param pred: predicted value + :param y: actual value + :param criterion: nn.CrossEntropyLoss + :param mtr: beta matrix + """ + mtr_t = torch.transpose(mtr, 1, 2) + aa = torch.bmm(mtr, mtr_t) + loss_fn = 0 + for i in range(aa.size()[0]): + aai = torch.add(aa[i, ], Variable(torch.neg(torch.eye(mtr.size()[1])))) + loss_fn += torch.trace(torch.mul(aai, aai).data) + loss_fn /= aa.size()[0] + loss = torch.add(criterion(pred, y), Variable(torch.FloatTensor([loss_fn * a]))) + return loss