Switch to unified view

a b/braindecode/models/msvtnet.py
1
# Authors: Tao Yang <sheeptao@outlook.com>
2
#          Bruno Aristimunha <b.aristimunha@gmail.com> (braindecode adaptation)
3
#
4
import torch
5
import torch.nn as nn
6
from einops.layers.torch import Rearrange
7
from typing import List, Type, Union, Tuple, Optional, Dict
8
from braindecode.models.base import EEGModuleMixin
9
10
11
class _TSConv(nn.Sequential):
12
    """
13
    Time-Distributed Separable Convolution block.
14
15
    The architecture consists of:
16
    - **Temporal Convolution**
17
    - **Batch Normalization**
18
    - **Depthwise Spatial Convolution**
19
    - **Batch Normalization**
20
    - **Activation Function**
21
    - **First Pooling Layer**
22
    - **Dropout**
23
    - **Depthwise Temporal Convolution**
24
    - **Batch Normalization**
25
    - **Activation Function**
26
    - **Second Pooling Layer**
27
    - **Dropout**
28
29
    Parameters
30
    ----------
31
    n_channels : int
32
        Number of input channels (EEG channels).
33
    n_filters : int
34
        Number of filters for the convolution layers.
35
    conv1_kernel_size : int
36
        Kernel size for the first convolution layer.
37
    conv2_kernel_size : int
38
        Kernel size for the second convolution layer.
39
    depth_multiplier : int
40
        Depth multiplier for depthwise convolution.
41
    pool1_size : int
42
        Kernel size for the first pooling layer.
43
    pool2_size : int
44
        Kernel size for the second pooling layer.
45
    drop_prob : float
46
        Dropout probability.
47
    activation : Type[nn.Module], optional
48
        Activation function class to use, by default nn.ELU.
49
    """
50
51
    def __init__(
52
        self,
53
        n_channels: int,
54
        n_filters: int,
55
        conv1_kernel_size: int,
56
        conv2_kernel_size: int,
57
        depth_multiplier: int,
58
        pool1_size: int,
59
        pool2_size: int,
60
        drop_prob: float,
61
        activation: Type[nn.Module] = nn.ELU,
62
    ):
63
        super().__init__(
64
            nn.Conv2d(
65
                in_channels=1,
66
                out_channels=n_filters,
67
                kernel_size=(1, conv1_kernel_size),
68
                padding="same",
69
                bias=False,
70
            ),
71
            nn.BatchNorm2d(n_filters),
72
            nn.Conv2d(
73
                in_channels=n_filters,
74
                out_channels=n_filters * depth_multiplier,
75
                kernel_size=(n_channels, 1),
76
                groups=n_filters,
77
                bias=False,
78
            ),
79
            nn.BatchNorm2d(n_filters * depth_multiplier),
80
            activation(),
81
            nn.AvgPool2d(kernel_size=(1, pool1_size)),
82
            nn.Dropout(drop_prob),
83
            nn.Conv2d(
84
                in_channels=n_filters * depth_multiplier,
85
                out_channels=n_filters * depth_multiplier,
86
                kernel_size=(1, conv2_kernel_size),
87
                padding="same",
88
                groups=n_filters * depth_multiplier,
89
                bias=False,
90
            ),
91
            nn.BatchNorm2d(n_filters * depth_multiplier),
92
            activation(),
93
            nn.AvgPool2d(kernel_size=(1, pool2_size)),
94
            nn.Dropout(drop_prob),
95
        )
96
97
98
class _PositionalEncoding(nn.Module):
99
    """
100
    Positional encoding module that adds learnable positional embeddings.
101
102
    Parameters
103
    ----------
104
    seq_length : int
105
        Sequence length.
106
    d_model : int
107
        Dimensionality of the model.
108
    """
109
110
    def __init__(self, seq_length: int, d_model: int) -> None:
111
        super().__init__()
112
        self.seq_length = seq_length
113
        self.d_model = d_model
114
        self.pe = nn.Parameter(torch.zeros(1, seq_length, d_model))
115
116
    def forward(self, x: torch.Tensor) -> torch.Tensor:
117
        x = x + self.pe
118
        return x
119
120
121
class _Transformer(nn.Module):
122
    """
123
    Transformer encoder module with learnable class token and positional encoding.
124
125
    Parameters
126
    ----------
127
    seq_length : int
128
        Sequence length of the input.
129
    d_model : int
130
        Dimensionality of the model.
131
    num_heads : int
132
        Number of heads in the multihead attention.
133
    feedforward_ratio : float
134
        Ratio to compute the dimension of the feedforward network.
135
    drop_prob : float, optional
136
        Dropout probability, by default 0.5.
137
    num_layers : int, optional
138
        Number of transformer encoder layers, by default 4.
139
    """
