Switch to unified view

a b/stay_admission/operations.py
1
import numpy as np
2
import torch
3
import torch.nn as nn
4
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
5
import sklearn
6
7
OPS1 = {
8
    'identity': lambda d_model: Identity(d_model),
9
    'ffn': lambda d_model: FFN(d_model),
10
    'interaction_1': lambda d_model: Attention_s1(d_model),
11
    'interaction_2': lambda d_model: Attention_s2(d_model),
12
}
13
14
OPS2 = {
15
    'identity': lambda d_model: Identity(d_model),
16
    'conv': lambda d_model: Conv(d_model),
17
    'attention': lambda d_model: SelfAttention(d_model),
18
    'rnn': lambda d_model: RNN(d_model),
19
    'ffn': lambda d_model: FFN(d_model),
20
    'interaction_1': lambda d_model: CatFC(d_model),
21
    'interaction_2': lambda d_model: Attention_x(d_model)
22
}
23
24
OPS3 = {
25
    'identity': lambda d_model: Identity(d_model),
26
    'zero': lambda d_model: Zero(d_model),
27
}
28
29
OPS4 = {
30
    'sum': lambda d_model: Sum(d_model),
31
    'mul': lambda d_model: Mul(d_model),
32
}
33
34
class Zero(nn.Module):
35
    def __init__(self, d_model):
36
        super(Zero, self).__init__()
37
    def forward(self, x, masks, lengths):
38
        return torch.mul(x, 0)
39
40
41
class Sum(nn.Module):
42
    def __init__(self, d_model):
43
        super(Sum, self).__init__()
44
    def forward(self, all_x):
45
        out = all_x[0]
46
        for x in all_x[1:]:
47
            out += x
48
        return out
49
50
class Mul(nn.Module):
51
    def __init__(self, d_model):
52
        super(Mul, self).__init__()
53
    def forward(self, all_x):
54
        out = all_x[0]
55
        for x in all_x[1:]:
56
            out = out * x
57
        return out
58
59
class CatFC(nn.Module):
60
    def __init__(self, d_model):
61
        super(CatFC, self).__init__()
62
        self.ffn = nn.Sequential(nn.Linear(2*d_model, 4 * d_model), nn.ReLU(),
63
                                 nn.Linear(4 * d_model, d_model))
64
        self.layer_norm = nn.LayerNorm(d_model)
65
    def forward(self, current_x, s, other_x):
66
        s_ = s.unsqueeze(1).expand_as(current_x)
67
        x = torch.cat((current_x, s_), dim=-1)
68
        return self.layer_norm(self.ffn(x))
69
70
71
class Conv(nn.Module):
72
    def __init__(self, d_model):
73
        super(Conv, self).__init__()
74
        self.op = nn.Sequential(
75
            nn.ReLU(),
76
            nn.Conv1d(d_model, d_model, 3, padding=1),
77
            nn.BatchNorm1d(d_model, affine=True)
78
        )
79
        # self.batchnm = nn.BatchNorm1d(d_model, affine=True)
80
        # self.conv = nn.Conv1d(d_model, d_model, 3, padding=1)
81
82
    def forward(self, x, masks, lengths):
83
        x = self.op(x.permute(0, 2, 1))
84
        return x.permute(0, 2, 1)
85
86
87
class FFN(nn.Module):
88
89
  def __init__(self, d_model):
90
      super(FFN, self).__init__()
91
      self.ffn = nn.Sequential(nn.Linear(d_model, 4 * d_model), nn.ReLU(),
92
                                         nn.Linear(4 * d_model, d_model))
93
      self.layer_norm = nn.LayerNorm(d_model)
94
95
  def forward(self, x, masks, lengths):
96
      x = self.layer_norm(x + self.ffn(x))
97
      return x
98
99
100
class Identity(nn.Module):
101
  def __init__(self, d_model):
102
      super(Identity, self).__init__()
103
  def forward(self, x, masks, lengths):
104
      return x
105
106
107
class SelfAttention(nn.Module):
108
    def __init__(self, in_feature, num_head=4, dropout=0.1):
109
        super(SelfAttention, self).__init__()
110
        self.in_feature = in_feature
111
        self.num_head = num_head
112
        self.size_per_head = in_feature // num_head
113
        self.out_dim = num_head * self.size_per_head
114
        assert self.size_per_head * num_head == in_feature
115
        self.q_linear = nn.Linear(in_feature, in_feature, bias=False)
