# Authors: Tao Yang <sheeptao@outlook.com>
# Bruno Aristimunha <b.aristimunha@gmail.com> (braindecode adaptation)
#
import torch
import torch.nn as nn
from einops.layers.torch import Rearrange
from typing import List, Type, Union, Tuple, Optional, Dict
from braindecode.models.base import EEGModuleMixin
class _TSConv(nn.Sequential):
"""
Time-Distributed Separable Convolution block.
The architecture consists of:
- **Temporal Convolution**
- **Batch Normalization**
- **Depthwise Spatial Convolution**
- **Batch Normalization**
- **Activation Function**
- **First Pooling Layer**
- **Dropout**
- **Depthwise Temporal Convolution**
- **Batch Normalization**
- **Activation Function**
- **Second Pooling Layer**
- **Dropout**
Parameters
----------
n_channels : int
Number of input channels (EEG channels).
n_filters : int
Number of filters for the convolution layers.
conv1_kernel_size : int
Kernel size for the first convolution layer.
conv2_kernel_size : int
Kernel size for the second convolution layer.
depth_multiplier : int
Depth multiplier for depthwise convolution.
pool1_size : int
Kernel size for the first pooling layer.
pool2_size : int
Kernel size for the second pooling layer.
drop_prob : float
Dropout probability.
activation : Type[nn.Module], optional
Activation function class to use, by default nn.ELU.
"""
def __init__(
self,
n_channels: int,
n_filters: int,
conv1_kernel_size: int,
conv2_kernel_size: int,
depth_multiplier: int,
pool1_size: int,
pool2_size: int,
drop_prob: float,
activation: Type[nn.Module] = nn.ELU,
):
super().__init__(
nn.Conv2d(
in_channels=1,
out_channels=n_filters,
kernel_size=(1, conv1_kernel_size),
padding="same",
bias=False,
),
nn.BatchNorm2d(n_filters),
nn.Conv2d(
in_channels=n_filters,
out_channels=n_filters * depth_multiplier,
kernel_size=(n_channels, 1),
groups=n_filters,
bias=False,
),
nn.BatchNorm2d(n_filters * depth_multiplier),
activation(),
nn.AvgPool2d(kernel_size=(1, pool1_size)),
nn.Dropout(drop_prob),
nn.Conv2d(
in_channels=n_filters * depth_multiplier,
out_channels=n_filters * depth_multiplier,
kernel_size=(1, conv2_kernel_size),
padding="same",
groups=n_filters * depth_multiplier,
bias=False,
),
nn.BatchNorm2d(n_filters * depth_multiplier),
activation(),
nn.AvgPool2d(kernel_size=(1, pool2_size)),
nn.Dropout(drop_prob),
)
class _PositionalEncoding(nn.Module):
"""
Positional encoding module that adds learnable positional embeddings.
Parameters
----------
seq_length : int
Sequence length.
d_model : int
Dimensionality of the model.
"""
def __init__(self, seq_length: int, d_model: int) -> None:
super().__init__()
self.seq_length = seq_length
self.d_model = d_model
self.pe = nn.Parameter(torch.zeros(1, seq_length, d_model))
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x + self.pe
return x
class _Transformer(nn.Module):
"""
Transformer encoder module with learnable class token and positional encoding.
Parameters
----------
seq_length : int
Sequence length of the input.
d_model : int
Dimensionality of the model.
num_heads : int
Number of heads in the multihead attention.
feedforward_ratio : float
Ratio to compute the dimension of the feedforward network.
drop_prob : float, optional
Dropout probability, by default 0.5.
num_layers : int, optional
Number of transformer encoder layers, by default 4.
"""
def __init__(
self,
seq_length: int,
d_model: int,
num_heads: int,
feedforward_ratio: float,
drop_prob: float = 0.5,
num_layers: int = 4,
) -> None:
super().__init__()
self.cls_embedding = nn.Parameter(torch.zeros(1, 1, d_model))
self.pos_embedding = _PositionalEncoding(seq_length + 1, d_model)
dim_ff = int(d_model * feedforward_ratio)
self.dropout = nn.Dropout(drop_prob)
self.trans = nn.TransformerEncoder(
nn.TransformerEncoderLayer(
d_model,
num_heads,
dim_ff,
drop_prob,
batch_first=True,
norm_first=True,
),
num_layers,
norm=nn.LayerNorm(d_model),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
batch_size = x.shape[0]
x = torch.cat((self.cls_embedding.expand(batch_size, -1, -1), x), dim=1)
x = self.pos_embedding(x)
x = self.dropout(x)
return self.trans(x)[:, 0]
class _DenseLayers(nn.Sequential):
"""
Final classification layers.
Parameters
----------
linear_in : int
Input dimension to the linear layer.
n_classes : int
Number of output classes.
"""
def __init__(self, linear_in: int, n_classes: int):
super().__init__(
nn.Flatten(),
nn.Linear(linear_in, n_classes),
)
class MSVTNet(EEGModuleMixin, nn.Module):
"""MSVTNet model from Liu K et al (2024) from [msvt2024]_.
This model implements a multi-scale convolutional transformer network
for EEG signal classification, as described in [msvt2024]_.
.. figure:: https://raw.githubusercontent.com/SheepTAO/MSVTNet/refs/heads/main/MSVTNet_Arch.png
:align: center
:alt: MSVTNet Architecture
Parameters
----------
n_filters_list : List[int], optional
List of filter numbers for each TSConv block, by default (9, 9, 9, 9).
conv1_kernels_size : List[int], optional
List of kernel sizes for the first convolution in each TSConv block,
by default (15, 31, 63, 125).
conv2_kernel_size : int, optional
Kernel size for the second convolution in TSConv blocks, by default 15.
depth_multiplier : int, optional
Depth multiplier for depthwise convolution, by default 2.
pool1_size : int, optional
Pooling size for the first pooling layer in TSConv blocks, by default 8.
pool2_size : int, optional
Pooling size for the second pooling layer in TSConv blocks, by default 7.
drop_prob : float, optional
Dropout probability for convolutional layers, by default 0.3.
num_heads : int, optional
Number of attention heads in the transformer encoder, by default 8.
feedforward_ratio : float, optional
Ratio to compute feedforward dimension in the transformer, by default 1.
drop_prob_trans : float, optional
Dropout probability for the transformer, by default 0.5.
num_layers : int, optional
Number of transformer encoder layers, by default 2.
activation : Type[nn.Module], optional
Activation function class to use, by default nn.ELU.
return_features : bool, optional
Whether to return predictions from branch classifiers, by default False.
Notes
-----
This implementation is not guaranteed to be correct, has not been checked
by original authors, only reimplemented based on the original code [msvt2024code]_.
References
----------
.. [msvt2024] Liu, K., et al. (2024). MSVTNet: Multi-Scale Vision
Transformer Neural Network for EEG-Based Motor Imagery Decoding.
IEEE Journal of Biomedical an Health Informatics.
.. [msvt2024code] Liu, K., et al. (2024). MSVTNet: Multi-Scale Vision
Transformer Neural Network for EEG-Based Motor Imagery Decoding.
Source Code: https://github.com/SheepTAO/MSVTNet
"""
def __init__(
self,
# braindecode parameters
n_chans: Optional[int] = None,
n_outputs: Optional[int] = None,
n_times: Optional[int] = None,
input_window_seconds: Optional[float] = None,
sfreq: Optional[float] = None,
chs_info: Optional[List[Dict]] = None,
# Model's parameters
n_filters_list: Tuple[int, ...] = (9, 9, 9, 9),
conv1_kernels_size: Tuple[int, ...] = (15, 31, 63, 125),
conv2_kernel_size: int = 15,
depth_multiplier: int = 2,
pool1_size: int = 8,
pool2_size: int = 7,
drop_prob: float = 0.3,
num_heads: int = 8,
feedforward_ratio: float = 1,
drop_prob_trans: float = 0.5,
num_layers: int = 2,
activation: Type[nn.Module] = nn.ELU,
return_features: bool = False,
):
super().__init__(
n_outputs=n_outputs,
n_chans=n_chans,
chs_info=chs_info,
n_times=n_times,
input_window_seconds=input_window_seconds,
sfreq=sfreq,
)
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
self.return_features = return_features
assert len(n_filters_list) == len(conv1_kernels_size), (
"The length of n_filters_list and conv1_kernel_sizes should be equal."
)
self.ensure_dim = Rearrange("batch chans time -> batch 1 chans time")
self.mstsconv = nn.ModuleList(
[
nn.Sequential(
_TSConv(
self.n_chans,
n_filters_list[b],
conv1_kernels_size[b],
conv2_kernel_size,
depth_multiplier,
pool1_size,
pool2_size,
drop_prob,
activation,
),
Rearrange("batch channels 1 time -> batch time channels"),
)
for b in range(len(n_filters_list))
]
)
branch_linear_in = self._forward_flatten(cat=False)
self.branch_head = nn.ModuleList(
[
_DenseLayers(branch_linear_in[b].shape[1], self.n_outputs)
for b in range(len(n_filters_list))
]
)
seq_len, d_model = self._forward_mstsconv().shape[1:3] # type: ignore
self.transformer = _Transformer(
seq_len,
d_model,
num_heads,
feedforward_ratio,
drop_prob_trans,
num_layers,
)
linear_in = self._forward_flatten().shape[1] # type: ignore
self.flatten_layer = nn.Flatten()
self.final_layer = nn.Linear(linear_in, self.n_outputs)
def _forward_mstsconv(
self, cat: bool = True
) -> Union[torch.Tensor, List[torch.Tensor]]:
x = torch.randn(1, 1, self.n_chans, self.n_times)
x = [tsconv(x) for tsconv in self.mstsconv]
if cat:
x = torch.cat(x, dim=2)
return x
def _forward_flatten(
self, cat: bool = True
) -> Union[torch.Tensor, List[torch.Tensor]]:
x = self._forward_mstsconv(cat)
if cat:
x = self.transformer(x)
x = x.flatten(start_dim=1, end_dim=-1)
else:
x = [_.flatten(start_dim=1, end_dim=-1) for _ in x]
return x
def forward(
self, x: torch.Tensor
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
# x with shape: (batch, n_chans, n_times)
x = self.ensure_dim(x)
# x with shape: (batch, 1, n_chans, n_times)
x_list = [tsconv(x) for tsconv in self.mstsconv]
# x_list contains 4 tensors, each of shape: [batch_size, seq_len, embed_dim]
branch_preds = [
branch(x_list[idx]) for idx, branch in enumerate(self.branch_head)
]
# branch_preds contains 4 tensors, each of shape: [batch_size, num_classes]
x = torch.stack(x_list, dim=2)
x = x.view(x.size(0), x.size(1), -1)
# x shape after concatenation: [batch_size, seq_len, total_embed_dim]
x = self.transformer(x)
# x shape after transformer: [batch_size, embed_dim]
x = self.final_layer(x)
return (x, branch_preds) if self.return_features else x