Switch to unified view

a b/app/models/backbones/rnn.py
1
from torch import nn
2
3
4
class RNN(nn.Module):
5
    def __init__(
6
        self,
7
        demo_dim,
8
        lab_dim,
9
        max_visits,
10
        hidden_dim,
11
        act_layer=nn.GELU,
12
        drop=0.0,
13
    ):
14
        super(RNN, self).__init__()
15
16
        # hyperparameters
17
        self.demo_dim = demo_dim
18
        self.lab_dim = lab_dim
19
        self.max_visits = max_visits
20
        self.hidden_dim = hidden_dim
21
22
        self.proj = nn.Linear(demo_dim + lab_dim, hidden_dim)
23
        self.act = act_layer()
24
        self.bn = nn.BatchNorm1d(max_visits)
25
        self.rnn = nn.RNN(
26
            input_size=hidden_dim,
27
            hidden_size=hidden_dim,
28
            num_layers=1,
29
            batch_first=True,
30
        )
31
32
    def forward(self, x, device, info=None):
33
        """extra info is not used here"""
34
        x = self.proj(x)
35
        # x = self.act(x)
36
        # x = self.bn(x)
37
        x, _ = self.rnn(x)
38
        return x