140
141
    def __init__(
142
        self,
143
        seq_length: int,
144
        d_model: int,
145
        num_heads: int,
146
        feedforward_ratio: float,
147
        drop_prob: float = 0.5,
148
        num_layers: int = 4,
149
    ) -> None:
150
        super().__init__()
151
        self.cls_embedding = nn.Parameter(torch.zeros(1, 1, d_model))
152
        self.pos_embedding = _PositionalEncoding(seq_length + 1, d_model)
153
154
        dim_ff = int(d_model * feedforward_ratio)
155
        self.dropout = nn.Dropout(drop_prob)
156
        self.trans = nn.TransformerEncoder(
157
            nn.TransformerEncoderLayer(
158
                d_model,
159
                num_heads,
160
                dim_ff,
161
                drop_prob,
162
                batch_first=True,
163
                norm_first=True,
164
            ),
165
            num_layers,
166
            norm=nn.LayerNorm(d_model),
167
        )
168
169
    def forward(self, x: torch.Tensor) -> torch.Tensor:
170
        batch_size = x.shape[0]
171
        x = torch.cat((self.cls_embedding.expand(batch_size, -1, -1), x), dim=1)
172
        x = self.pos_embedding(x)
173
        x = self.dropout(x)
174
        return self.trans(x)[:, 0]
175
176
177
class _DenseLayers(nn.Sequential):
178
    """
179
    Final classification layers.
180
181
    Parameters
182
    ----------
183
    linear_in : int
184
        Input dimension to the linear layer.
185
    n_classes : int
186
        Number of output classes.
187
    """
188
189
    def __init__(self, linear_in: int, n_classes: int):
190
        super().__init__(
191
            nn.Flatten(),
192
            nn.Linear(linear_in, n_classes),
193
        )
194
195
196
class MSVTNet(EEGModuleMixin, nn.Module):
197
    """MSVTNet model from Liu K et al (2024) from [msvt2024]_.
198
199
    This model implements a multi-scale convolutional transformer network
200
    for EEG signal classification, as described in [msvt2024]_.
201
202
    .. figure:: https://raw.githubusercontent.com/SheepTAO/MSVTNet/refs/heads/main/MSVTNet_Arch.png
203
       :align: center
204
       :alt: MSVTNet Architecture
205
206
    Parameters
207
    ----------
208
    n_filters_list : List[int], optional
209
        List of filter numbers for each TSConv block, by default (9, 9, 9, 9).
210
    conv1_kernels_size : List[int], optional
211
        List of kernel sizes for the first convolution in each TSConv block,
212
        by default (15, 31, 63, 125).
213
    conv2_kernel_size : int, optional
214
        Kernel size for the second convolution in TSConv blocks, by default 15.
215
    depth_multiplier : int, optional
216
        Depth multiplier for depthwise convolution, by default 2.
217
    pool1_size : int, optional
218
        Pooling size for the first pooling layer in TSConv blocks, by default 8.
219
    pool2_size : int, optional
220
        Pooling size for the second pooling layer in TSConv blocks, by default 7.
221
    drop_prob : float, optional
222
        Dropout probability for convolutional layers, by default 0.3.
223
    num_heads : int, optional
224
        Number of attention heads in the transformer encoder, by default 8.
225
    feedforward_ratio : float, optional
226
        Ratio to compute feedforward dimension in the transformer, by default 1.
227
    drop_prob_trans : float, optional
228
        Dropout probability for the transformer, by default 0.5.
229
    num_layers : int, optional
230
        Number of transformer encoder layers, by default 2.
231
    activation : Type[nn.Module], optional
232
        Activation function class to use, by default nn.ELU.
233
    return_features : bool, optional
234
        Whether to return predictions from branch classifiers, by default False.
235
236
    Notes
237
    -----
238
    This implementation is not guaranteed to be correct, has not been checked
239
    by original authors, only reimplemented based on the original code [msvt2024code]_.
240
241
    References
242
    ----------
243
    .. [msvt2024] Liu, K., et al. (2024). MSVTNet: Multi-Scale Vision
244
       Transformer Neural Network for EEG-Based Motor Imagery Decoding.
245
       IEEE Journal of Biomedical an Health Informatics.
246
    .. [msvt2024code] Liu, K., et al. (2024). MSVTNet: Multi-Scale Vision
247
       Transformer Neural Network for EEG-Based Motor Imagery Decoding.
248
       Source Code: https://github.com/SheepTAO/MSVTNet
249
    """