116
        self.k_linear = nn.Linear(in_feature, in_feature, bias=False)
117
        self.v_linear = nn.Linear(in_feature, in_feature, bias=False)
118
        self.fc = nn.Linear(in_feature, in_feature, bias=False)
119
        self.dropout = nn.Dropout(dropout)
120
        self.layer_norm = nn.LayerNorm(in_feature)
121
122
    def forward(self, x, attn_mask, lengths):
123
        batch_size = x.size(0)
124
        res = x
125
        query = self.q_linear(x)
126
        key = self.k_linear(x)
127
        value = self.v_linear(x)
128
129
        query = query.view(batch_size, self.num_head, -1, self.size_per_head)
130
        key = key.view(batch_size, self.num_head, -1, self.size_per_head)
131
        value = value.view(batch_size, self.num_head, -1, self.size_per_head)
132
133
        scale = np.sqrt(self.size_per_head)
134
        energy = torch.matmul(query, key.permute(0, 1, 3, 2)) / scale
135
136
        attention = torch.softmax(energy, dim=-1)
137
        x = torch.matmul(attention, value)
138
        x = x.permute(0, 2, 1, 3).contiguous()
139
        x = x.view(batch_size, -1, self.in_feature)
140
        x = self.fc(x)
141
        x = self.dropout(x)
142
        x += res
143
        x = self.layer_norm(x)
144
        return x
145
146
147
class Attention_s1(nn.Module):
148
    def __init__(self, in_feature, num_head=4, dropout=0.1):
149
        super(Attention_s1, self).__init__()
150
        self.in_feature = in_feature
151
        self.num_head = num_head
152
        self.size_per_head = in_feature // num_head
153
        self.out_dim = num_head * self.size_per_head
154
        assert self.size_per_head * num_head == in_feature
155
        self.q_linear = nn.Linear(in_feature, in_feature, bias=False)
156
        self.k_linear = nn.Linear(in_feature, in_feature, bias=False)
157
        self.v_linear = nn.Linear(in_feature, in_feature, bias=False)
158
        self.fc = nn.Linear(in_feature, in_feature, bias=False)
159
        self.dropout = nn.Dropout(dropout)
160
        self.layer_norm = nn.LayerNorm(in_feature)
161
162
    def forward(self, s, x1, x2):
163
        batch_size = x1.size(0)
164
        s = s.unsqueeze(1)
165
        res = s
166
        query = self.q_linear(s)
167
        key = self.k_linear(x1)
168
        value = self.v_linear(x1)
169
170
        query = query.view(batch_size, self.num_head, -1, self.size_per_head)
171
        key = key.view(batch_size, self.num_head, -1, self.size_per_head)
172
        value = value.view(batch_size, self.num_head, -1, self.size_per_head)
173
174
        scale = np.sqrt(self.size_per_head)
175
        energy = torch.matmul(query, key.permute(0, 1, 3, 2)) / scale
176
177
        attention = torch.softmax(energy, dim=-1)
178
        x = torch.matmul(attention, value)
179
        x = x.permute(0, 2, 1, 3).contiguous()
180
        x = x.view(batch_size, -1, self.in_feature)
181
        x = self.fc(x)
182
        x = self.dropout(x)
183
        x += res
184
        x = self.layer_norm(x)
185
        return x.squeeze()
186
187
188
class Attention_s2(nn.Module):
189
    def __init__(self, in_feature, num_head=4, dropout=0.1):
190
        super(Attention_s2, self).__init__()
191
        self.in_feature = in_feature
192
        self.num_head = num_head
193
        self.size_per_head = in_feature // num_head
194
        self.out_dim = num_head * self.size_per_head
195
        assert self.size_per_head * num_head == in_feature
196
        self.q_linear = nn.Linear(in_feature, in_feature, bias=False)
197
        self.k_linear = nn.Linear(in_feature, in_feature, bias=False)
198
        self.v_linear = nn.Linear(in_feature, in_feature, bias=False)
199
        self.fc = nn.Linear(in_feature, in_feature, bias=False)
200
        self.dropout = nn.Dropout(dropout)
201
        self.layer_norm = nn.LayerNorm(in_feature)
202
203
    def forward(self, s, x1, x2):
204
        batch_size = x2.size(0)
205
        s = s.unsqueeze(1)
206
        res = s
207
        query = self.q_linear(s)
208
        key = self.k_linear(x2)
