Diff of /src/utils/models.py [000000] .. [66326d]

Switch to unified view

a b/src/utils/models.py
1
"""
2
PyTorch Neural Network model definitions.
3
4
Consists of simple parameterised:
5
6
- MLP:         Dense Feedforward ANN      / "Multilayer Perceptron"
7
- CNN:         1d CNN                     / "Temporal CNN" (TCN)
8
- RNN:         Recurrent Neural network
9
- GRU:         Gated Recurrent Unit
10
- LSTM:        Long-short term memory RNN
11
- Transformer: Transformer encoder
12
13
Models generally of format:
14
15
=================================================================
16
Layer (type:depth-idx)                   Output Shape
17
=================================================================
18
SimpleMLP                                --
19
├─Sequential: 1-1
20
│    └─Sequential: 2-1                   [n, hidden_dim]
21
│    │    └─Linear: 3-1
22
│    │    └─Nonlinearity: 3-2
23
│    └─Sequential: 2-2                   [n, hidden_dim]
24
│    │    └─Linear: 3-3
25
│    │    └─Non-linearity: 3-4
26
|    |
27
                      ... (n_layers) ...
28
|    |
29
│    └─Sequential: 2-n                   [n, hidden_dim]
30
│    │    └─Linear: 3-2n+1
31
│    │    └─Nonlinearity: 3-2n+2
32
├─Sequential: 1-2                        [n, hidden_dim//2]
33
│    └─Linear: 2-1
34
|    └─Linear: 2-2                       [n, output_size]
35
=================================================================
36
37
Where the number of layers, layer width, nonlinearity, and degree of dropout are parameterised.
38
39
Model specific parameters:
40
41
- CNN          Kernel width
42
- RNN/LSTM/GRU Bidirectionality
43
- Transformer  Number of heads
44
45
"""
46
47
from torch import nn
48
49
50
class SimpleMLP(nn.Module):
51
    """
52
    Feed-forward network ("multi-layer perceptron")
53
    """
54
55
    def __init__(
56
        self,
57
        n_channels,
58
        seq_len,
59
        hidden_dim,
60
        n_layers,
61
        output_size=2,
62
        dropout=0,
63
        nonlinearity="relu",
64
    ):
65
        super().__init__()
66
67
        if nonlinearity == "relu":
68
            nonlinearity = nn.ReLU
69
        elif nonlinearity == "tanh":
70
            nonlinearity = nn.Tanh
71
72
        layers = []
73
74
        for i in range(n_layers):
75
            if i == 0:
76
                current_layer = nn.Sequential(
77
                    nn.Linear(
78
                        in_features=seq_len * n_channels,
79
                        out_features=hidden_dim,
80
                        bias=True,
81
                    ),
82
                    nonlinearity(),
83
                    nn.Dropout(p=dropout),
84
                )
85
            else:
86
                current_layer = nn.Sequential(
87
                    nn.Linear(
88
                        in_features=hidden_dim, out_features=hidden_dim, bias=True
89
                    ),
90
                    nonlinearity(),
91
                    nn.Dropout(p=dropout),
92
                )
93
            layers.append(current_layer)
94
95
        self.features = nn.Sequential(*layers)