250
251
    def __init__(
252
        self,
253
        # braindecode parameters
254
        n_chans: Optional[int] = None,
255
        n_outputs: Optional[int] = None,
256
        n_times: Optional[int] = None,
257
        input_window_seconds: Optional[float] = None,
258
        sfreq: Optional[float] = None,
259
        chs_info: Optional[List[Dict]] = None,
260
        # Model's parameters
261
        n_filters_list: Tuple[int, ...] = (9, 9, 9, 9),
262
        conv1_kernels_size: Tuple[int, ...] = (15, 31, 63, 125),
263
        conv2_kernel_size: int = 15,
264
        depth_multiplier: int = 2,
265
        pool1_size: int = 8,
266
        pool2_size: int = 7,
267
        drop_prob: float = 0.3,
268
        num_heads: int = 8,
269
        feedforward_ratio: float = 1,
270
        drop_prob_trans: float = 0.5,
271
        num_layers: int = 2,
272
        activation: Type[nn.Module] = nn.ELU,
273
        return_features: bool = False,
274
    ):
275
        super().__init__(
276
            n_outputs=n_outputs,
277
            n_chans=n_chans,
278
            chs_info=chs_info,
279
            n_times=n_times,
280
            input_window_seconds=input_window_seconds,
281
            sfreq=sfreq,
282
        )
283
        del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq
284
285
        self.return_features = return_features
286
        assert len(n_filters_list) == len(conv1_kernels_size), (
287
            "The length of n_filters_list and conv1_kernel_sizes should be equal."
288
        )
289
290
        self.ensure_dim = Rearrange("batch chans time -> batch 1 chans time")
291
        self.mstsconv = nn.ModuleList(
292
            [
293
                nn.Sequential(
294
                    _TSConv(
295
                        self.n_chans,
296
                        n_filters_list[b],
297
                        conv1_kernels_size[b],
298
                        conv2_kernel_size,
299
                        depth_multiplier,
300
                        pool1_size,
301
                        pool2_size,
302
                        drop_prob,
303
                        activation,
304
                    ),
305
                    Rearrange("batch channels 1 time -> batch time channels"),
306
                )
307
                for b in range(len(n_filters_list))
308
            ]
309
        )
310
        branch_linear_in = self._forward_flatten(cat=False)
311
        self.branch_head = nn.ModuleList(
312
            [
313
                _DenseLayers(branch_linear_in[b].shape[1], self.n_outputs)
314
                for b in range(len(n_filters_list))
315
            ]
316
        )
317
318
        seq_len, d_model = self._forward_mstsconv().shape[1:3]  # type: ignore
319
        self.transformer = _Transformer(
320
            seq_len,
321
            d_model,
322
            num_heads,
323
            feedforward_ratio,
324
            drop_prob_trans,
325
            num_layers,
326
        )
327
328
        linear_in = self._forward_flatten().shape[1]  # type: ignore
329
        self.flatten_layer = nn.Flatten()
330
        self.final_layer = nn.Linear(linear_in, self.n_outputs)
331
332
    def _forward_mstsconv(
333
        self, cat: bool = True
334
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
335
        x = torch.randn(1, 1, self.n_chans, self.n_times)
336
        x = [tsconv(x) for tsconv in self.mstsconv]
337
        if cat:
338
            x = torch.cat(x, dim=2)
339
        return x
340
341
    def _forward_flatten(
342
        self, cat: bool = True
343
    ) -> Union[torch.Tensor, List[torch.Tensor]]:
344
        x = self._forward_mstsconv(cat)
345
        if cat:
346
            x = self.transformer(x)
347
            x = x.flatten(start_dim=1, end_dim=-1)
348
        else:
349
            x = [_.flatten(start_dim=1, end_dim=-1) for _ in x]
350
        return x
351
352
    def forward(
353
        self, x: torch.Tensor
354
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
355
        # x with shape: (batch, n_chans, n_times)
356
        x = self.ensure_dim(x)
357
        # x with shape: (batch, 1, n_chans, n_times)
358
        x_list = [tsconv(x) for tsconv in self.mstsconv]
359
        # x_list contains 4 tensors, each of shape: [batch_size, seq_len, embed_dim]
360
        branch_preds = [
361
            branch(x_list[idx]) for idx, branch in enumerate(self.branch_head)
362
        ]
363
        # branch_preds contains 4 tensors, each of shape: [batch_size, num_classes]
364
        x = torch.stack(x_list, dim=2)
365
        x = x.view(x.size(0), x.size(1), -1)
366
        # x shape after concatenation: [batch_size, seq_len, total_embed_dim]
367
        x = self.transformer(x)
368
        # x shape after transformer: [batch_size, embed_dim]
369
370
        x = self.final_layer(x)
371
        return (x, branch_preds) if self.return_features else x