|
a |
|
b/app/models/backbones/rnn.py |
|
|
1 |
from torch import nn |
|
|
2 |
|
|
|
3 |
|
|
|
4 |
class RNN(nn.Module): |
|
|
5 |
def __init__( |
|
|
6 |
self, |
|
|
7 |
demo_dim, |
|
|
8 |
lab_dim, |
|
|
9 |
max_visits, |
|
|
10 |
hidden_dim, |
|
|
11 |
act_layer=nn.GELU, |
|
|
12 |
drop=0.0, |
|
|
13 |
): |
|
|
14 |
super(RNN, self).__init__() |
|
|
15 |
|
|
|
16 |
# hyperparameters |
|
|
17 |
self.demo_dim = demo_dim |
|
|
18 |
self.lab_dim = lab_dim |
|
|
19 |
self.max_visits = max_visits |
|
|
20 |
self.hidden_dim = hidden_dim |
|
|
21 |
|
|
|
22 |
self.proj = nn.Linear(demo_dim + lab_dim, hidden_dim) |
|
|
23 |
self.act = act_layer() |
|
|
24 |
self.bn = nn.BatchNorm1d(max_visits) |
|
|
25 |
self.rnn = nn.RNN( |
|
|
26 |
input_size=hidden_dim, |
|
|
27 |
hidden_size=hidden_dim, |
|
|
28 |
num_layers=1, |
|
|
29 |
batch_first=True, |
|
|
30 |
) |
|
|
31 |
|
|
|
32 |
def forward(self, x, device, info=None): |
|
|
33 |
"""extra info is not used here""" |
|
|
34 |
x = self.proj(x) |
|
|
35 |
# x = self.act(x) |
|
|
36 |
# x = self.bn(x) |
|
|
37 |
x, _ = self.rnn(x) |
|
|
38 |
return x |