Diff of /Patient2Vec.py [000000] .. [a6870d]

Switch to unified view

a b/Patient2Vec.py
1
"""
2
Patient2Vec: a self-attentive representation learning framework
3
author: Jinghe Zhang
4
"""
5
import torch
6
import torch.nn as nn
7
import torch.optim as optim
8
from torch.autograd import Variable
9
import numpy as np
10
11
12
class Patient2Vec(nn.Module):
13
    """
14
    Self-attentive representation learning framework,
15
    including convolutional embedding layer,
16
    recurrent autoencoder with an encoder, recurrent module, and a decoder.
17
    In addition, a linear layer is on top of each decode step and the weights are shared at these step.
18
    """
19
20
    def __init__(self, input_size, hidden_size, n_layers, att_dim, initrange,
21
                 output_size, rnn_type, seq_len, pad_size, n_filters, bi, dropout_p=0.5):
22
        """
23
        Initilize a recurrent model
24
        :param input_size: int
25
        :param hidden_size: int
26
        :param n_layers: number of layers; int
27
        :param att_dim: dimension of the attention; int
28
        :param initrange: upper bound of the initial weights; symmetric
29
        :param output_size: int
30
        :param rnn_type: str, such as 'GRU'
31
        :param seq_len: length of the sequence; int
32
        :param pad_size: padding size; int
33
        :param n_filters: number of hops; int
34
        :param bi: bidirectional; bool
35
        :param dropout_p: dropout rate; float
36
        """
37
        super(Patient2Vec, self).__init__()
38
39
        self.initrange = initrange
40
        # convolution
41
        self.b = 1
42
        if bi:
43
            self.b = 2
44
45
        self.conv = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=input_size, stride=2)
46
        self.conv2 = nn.Conv1d(in_channels=1, out_channels=n_filters, kernel_size=hidden_size * self.b, stride=2)
47
        # Bidirectional RNN
48
        self.rnn = getattr(nn, rnn_type)(input_size, hidden_size, n_layers, dropout=dropout_p,
49
                                         batch_first=True, bias=True, bidirectional=bi)
50
        # initialize 2-layer attention weight matrics
51
        self.att_w1 = nn.Linear(hidden_size * self.b, att_dim, bias=False)
52
        # final linear layer
53
        self.linear = nn.Linear(hidden_size * self.b * n_filters + 3, output_size, bias=True)
54
55
        self.func_softmax = nn.Softmax()
56
        self.func_sigmoid = nn.Sigmoid()
57
        self.func_tanh = nn.Hardtanh(0, 1)
58
        # Add dropout
59
        self.dropout_p = dropout_p
60
        self.dropout = nn.Dropout(p=self.dropout_p)
61
        self.init_weights()
62
63
        self.pad_size = pad_size
64
        self.input_size = input_size
65
        self.hidden_size = hidden_size
66
        self.output_size = output_size
67
        self.n_layers = n_layers
68
        self.seq_len = seq_len
69
        self.n_filters = n_filters
70
71
    def init_weights(self):
72
        """
73
        weight initialization
74
        """
75
        for param in self.parameters():
76
            param.data.uniform_(-self.initrange, self.initrange)
77
78
    def convolutional_layer(self, inputs):
79
        convolution_all = []
80
        conv_wts = []
81
        for i in range(self.seq_len):
82
            convolution_one_month = []
83
            for j in range(self.pad_size):
84
                convolution = self.conv(torch.unsqueeze(inputs[:, i, j], dim=1))
85
                convolution_one_month.append(convolution)
86
            convolution_one_month = torch.stack(convolution_one_month)
87
            convolution_one_month = torch.squeeze(convolution_one_month, dim=3)
88
            convolution_one_month = torch.transpose(convolution_one_month, 0, 1)
89
            convolution_one_month = torch.transpose(convolution_one_month, 1, 2)
90
            convolution_one_month = torch.squeeze(convolution_one_month, dim=1)
91
            convolution_one_month = self.func_tanh(convolution_one_month)
92
            convolution_one_month = torch.unsqueeze(convolution_one_month, dim=1)
93
            vec = torch.bmm(convolution_one_month, inputs[:, i])
94
            convolution_all.append(vec)
95
            conv_wts.append(convolution_one_month)
96
        convolution_all = torch.stack(convolution_all, dim=1)
97
        convolution_all = torch.squeeze(convolution_all, dim=2)
98
        conv_wts = torch.squeeze(torch.stack(conv_wts, dim=1), dim=2)
99
        return convolution_all, conv_wts
100
101
    def encode_rnn(self, embedding, batch_size):
102
        self.weight = next(self.parameters()).data
103
        init_state = (Variable(self.weight.new(self.n_layers * self.b, batch_size, self.hidden_size).zero_()))
104
        embedding = self.dropout(embedding)
105
        outputs_rnn, states_rnn = self.rnn(embedding, init_state)
106
        return outputs_rnn
107
108
    def add_beta_attention(self, states, batch_size):
109
        # beta attention
110
        att_wts = []
111
        for i in range(self.seq_len):
112
            m1 = self.conv2(torch.unsqueeze(states[:, i], dim=1))
113
            att_wts.append(torch.squeeze(m1, dim=2))
114
        att_wts = torch.stack(att_wts, dim=2)
115
        att_beta = []
116
        for i in range(self.n_filters):
117
            a0 = self.func_softmax(att_wts[:, i])
118
            att_beta.append(a0)
119
        att_beta = torch.stack(att_beta, dim=1)
120
        context = torch.bmm(att_beta, states)
121
        context = context.view(batch_size, -1)
122
        return att_beta, context
123
124
    def forward(self, inputs, inputs_other, batch_size):
125
        """
126
        the recurrent module
127
        """
128
        # Convolutional
129
        convolutions, alpha = self.convolutional_layer(inputs)
130
        # RNN
131
        states_rnn = self.encode_rnn(convolutions, batch_size)
132
        # Add attentions and get context vector
133
        beta, context = self.add_beta_attention(states_rnn, batch_size)
134
        # Final linear layer with demographic info added as extra variables
135
        context_v2 = torch.cat((context, inputs_other), 1)
136
        linear_y = self.linear(context_v2)
137
        out = self.func_softmax(linear_y)
138
        return out, alpha, beta
139
140
141
def get_loss(pred, y, criterion, mtr, a=0.5):
142
    """
143
    To calculate loss
144
    :param pred: predicted value
145
    :param y: actual value
146
    :param criterion: nn.CrossEntropyLoss
147
    :param mtr: beta matrix
148
    """
149
    mtr_t = torch.transpose(mtr, 1, 2)
150
    aa = torch.bmm(mtr, mtr_t)
151
    loss_fn = 0
152
    for i in range(aa.size()[0]):
153
        aai = torch.add(aa[i, ], Variable(torch.neg(torch.eye(mtr.size()[1]))))
154
        loss_fn += torch.trace(torch.mul(aai, aai).data)
155
    loss_fn /= aa.size()[0]
156
    loss = torch.add(criterion(pred, y), Variable(torch.FloatTensor([loss_fn * a])))
157
    return loss