96
        self.fc = nn.Sequential(
97
            nn.Linear(in_features=hidden_dim, out_features=hidden_dim // 2, bias=True),
98
            nn.Linear(in_features=hidden_dim // 2, out_features=output_size, bias=True),
99
        )
100
101
    def forward(self, x):
102
        """
103
        Forward pass of model.
104
        """
105
        batch_size = x.shape[0]
106
107
        out = x.view(batch_size, -1)
108
        out = self.features(out)
109
        out = self.fc(out)
110
        return out
111
112
113
class SimpleRNN(nn.Module):
114
    """
115
    RNN
116
    """
117
118
    def __init__(
119
        self,
120
        n_channels,
121
        seq_len,
122
        hidden_dim,
123
        n_layers,
124
        output_size=2,
125
        bidirectional=True,
126
        nonlinearity="tanh",
127
        dropout=0,
128
    ):
129
        super().__init__()
130
131
        scalar = 2 if bidirectional else 1
132
133
        self.rnn = nn.RNN(
134
            n_channels,
135
            hidden_dim,
136
            n_layers,
137
            batch_first=True,
138
            bidirectional=bidirectional,
139
            dropout=dropout,
140
            nonlinearity=nonlinearity,
141
        )
142
        self.fc = nn.Sequential(
143
            nn.Linear(
144
                in_features=scalar * seq_len * hidden_dim,
145
                out_features=scalar * seq_len * hidden_dim // 2,
146
                bias=True,
147
            ),
148
            nn.Linear(
149
                in_features=scalar * seq_len * hidden_dim // 2,
150
                out_features=output_size,
151
                bias=True,
152
            ),
153
        )
154
155
    def forward(self, x):
156
        """
157
        Forward pass of model.
158
        """
159
        batch_size = x.shape[0]
160
161
        out, _ = self.rnn(x)
162
        out = out.reshape(batch_size, -1)
163
        out = self.fc(out)
164
        return out
165
166
167
class SimpleLSTM(nn.Module):
168
    """
169
    LSTM
170
    """
171
172
    def __init__(
173
        self,
174
        n_channels,
175
        seq_len,
176
        hidden_dim,
177
        n_layers,
178
        output_size=2,
179
        bidirectional=True,
180
        dropout=0,
181
    ):
182
        super().__init__()
183
184
        scalar = 2 if bidirectional else 1
185
186
        self.lstm = nn.LSTM(
187
            n_channels,
188
            hidden_dim,
189
            n_layers,
190
            batch_first=True,
191
            bidirectional=bidirectional,
192
            dropout=dropout,
193
        )
194
        self.fc = nn.Sequential(
195
            nn.Linear(
196
                in_features=scalar * seq_len * hidden_dim,
197
                out_features=scalar * seq_len * hidden_dim // 2,
198
                bias=True,
199
            ),
200
            nn.Linear(
201
                in_features=scalar * seq_len * hidden_dim // 2,
202
                out_features=output_size,
203
                bias=True,
204
            ),
205
        )
206
207
    def forward(self, x):
208
        """
209
        Forward pass of model.
210
        """
211
        batch_size = x.shape[0]
212
213
        out, _ = self.lstm(x)
214
        out = out.reshape(batch_size, -1)
215
        out = self.fc(out)
216
        return out
217
218
219
class SimpleGRU(nn.Module):
220
    """
221
    GRU
222
    """
223
224
    def __init__(
225
        self,
226
        n_channels,
227
        seq_len,
228
        hidden_dim,
229
        n_layers,
230
        output_size=2,
231
        bidirectional=True,
232
        dropout=0,
233
    ):
234
        super().__init__()
235
236
        scalar = 2 if bidirectional else 1
237
238
        self.lstm = nn.GRU(
239
            n_channels,
240
            hidden_dim,
241
            n_layers,
242
            batch_first=True,
243
            bidirectional=bidirectional,
244
            dropout=dropout,
245
        )
246
        self.fc = nn.Sequential(
247
            nn.Linear(
248
                in_features=scalar * seq_len * hidden_dim,
249
                out_features=scalar * seq_len * hidden_dim // 2,
250
                bias=True,
251
            ),
252
            nn.Linear(
253
                in_features=scalar * seq_len * hidden_dim // 2,
254
                out_features=output_size,
255
                bias=True,
256
            ),
257
        )
258
259
    def forward(self, x):
260
        """
261
        Forward pass of model.
262
        """
263
        batch_size = x.shape[0]
264
265
        out, _ = self.lstm(x)
266
        out = out.reshape(batch_size, -1)
267
        out = self.fc(out)
268
        return out
269
270
271
class SimpleCNN(nn.Module):
272
    """
273
    1d CNN (also known as TCN)
274
275
    `kernel_size` must be odd for `padding` to work as expected.
276
    """
277
278
    def __init__(
279
        self,
280
        n_channels,
281
        seq_len,
282
        hidden_dim,
283
        n_layers,
284
        output_size=2,
285
        kernel_size=3,
286
        nonlinearity="relu",
287
    ):
288
        super().__init__()
289
290
        if nonlinearity == "relu":
291
            nonlinearity = nn.ReLU
292
        elif nonlinearity == "tanh":
293
            nonlinearity = nn.Tanh
294
295
        layers = []
296
        n_pools = 0
297
298
        for i in range(n_layers):
299
            in_channels = n_channels if i == 0 else hidden_dim
300
301
            current_layer = nn.Sequential(
302
                nn.Conv1d(
303
                    in_channels,
304
                    hidden_dim,
305
                    kernel_size,
306
                    stride=1,
307
                    padding=kernel_size // 2,
308
                ),
309
                # JA: Investigate removing BatchNorm as bad for CL
310
                # nn.BatchNorm1d(hidden_dim),
311
                nonlinearity(),
312
            )
313
            layers.append(current_layer)
314
315
            # Ensure MaxPools don't wash out entire sequence
316
            if seq_len // 2 ** (n_pools + 1) > 2:
317
                n_pools += 1
318
                layers.append(nn.MaxPool1d(kernel_size=2, stride=2))
319
320
        self.cnn_layers = nn.Sequential(*layers)
321
        self.fc = nn.Sequential(
322
            nn.Linear(
323
                in_features=hidden_dim * (seq_len // 2**n_pools),
324
                out_features=(hidden_dim * (seq_len // 2**n_pools)) // 2,
325
                bias=True,
326
            ),
327
            nn.Linear(
328
                in_features=(hidden_dim * (seq_len // 2**n_pools)) // 2,
329
                out_features=output_size,
330
                bias=True,
331
            ),
332
        )
333
334
    def forward(self, x):
335
        """
336
        Forward pass of model.
337
        """
338
        batch_size = x.shape[0]
339
340
        out = x.swapdims(1, 2)
341
        out = self.cnn_layers(out)
342
        out = out.reshape(batch_size, -1)
343
        out = self.fc(out)
344
        return out
345
346
347
class SimpleTransformer(nn.Module):
348
    """
349
    Transformer.
350
    """
351
352
    def __init__(
353
        self,
354
        n_channels,
355
        seq_len,
356
        hidden_dim,
357
        n_layers,
358
        n_heads=8,
359
        output_size=2,
360
        nonlinearity="relu",
361
        dropout=0,
362
    ):
363
        super().__init__()
364
365
        # JA: need to make this more elegant
366
        while seq_len % n_heads != 0:
367
            n_heads -= 1
368
369
        transformer_layer = nn.TransformerEncoderLayer(
370
            d_model=seq_len,
371
            dim_feedforward=hidden_dim,
372
            nhead=n_heads,
373
            activation=nonlinearity,
374
            dropout=dropout,
375
            batch_first=True,
376
        )
377
        self.transformer = nn.TransformerEncoder(transformer_layer, num_layers=n_layers)
378
        self.fc = nn.Linear(seq_len * n_channels, output_size)
379
380
    def forward(self, x):
381
        """
382
        Forward pass of model.
383
        """
384
        batch_size = x.shape[0]
385
386
        out = x.swapdims(1, 2)
387
        out = self.transformer(out)
388
        out = out.reshape(batch_size, -1)
389
        out = self.fc(out)
390
        return out
391
392
393
MODELS = {
394
    "MLP": SimpleMLP,
395
    "CNN": SimpleCNN,
396
    "RNN": SimpleRNN,
397
    "LSTM": SimpleLSTM,
398
    "GRU": SimpleGRU,
399
    "Transformer": SimpleTransformer,
400
}