Diff of /src/utils/models.py [000000] .. [66326d]

Switch to side-by-side view

--- a
+++ b/src/utils/models.py
@@ -0,0 +1,400 @@
+"""
+PyTorch Neural Network model definitions.
+
+Consists of simple parameterised:
+
+- MLP:         Dense Feedforward ANN      / "Multilayer Perceptron"
+- CNN:         1d CNN                     / "Temporal CNN" (TCN)
+- RNN:         Recurrent Neural network
+- GRU:         Gated Recurrent Unit
+- LSTM:        Long-short term memory RNN
+- Transformer: Transformer encoder
+
+Models generally of format:
+
+=================================================================
+Layer (type:depth-idx)                   Output Shape
+=================================================================
+SimpleMLP                                --
+├─Sequential: 1-1
+│    └─Sequential: 2-1                   [n, hidden_dim]
+│    │    └─Linear: 3-1
+│    │    └─Nonlinearity: 3-2
+│    └─Sequential: 2-2                   [n, hidden_dim]
+│    │    └─Linear: 3-3
+│    │    └─Non-linearity: 3-4
+|    |
+                      ... (n_layers) ...
+|    |
+│    └─Sequential: 2-n                   [n, hidden_dim]
+│    │    └─Linear: 3-2n+1
+│    │    └─Nonlinearity: 3-2n+2
+├─Sequential: 1-2                        [n, hidden_dim//2]
+│    └─Linear: 2-1
+|    └─Linear: 2-2                       [n, output_size]
+=================================================================
+
+Where the number of layers, layer width, nonlinearity, and degree of dropout are parameterised.
+
+Model specific parameters:
+
+- CNN          Kernel width
+- RNN/LSTM/GRU Bidirectionality
+- Transformer  Number of heads
+
+"""
+
+from torch import nn
+
+
+class SimpleMLP(nn.Module):
+    """
+    Feed-forward network ("multi-layer perceptron")
+    """
+
+    def __init__(
+        self,
+        n_channels,
+        seq_len,
+        hidden_dim,
+        n_layers,
+        output_size=2,
+        dropout=0,
+        nonlinearity="relu",
+    ):
+        super().__init__()
+
+        if nonlinearity == "relu":
+            nonlinearity = nn.ReLU
+        elif nonlinearity == "tanh":
+            nonlinearity = nn.Tanh
+
+        layers = []
+
+        for i in range(n_layers):
+            if i == 0:
+                current_layer = nn.Sequential(
+                    nn.Linear(
+                        in_features=seq_len * n_channels,
+                        out_features=hidden_dim,
+                        bias=True,
+                    ),
+                    nonlinearity(),
+                    nn.Dropout(p=dropout),
+                )
+            else:
+                current_layer = nn.Sequential(
+                    nn.Linear(
+                        in_features=hidden_dim, out_features=hidden_dim, bias=True
+                    ),
+                    nonlinearity(),
+                    nn.Dropout(p=dropout),
+                )
+            layers.append(current_layer)
+
+        self.features = nn.Sequential(*layers)
+        self.fc = nn.Sequential(
+            nn.Linear(in_features=hidden_dim, out_features=hidden_dim // 2, bias=True),
+            nn.Linear(in_features=hidden_dim // 2, out_features=output_size, bias=True),
+        )
+
+    def forward(self, x):
+        """
+        Forward pass of model.
+        """
+        batch_size = x.shape[0]
+
+        out = x.view(batch_size, -1)
+        out = self.features(out)
+        out = self.fc(out)
+        return out
+
+
+class SimpleRNN(nn.Module):
+    """
+    RNN
+    """
+
+    def __init__(
+        self,
+        n_channels,
+        seq_len,
+        hidden_dim,
+        n_layers,
+        output_size=2,
+        bidirectional=True,
+        nonlinearity="tanh",
+        dropout=0,
+    ):
+        super().__init__()
+
+        scalar = 2 if bidirectional else 1
+
+        self.rnn = nn.RNN(
+            n_channels,
+            hidden_dim,
+            n_layers,
+            batch_first=True,
+            bidirectional=bidirectional,
+            dropout=dropout,
+            nonlinearity=nonlinearity,
+        )
+        self.fc = nn.Sequential(
+            nn.Linear(
+                in_features=scalar * seq_len * hidden_dim,
+                out_features=scalar * seq_len * hidden_dim // 2,
+                bias=True,
+            ),
+            nn.Linear(
+                in_features=scalar * seq_len * hidden_dim // 2,
+                out_features=output_size,
+                bias=True,
+            ),
+        )
+
+    def forward(self, x):
+        """
+        Forward pass of model.
+        """
+        batch_size = x.shape[0]
+
+        out, _ = self.rnn(x)
+        out = out.reshape(batch_size, -1)
+        out = self.fc(out)
+        return out
+
+
+class SimpleLSTM(nn.Module):
+    """
+    LSTM
+    """
+
+    def __init__(
+        self,
+        n_channels,
+        seq_len,
+        hidden_dim,
+        n_layers,
+        output_size=2,
+        bidirectional=True,
+        dropout=0,
+    ):
+        super().__init__()
+
+        scalar = 2 if bidirectional else 1
+
+        self.lstm = nn.LSTM(
+            n_channels,
+            hidden_dim,
+            n_layers,
+            batch_first=True,
+            bidirectional=bidirectional,
+            dropout=dropout,
+        )
+        self.fc = nn.Sequential(
+            nn.Linear(
+                in_features=scalar * seq_len * hidden_dim,
+                out_features=scalar * seq_len * hidden_dim // 2,
+                bias=True,
+            ),
+            nn.Linear(
+                in_features=scalar * seq_len * hidden_dim // 2,
+                out_features=output_size,
+                bias=True,
+            ),
+        )
+
+    def forward(self, x):
+        """
+        Forward pass of model.
+        """
+        batch_size = x.shape[0]
+
+        out, _ = self.lstm(x)
+        out = out.reshape(batch_size, -1)
+        out = self.fc(out)
+        return out
+
+
+class SimpleGRU(nn.Module):
+    """
+    GRU
+    """
+
+    def __init__(
+        self,
+        n_channels,
+        seq_len,
+        hidden_dim,
+        n_layers,
+        output_size=2,
+        bidirectional=True,
+        dropout=0,
+    ):
+        super().__init__()
+
+        scalar = 2 if bidirectional else 1
+
+        self.lstm = nn.GRU(
+            n_channels,
+            hidden_dim,
+            n_layers,
+            batch_first=True,
+            bidirectional=bidirectional,
+            dropout=dropout,
+        )
+        self.fc = nn.Sequential(
+            nn.Linear(
+                in_features=scalar * seq_len * hidden_dim,
+                out_features=scalar * seq_len * hidden_dim // 2,
+                bias=True,
+            ),
+            nn.Linear(
+                in_features=scalar * seq_len * hidden_dim // 2,
+                out_features=output_size,
+                bias=True,
+            ),
+        )
+
+    def forward(self, x):
+        """
+        Forward pass of model.
+        """
+        batch_size = x.shape[0]
+
+        out, _ = self.lstm(x)
+        out = out.reshape(batch_size, -1)
+        out = self.fc(out)
+        return out
+
+
+class SimpleCNN(nn.Module):
+    """
+    1d CNN (also known as TCN)
+
+    `kernel_size` must be odd for `padding` to work as expected.
+    """
+
+    def __init__(
+        self,
+        n_channels,
+        seq_len,
+        hidden_dim,
+        n_layers,
+        output_size=2,
+        kernel_size=3,
+        nonlinearity="relu",
+    ):
+        super().__init__()
+
+        if nonlinearity == "relu":
+            nonlinearity = nn.ReLU
+        elif nonlinearity == "tanh":
+            nonlinearity = nn.Tanh
+
+        layers = []
+        n_pools = 0
+
+        for i in range(n_layers):
+            in_channels = n_channels if i == 0 else hidden_dim
+
+            current_layer = nn.Sequential(
+                nn.Conv1d(
+                    in_channels,
+                    hidden_dim,
+                    kernel_size,
+                    stride=1,
+                    padding=kernel_size // 2,
+                ),
+                # JA: Investigate removing BatchNorm as bad for CL
+                # nn.BatchNorm1d(hidden_dim),
+                nonlinearity(),
+            )
+            layers.append(current_layer)
+
+            # Ensure MaxPools don't wash out entire sequence
+            if seq_len // 2 ** (n_pools + 1) > 2:
+                n_pools += 1
+                layers.append(nn.MaxPool1d(kernel_size=2, stride=2))
+
+        self.cnn_layers = nn.Sequential(*layers)
+        self.fc = nn.Sequential(
+            nn.Linear(
+                in_features=hidden_dim * (seq_len // 2**n_pools),
+                out_features=(hidden_dim * (seq_len // 2**n_pools)) // 2,
+                bias=True,
+            ),
+            nn.Linear(
+                in_features=(hidden_dim * (seq_len // 2**n_pools)) // 2,
+                out_features=output_size,
+                bias=True,
+            ),
+        )
+
+    def forward(self, x):
+        """
+        Forward pass of model.
+        """
+        batch_size = x.shape[0]
+
+        out = x.swapdims(1, 2)
+        out = self.cnn_layers(out)
+        out = out.reshape(batch_size, -1)
+        out = self.fc(out)
+        return out
+
+
+class SimpleTransformer(nn.Module):
+    """
+    Transformer.
+    """
+
+    def __init__(
+        self,
+        n_channels,
+        seq_len,
+        hidden_dim,
+        n_layers,
+        n_heads=8,
+        output_size=2,
+        nonlinearity="relu",
+        dropout=0,
+    ):
+        super().__init__()
+
+        # JA: need to make this more elegant
+        while seq_len % n_heads != 0:
+            n_heads -= 1
+
+        transformer_layer = nn.TransformerEncoderLayer(
+            d_model=seq_len,
+            dim_feedforward=hidden_dim,
+            nhead=n_heads,
+            activation=nonlinearity,
+            dropout=dropout,
+            batch_first=True,
+        )
+        self.transformer = nn.TransformerEncoder(transformer_layer, num_layers=n_layers)
+        self.fc = nn.Linear(seq_len * n_channels, output_size)
+
+    def forward(self, x):
+        """
+        Forward pass of model.
+        """
+        batch_size = x.shape[0]
+
+        out = x.swapdims(1, 2)
+        out = self.transformer(out)
+        out = out.reshape(batch_size, -1)
+        out = self.fc(out)
+        return out
+
+
+MODELS = {
+    "MLP": SimpleMLP,
+    "CNN": SimpleCNN,
+    "RNN": SimpleRNN,
+    "LSTM": SimpleLSTM,
+    "GRU": SimpleGRU,
+    "Transformer": SimpleTransformer,
+}