--- a +++ b/src/utils/models.py @@ -0,0 +1,400 @@ +""" +PyTorch Neural Network model definitions. + +Consists of simple parameterised: + +- MLP: Dense Feedforward ANN / "Multilayer Perceptron" +- CNN: 1d CNN / "Temporal CNN" (TCN) +- RNN: Recurrent Neural network +- GRU: Gated Recurrent Unit +- LSTM: Long-short term memory RNN +- Transformer: Transformer encoder + +Models generally of format: + +================================================================= +Layer (type:depth-idx) Output Shape +================================================================= +SimpleMLP -- +├─Sequential: 1-1 +│ └─Sequential: 2-1 [n, hidden_dim] +│ │ └─Linear: 3-1 +│ │ └─Nonlinearity: 3-2 +│ └─Sequential: 2-2 [n, hidden_dim] +│ │ └─Linear: 3-3 +│ │ └─Non-linearity: 3-4 +| | + ... (n_layers) ... +| | +│ └─Sequential: 2-n [n, hidden_dim] +│ │ └─Linear: 3-2n+1 +│ │ └─Nonlinearity: 3-2n+2 +├─Sequential: 1-2 [n, hidden_dim//2] +│ └─Linear: 2-1 +| └─Linear: 2-2 [n, output_size] +================================================================= + +Where the number of layers, layer width, nonlinearity, and degree of dropout are parameterised. + +Model specific parameters: + +- CNN Kernel width +- RNN/LSTM/GRU Bidirectionality +- Transformer Number of heads + +""" + +from torch import nn + + +class SimpleMLP(nn.Module): + """ + Feed-forward network ("multi-layer perceptron") + """ + + def __init__( + self, + n_channels, + seq_len, + hidden_dim, + n_layers, + output_size=2, + dropout=0, + nonlinearity="relu", + ): + super().__init__() + + if nonlinearity == "relu": + nonlinearity = nn.ReLU + elif nonlinearity == "tanh": + nonlinearity = nn.Tanh + + layers = [] + + for i in range(n_layers): + if i == 0: + current_layer = nn.Sequential( + nn.Linear( + in_features=seq_len * n_channels, + out_features=hidden_dim, + bias=True, + ), + nonlinearity(), + nn.Dropout(p=dropout), + ) + else: + current_layer = nn.Sequential( + nn.Linear( + in_features=hidden_dim, out_features=hidden_dim, bias=True + ), + nonlinearity(), + nn.Dropout(p=dropout), + ) + layers.append(current_layer) + + self.features = nn.Sequential(*layers) + self.fc = nn.Sequential( + nn.Linear(in_features=hidden_dim, out_features=hidden_dim // 2, bias=True), + nn.Linear(in_features=hidden_dim // 2, out_features=output_size, bias=True), + ) + + def forward(self, x): + """ + Forward pass of model. + """ + batch_size = x.shape[0] + + out = x.view(batch_size, -1) + out = self.features(out) + out = self.fc(out) + return out + + +class SimpleRNN(nn.Module): + """ + RNN + """ + + def __init__( + self, + n_channels, + seq_len, + hidden_dim, + n_layers, + output_size=2, + bidirectional=True, + nonlinearity="tanh", + dropout=0, + ): + super().__init__() + + scalar = 2 if bidirectional else 1 + + self.rnn = nn.RNN( + n_channels, + hidden_dim, + n_layers, + batch_first=True, + bidirectional=bidirectional, + dropout=dropout, + nonlinearity=nonlinearity, + ) + self.fc = nn.Sequential( + nn.Linear( + in_features=scalar * seq_len * hidden_dim, + out_features=scalar * seq_len * hidden_dim // 2, + bias=True, + ), + nn.Linear( + in_features=scalar * seq_len * hidden_dim // 2, + out_features=output_size, + bias=True, + ), + ) + + def forward(self, x): + """ + Forward pass of model. + """ + batch_size = x.shape[0] + + out, _ = self.rnn(x) + out = out.reshape(batch_size, -1) + out = self.fc(out) + return out + + +class SimpleLSTM(nn.Module): + """ + LSTM + """ + + def __init__( + self, + n_channels, + seq_len, + hidden_dim, + n_layers, + output_size=2, + bidirectional=True, + dropout=0, + ): + super().__init__() + + scalar = 2 if bidirectional else 1 + + self.lstm = nn.LSTM( + n_channels, + hidden_dim, + n_layers, + batch_first=True, + bidirectional=bidirectional, + dropout=dropout, + ) + self.fc = nn.Sequential( + nn.Linear( + in_features=scalar * seq_len * hidden_dim, + out_features=scalar * seq_len * hidden_dim // 2, + bias=True, + ), + nn.Linear( + in_features=scalar * seq_len * hidden_dim // 2, + out_features=output_size, + bias=True, + ), + ) + + def forward(self, x): + """ + Forward pass of model. + """ + batch_size = x.shape[0] + + out, _ = self.lstm(x) + out = out.reshape(batch_size, -1) + out = self.fc(out) + return out + + +class SimpleGRU(nn.Module): + """ + GRU + """ + + def __init__( + self, + n_channels, + seq_len, + hidden_dim, + n_layers, + output_size=2, + bidirectional=True, + dropout=0, + ): + super().__init__() + + scalar = 2 if bidirectional else 1 + + self.lstm = nn.GRU( + n_channels, + hidden_dim, + n_layers, + batch_first=True, + bidirectional=bidirectional, + dropout=dropout, + ) + self.fc = nn.Sequential( + nn.Linear( + in_features=scalar * seq_len * hidden_dim, + out_features=scalar * seq_len * hidden_dim // 2, + bias=True, + ), + nn.Linear( + in_features=scalar * seq_len * hidden_dim // 2, + out_features=output_size, + bias=True, + ), + ) + + def forward(self, x): + """ + Forward pass of model. + """ + batch_size = x.shape[0] + + out, _ = self.lstm(x) + out = out.reshape(batch_size, -1) + out = self.fc(out) + return out + + +class SimpleCNN(nn.Module): + """ + 1d CNN (also known as TCN) + + `kernel_size` must be odd for `padding` to work as expected. + """ + + def __init__( + self, + n_channels, + seq_len, + hidden_dim, + n_layers, + output_size=2, + kernel_size=3, + nonlinearity="relu", + ): + super().__init__() + + if nonlinearity == "relu": + nonlinearity = nn.ReLU + elif nonlinearity == "tanh": + nonlinearity = nn.Tanh + + layers = [] + n_pools = 0 + + for i in range(n_layers): + in_channels = n_channels if i == 0 else hidden_dim + + current_layer = nn.Sequential( + nn.Conv1d( + in_channels, + hidden_dim, + kernel_size, + stride=1, + padding=kernel_size // 2, + ), + # JA: Investigate removing BatchNorm as bad for CL + # nn.BatchNorm1d(hidden_dim), + nonlinearity(), + ) + layers.append(current_layer) + + # Ensure MaxPools don't wash out entire sequence + if seq_len // 2 ** (n_pools + 1) > 2: + n_pools += 1 + layers.append(nn.MaxPool1d(kernel_size=2, stride=2)) + + self.cnn_layers = nn.Sequential(*layers) + self.fc = nn.Sequential( + nn.Linear( + in_features=hidden_dim * (seq_len // 2**n_pools), + out_features=(hidden_dim * (seq_len // 2**n_pools)) // 2, + bias=True, + ), + nn.Linear( + in_features=(hidden_dim * (seq_len // 2**n_pools)) // 2, + out_features=output_size, + bias=True, + ), + ) + + def forward(self, x): + """ + Forward pass of model. + """ + batch_size = x.shape[0] + + out = x.swapdims(1, 2) + out = self.cnn_layers(out) + out = out.reshape(batch_size, -1) + out = self.fc(out) + return out + + +class SimpleTransformer(nn.Module): + """ + Transformer. + """ + + def __init__( + self, + n_channels, + seq_len, + hidden_dim, + n_layers, + n_heads=8, + output_size=2, + nonlinearity="relu", + dropout=0, + ): + super().__init__() + + # JA: need to make this more elegant + while seq_len % n_heads != 0: + n_heads -= 1 + + transformer_layer = nn.TransformerEncoderLayer( + d_model=seq_len, + dim_feedforward=hidden_dim, + nhead=n_heads, + activation=nonlinearity, + dropout=dropout, + batch_first=True, + ) + self.transformer = nn.TransformerEncoder(transformer_layer, num_layers=n_layers) + self.fc = nn.Linear(seq_len * n_channels, output_size) + + def forward(self, x): + """ + Forward pass of model. + """ + batch_size = x.shape[0] + + out = x.swapdims(1, 2) + out = self.transformer(out) + out = out.reshape(batch_size, -1) + out = self.fc(out) + return out + + +MODELS = { + "MLP": SimpleMLP, + "CNN": SimpleCNN, + "RNN": SimpleRNN, + "LSTM": SimpleLSTM, + "GRU": SimpleGRU, + "Transformer": SimpleTransformer, +}