209
        value = self.v_linear(x2)
210
211
        query = query.view(batch_size, self.num_head, -1, self.size_per_head)
212
        key = key.view(batch_size, self.num_head, -1, self.size_per_head)
213
        value = value.view(batch_size, self.num_head, -1, self.size_per_head)
214
215
        scale = np.sqrt(self.size_per_head)
216
        energy = torch.matmul(query, key.permute(0, 1, 3, 2)) / scale
217
218
        attention = torch.softmax(energy, dim=-1)
219
        x = torch.matmul(attention, value)
220
        x = x.permute(0, 2, 1, 3).contiguous()
221
        x = x.view(batch_size, -1, self.in_feature)
222
        x = self.fc(x)
223
        x = self.dropout(x)
224
        x += res
225
        x = self.layer_norm(x)
226
        return x.squeeze()
227
228
class Attention_x(nn.Module):
229
    def __init__(self, in_feature, num_head=4, dropout=0.1):
230
        super(Attention_x, self).__init__()
231
        self.in_feature = in_feature
232
        self.num_head = num_head
233
        self.size_per_head = in_feature // num_head
234
        self.out_dim = num_head * self.size_per_head
235
        assert self.size_per_head * num_head == in_feature
236
        self.q_linear = nn.Linear(in_feature, in_feature, bias=False)
237
        self.k_linear = nn.Linear(in_feature, in_feature, bias=False)
238
        self.v_linear = nn.Linear(in_feature, in_feature, bias=False)
239
        self.fc = nn.Linear(in_feature, in_feature, bias=False)
240
        self.dropout = nn.Dropout(dropout)
241
        self.layer_norm = nn.LayerNorm(in_feature)
242
243
    def forward(self, current_x, s, other_x):
244
        batch_size = current_x.size(0)
245
        res = current_x
246
        query = self.q_linear(current_x)
247
        key = self.k_linear(other_x)
248
        value = self.v_linear(other_x)
249
250
        query = query.view(batch_size, self.num_head, -1, self.size_per_head)
251
        key = key.view(batch_size, self.num_head, -1, self.size_per_head)
252
        value = value.view(batch_size, self.num_head, -1, self.size_per_head)
253
254
        scale = np.sqrt(self.size_per_head)
255
        energy = torch.matmul(query, key.permute(0, 1, 3, 2)) / scale
256
257
        attention = torch.softmax(energy, dim=-1)
258
        x = torch.matmul(attention, value)
259
        x = x.permute(0, 2, 1, 3).contiguous()
260
        x = x.view(batch_size, -1, self.in_feature)
261
        x = self.fc(x)
262
        x = self.dropout(x)
263
        x += res
264
        x = self.layer_norm(x)
265
        return x
266
267
268
class RNN(nn.Module):
269
    def __init__(self, d_model):
270
        super(RNN, self).__init__()
271
        self.rnn = nn.GRU(d_model, d_model, num_layers=1, batch_first=True)
272
    def forward(self, x, masks, lengths):
273
        rnn_input = x
274
        rnn_output, _ = self.rnn(rnn_input)
275
        return rnn_output
276
277
278
class MaxPoolLayer(nn.Module):
279
    """
280
    A layer that performs max pooling along the sequence dimension
281
    """
282
283
    def __init__(self):
284
        super().__init__()
285
286
    def forward(self, inputs, mask_or_lengths=None):
287
        """
288
        inputs: tensor of shape (batch_size, seq_len, hidden_size)
289
        mask_or_lengths: tensor of shape (batch_size) or (batch_size, seq_len)
290
291
        returns: tensor of shape (batch_size, hidden_size)
292
        """
293
        bs, sl, _ = inputs.size()
294
        if mask_or_lengths is not None:
295
            if len(mask_or_lengths.size()) == 1:
296
                mask = (torch.arange(sl, device=inputs.device).unsqueeze(0).expand(bs, sl) >= mask_or_lengths.unsqueeze(
297
                    1))
298
            else:
299
                mask = mask_or_lengths
300
            inputs = inputs.masked_fill(mask.unsqueeze(-1).expand_as(inputs), float('-inf'))
301
        max_pooled = inputs.max(1)[0]
302
        return max_pooled
303
304
305
def prroc(testy, probs):
306
    precision, recall, thresholds = sklearn.metrics.precision_recall_curve(testy, probs)
307
    auc = auc(recall, precision)
308
    return auc