|
a |
|
b/braindecode/models/msvtnet.py |
|
|
1 |
# Authors: Tao Yang <sheeptao@outlook.com> |
|
|
2 |
# Bruno Aristimunha <b.aristimunha@gmail.com> (braindecode adaptation) |
|
|
3 |
# |
|
|
4 |
import torch |
|
|
5 |
import torch.nn as nn |
|
|
6 |
from einops.layers.torch import Rearrange |
|
|
7 |
from typing import List, Type, Union, Tuple, Optional, Dict |
|
|
8 |
from braindecode.models.base import EEGModuleMixin |
|
|
9 |
|
|
|
10 |
|
|
|
11 |
class _TSConv(nn.Sequential): |
|
|
12 |
""" |
|
|
13 |
Time-Distributed Separable Convolution block. |
|
|
14 |
|
|
|
15 |
The architecture consists of: |
|
|
16 |
- **Temporal Convolution** |
|
|
17 |
- **Batch Normalization** |
|
|
18 |
- **Depthwise Spatial Convolution** |
|
|
19 |
- **Batch Normalization** |
|
|
20 |
- **Activation Function** |
|
|
21 |
- **First Pooling Layer** |
|
|
22 |
- **Dropout** |
|
|
23 |
- **Depthwise Temporal Convolution** |
|
|
24 |
- **Batch Normalization** |
|
|
25 |
- **Activation Function** |
|
|
26 |
- **Second Pooling Layer** |
|
|
27 |
- **Dropout** |
|
|
28 |
|
|
|
29 |
Parameters |
|
|
30 |
---------- |
|
|
31 |
n_channels : int |
|
|
32 |
Number of input channels (EEG channels). |
|
|
33 |
n_filters : int |
|
|
34 |
Number of filters for the convolution layers. |
|
|
35 |
conv1_kernel_size : int |
|
|
36 |
Kernel size for the first convolution layer. |
|
|
37 |
conv2_kernel_size : int |
|
|
38 |
Kernel size for the second convolution layer. |
|
|
39 |
depth_multiplier : int |
|
|
40 |
Depth multiplier for depthwise convolution. |
|
|
41 |
pool1_size : int |
|
|
42 |
Kernel size for the first pooling layer. |
|
|
43 |
pool2_size : int |
|
|
44 |
Kernel size for the second pooling layer. |
|
|
45 |
drop_prob : float |
|
|
46 |
Dropout probability. |
|
|
47 |
activation : Type[nn.Module], optional |
|
|
48 |
Activation function class to use, by default nn.ELU. |
|
|
49 |
""" |
|
|
50 |
|
|
|
51 |
def __init__( |
|
|
52 |
self, |
|
|
53 |
n_channels: int, |
|
|
54 |
n_filters: int, |
|
|
55 |
conv1_kernel_size: int, |
|
|
56 |
conv2_kernel_size: int, |
|
|
57 |
depth_multiplier: int, |
|
|
58 |
pool1_size: int, |
|
|
59 |
pool2_size: int, |
|
|
60 |
drop_prob: float, |
|
|
61 |
activation: Type[nn.Module] = nn.ELU, |
|
|
62 |
): |
|
|
63 |
super().__init__( |
|
|
64 |
nn.Conv2d( |
|
|
65 |
in_channels=1, |
|
|
66 |
out_channels=n_filters, |
|
|
67 |
kernel_size=(1, conv1_kernel_size), |
|
|
68 |
padding="same", |
|
|
69 |
bias=False, |
|
|
70 |
), |
|
|
71 |
nn.BatchNorm2d(n_filters), |
|
|
72 |
nn.Conv2d( |
|
|
73 |
in_channels=n_filters, |
|
|
74 |
out_channels=n_filters * depth_multiplier, |
|
|
75 |
kernel_size=(n_channels, 1), |
|
|
76 |
groups=n_filters, |
|
|
77 |
bias=False, |
|
|
78 |
), |
|
|
79 |
nn.BatchNorm2d(n_filters * depth_multiplier), |
|
|
80 |
activation(), |
|
|
81 |
nn.AvgPool2d(kernel_size=(1, pool1_size)), |
|
|
82 |
nn.Dropout(drop_prob), |
|
|
83 |
nn.Conv2d( |
|
|
84 |
in_channels=n_filters * depth_multiplier, |
|
|
85 |
out_channels=n_filters * depth_multiplier, |
|
|
86 |
kernel_size=(1, conv2_kernel_size), |
|
|
87 |
padding="same", |
|
|
88 |
groups=n_filters * depth_multiplier, |
|
|
89 |
bias=False, |
|
|
90 |
), |
|
|
91 |
nn.BatchNorm2d(n_filters * depth_multiplier), |
|
|
92 |
activation(), |
|
|
93 |
nn.AvgPool2d(kernel_size=(1, pool2_size)), |
|
|
94 |
nn.Dropout(drop_prob), |
|
|
95 |
) |
|
|
96 |
|
|
|
97 |
|
|
|
98 |
class _PositionalEncoding(nn.Module): |
|
|
99 |
""" |
|
|
100 |
Positional encoding module that adds learnable positional embeddings. |
|
|
101 |
|
|
|
102 |
Parameters |
|
|
103 |
---------- |
|
|
104 |
seq_length : int |
|
|
105 |
Sequence length. |
|
|
106 |
d_model : int |
|
|
107 |
Dimensionality of the model. |
|
|
108 |
""" |
|
|
109 |
|
|
|
110 |
def __init__(self, seq_length: int, d_model: int) -> None: |
|
|
111 |
super().__init__() |
|
|
112 |
self.seq_length = seq_length |
|
|
113 |
self.d_model = d_model |
|
|
114 |
self.pe = nn.Parameter(torch.zeros(1, seq_length, d_model)) |
|
|
115 |
|
|
|
116 |
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
117 |
x = x + self.pe |
|
|
118 |
return x |
|
|
119 |
|
|
|
120 |
|
|
|
121 |
class _Transformer(nn.Module): |
|
|
122 |
""" |
|
|
123 |
Transformer encoder module with learnable class token and positional encoding. |
|
|
124 |
|
|
|
125 |
Parameters |
|
|
126 |
---------- |
|
|
127 |
seq_length : int |
|
|
128 |
Sequence length of the input. |
|
|
129 |
d_model : int |
|
|
130 |
Dimensionality of the model. |
|
|
131 |
num_heads : int |
|
|
132 |
Number of heads in the multihead attention. |
|
|
133 |
feedforward_ratio : float |
|
|
134 |
Ratio to compute the dimension of the feedforward network. |
|
|
135 |
drop_prob : float, optional |
|
|
136 |
Dropout probability, by default 0.5. |
|
|
137 |
num_layers : int, optional |
|
|
138 |
Number of transformer encoder layers, by default 4. |
|
|
139 |
""" |
|
|
140 |
|
|
|
141 |
def __init__( |
|
|
142 |
self, |
|
|
143 |
seq_length: int, |
|
|
144 |
d_model: int, |
|
|
145 |
num_heads: int, |
|
|
146 |
feedforward_ratio: float, |
|
|
147 |
drop_prob: float = 0.5, |
|
|
148 |
num_layers: int = 4, |
|
|
149 |
) -> None: |
|
|
150 |
super().__init__() |
|
|
151 |
self.cls_embedding = nn.Parameter(torch.zeros(1, 1, d_model)) |
|
|
152 |
self.pos_embedding = _PositionalEncoding(seq_length + 1, d_model) |
|
|
153 |
|
|
|
154 |
dim_ff = int(d_model * feedforward_ratio) |
|
|
155 |
self.dropout = nn.Dropout(drop_prob) |
|
|
156 |
self.trans = nn.TransformerEncoder( |
|
|
157 |
nn.TransformerEncoderLayer( |
|
|
158 |
d_model, |
|
|
159 |
num_heads, |
|
|
160 |
dim_ff, |
|
|
161 |
drop_prob, |
|
|
162 |
batch_first=True, |
|
|
163 |
norm_first=True, |
|
|
164 |
), |
|
|
165 |
num_layers, |
|
|
166 |
norm=nn.LayerNorm(d_model), |
|
|
167 |
) |
|
|
168 |
|
|
|
169 |
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
170 |
batch_size = x.shape[0] |
|
|
171 |
x = torch.cat((self.cls_embedding.expand(batch_size, -1, -1), x), dim=1) |
|
|
172 |
x = self.pos_embedding(x) |
|
|
173 |
x = self.dropout(x) |
|
|
174 |
return self.trans(x)[:, 0] |
|
|
175 |
|
|
|
176 |
|
|
|
177 |
class _DenseLayers(nn.Sequential): |
|
|
178 |
""" |
|
|
179 |
Final classification layers. |
|
|
180 |
|
|
|
181 |
Parameters |
|
|
182 |
---------- |
|
|
183 |
linear_in : int |
|
|
184 |
Input dimension to the linear layer. |
|
|
185 |
n_classes : int |
|
|
186 |
Number of output classes. |
|
|
187 |
""" |
|
|
188 |
|
|
|
189 |
def __init__(self, linear_in: int, n_classes: int): |
|
|
190 |
super().__init__( |
|
|
191 |
nn.Flatten(), |
|
|
192 |
nn.Linear(linear_in, n_classes), |
|
|
193 |
) |
|
|
194 |
|
|
|
195 |
|
|
|
196 |
class MSVTNet(EEGModuleMixin, nn.Module): |
|
|
197 |
"""MSVTNet model from Liu K et al (2024) from [msvt2024]_. |
|
|
198 |
|
|
|
199 |
This model implements a multi-scale convolutional transformer network |
|
|
200 |
for EEG signal classification, as described in [msvt2024]_. |
|
|
201 |
|
|
|
202 |
.. figure:: https://raw.githubusercontent.com/SheepTAO/MSVTNet/refs/heads/main/MSVTNet_Arch.png |
|
|
203 |
:align: center |
|
|
204 |
:alt: MSVTNet Architecture |
|
|
205 |
|
|
|
206 |
Parameters |
|
|
207 |
---------- |
|
|
208 |
n_filters_list : List[int], optional |
|
|
209 |
List of filter numbers for each TSConv block, by default (9, 9, 9, 9). |
|
|
210 |
conv1_kernels_size : List[int], optional |
|
|
211 |
List of kernel sizes for the first convolution in each TSConv block, |
|
|
212 |
by default (15, 31, 63, 125). |
|
|
213 |
conv2_kernel_size : int, optional |
|
|
214 |
Kernel size for the second convolution in TSConv blocks, by default 15. |
|
|
215 |
depth_multiplier : int, optional |
|
|
216 |
Depth multiplier for depthwise convolution, by default 2. |
|
|
217 |
pool1_size : int, optional |
|
|
218 |
Pooling size for the first pooling layer in TSConv blocks, by default 8. |
|
|
219 |
pool2_size : int, optional |
|
|
220 |
Pooling size for the second pooling layer in TSConv blocks, by default 7. |
|
|
221 |
drop_prob : float, optional |
|
|
222 |
Dropout probability for convolutional layers, by default 0.3. |
|
|
223 |
num_heads : int, optional |
|
|
224 |
Number of attention heads in the transformer encoder, by default 8. |
|
|
225 |
feedforward_ratio : float, optional |
|
|
226 |
Ratio to compute feedforward dimension in the transformer, by default 1. |
|
|
227 |
drop_prob_trans : float, optional |
|
|
228 |
Dropout probability for the transformer, by default 0.5. |
|
|
229 |
num_layers : int, optional |
|
|
230 |
Number of transformer encoder layers, by default 2. |
|
|
231 |
activation : Type[nn.Module], optional |
|
|
232 |
Activation function class to use, by default nn.ELU. |
|
|
233 |
return_features : bool, optional |
|
|
234 |
Whether to return predictions from branch classifiers, by default False. |
|
|
235 |
|
|
|
236 |
Notes |
|
|
237 |
----- |
|
|
238 |
This implementation is not guaranteed to be correct, has not been checked |
|
|
239 |
by original authors, only reimplemented based on the original code [msvt2024code]_. |
|
|
240 |
|
|
|
241 |
References |
|
|
242 |
---------- |
|
|
243 |
.. [msvt2024] Liu, K., et al. (2024). MSVTNet: Multi-Scale Vision |
|
|
244 |
Transformer Neural Network for EEG-Based Motor Imagery Decoding. |
|
|
245 |
IEEE Journal of Biomedical an Health Informatics. |
|
|
246 |
.. [msvt2024code] Liu, K., et al. (2024). MSVTNet: Multi-Scale Vision |
|
|
247 |
Transformer Neural Network for EEG-Based Motor Imagery Decoding. |
|
|
248 |
Source Code: https://github.com/SheepTAO/MSVTNet |
|
|
249 |
""" |
|
|
250 |
|
|
|
251 |
def __init__( |
|
|
252 |
self, |
|
|
253 |
# braindecode parameters |
|
|
254 |
n_chans: Optional[int] = None, |
|
|
255 |
n_outputs: Optional[int] = None, |
|
|
256 |
n_times: Optional[int] = None, |
|
|
257 |
input_window_seconds: Optional[float] = None, |
|
|
258 |
sfreq: Optional[float] = None, |
|
|
259 |
chs_info: Optional[List[Dict]] = None, |
|
|
260 |
# Model's parameters |
|
|
261 |
n_filters_list: Tuple[int, ...] = (9, 9, 9, 9), |
|
|
262 |
conv1_kernels_size: Tuple[int, ...] = (15, 31, 63, 125), |
|
|
263 |
conv2_kernel_size: int = 15, |
|
|
264 |
depth_multiplier: int = 2, |
|
|
265 |
pool1_size: int = 8, |
|
|
266 |
pool2_size: int = 7, |
|
|
267 |
drop_prob: float = 0.3, |
|
|
268 |
num_heads: int = 8, |
|
|
269 |
feedforward_ratio: float = 1, |
|
|
270 |
drop_prob_trans: float = 0.5, |
|
|
271 |
num_layers: int = 2, |
|
|
272 |
activation: Type[nn.Module] = nn.ELU, |
|
|
273 |
return_features: bool = False, |
|
|
274 |
): |
|
|
275 |
super().__init__( |
|
|
276 |
n_outputs=n_outputs, |
|
|
277 |
n_chans=n_chans, |
|
|
278 |
chs_info=chs_info, |
|
|
279 |
n_times=n_times, |
|
|
280 |
input_window_seconds=input_window_seconds, |
|
|
281 |
sfreq=sfreq, |
|
|
282 |
) |
|
|
283 |
del n_outputs, n_chans, chs_info, n_times, input_window_seconds, sfreq |
|
|
284 |
|
|
|
285 |
self.return_features = return_features |
|
|
286 |
assert len(n_filters_list) == len(conv1_kernels_size), ( |
|
|
287 |
"The length of n_filters_list and conv1_kernel_sizes should be equal." |
|
|
288 |
) |
|
|
289 |
|
|
|
290 |
self.ensure_dim = Rearrange("batch chans time -> batch 1 chans time") |
|
|
291 |
self.mstsconv = nn.ModuleList( |
|
|
292 |
[ |
|
|
293 |
nn.Sequential( |
|
|
294 |
_TSConv( |
|
|
295 |
self.n_chans, |
|
|
296 |
n_filters_list[b], |
|
|
297 |
conv1_kernels_size[b], |
|
|
298 |
conv2_kernel_size, |
|
|
299 |
depth_multiplier, |
|
|
300 |
pool1_size, |
|
|
301 |
pool2_size, |
|
|
302 |
drop_prob, |
|
|
303 |
activation, |
|
|
304 |
), |
|
|
305 |
Rearrange("batch channels 1 time -> batch time channels"), |
|
|
306 |
) |
|
|
307 |
for b in range(len(n_filters_list)) |
|
|
308 |
] |
|
|
309 |
) |
|
|
310 |
branch_linear_in = self._forward_flatten(cat=False) |
|
|
311 |
self.branch_head = nn.ModuleList( |
|
|
312 |
[ |
|
|
313 |
_DenseLayers(branch_linear_in[b].shape[1], self.n_outputs) |
|
|
314 |
for b in range(len(n_filters_list)) |
|
|
315 |
] |
|
|
316 |
) |
|
|
317 |
|
|
|
318 |
seq_len, d_model = self._forward_mstsconv().shape[1:3] # type: ignore |
|
|
319 |
self.transformer = _Transformer( |
|
|
320 |
seq_len, |
|
|
321 |
d_model, |
|
|
322 |
num_heads, |
|
|
323 |
feedforward_ratio, |
|
|
324 |
drop_prob_trans, |
|
|
325 |
num_layers, |
|
|
326 |
) |
|
|
327 |
|
|
|
328 |
linear_in = self._forward_flatten().shape[1] # type: ignore |
|
|
329 |
self.flatten_layer = nn.Flatten() |
|
|
330 |
self.final_layer = nn.Linear(linear_in, self.n_outputs) |
|
|
331 |
|
|
|
332 |
def _forward_mstsconv( |
|
|
333 |
self, cat: bool = True |
|
|
334 |
) -> Union[torch.Tensor, List[torch.Tensor]]: |
|
|
335 |
x = torch.randn(1, 1, self.n_chans, self.n_times) |
|
|
336 |
x = [tsconv(x) for tsconv in self.mstsconv] |
|
|
337 |
if cat: |
|
|
338 |
x = torch.cat(x, dim=2) |
|
|
339 |
return x |
|
|
340 |
|
|
|
341 |
def _forward_flatten( |
|
|
342 |
self, cat: bool = True |
|
|
343 |
) -> Union[torch.Tensor, List[torch.Tensor]]: |
|
|
344 |
x = self._forward_mstsconv(cat) |
|
|
345 |
if cat: |
|
|
346 |
x = self.transformer(x) |
|
|
347 |
x = x.flatten(start_dim=1, end_dim=-1) |
|
|
348 |
else: |
|
|
349 |
x = [_.flatten(start_dim=1, end_dim=-1) for _ in x] |
|
|
350 |
return x |
|
|
351 |
|
|
|
352 |
def forward( |
|
|
353 |
self, x: torch.Tensor |
|
|
354 |
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: |
|
|
355 |
# x with shape: (batch, n_chans, n_times) |
|
|
356 |
x = self.ensure_dim(x) |
|
|
357 |
# x with shape: (batch, 1, n_chans, n_times) |
|
|
358 |
x_list = [tsconv(x) for tsconv in self.mstsconv] |
|
|
359 |
# x_list contains 4 tensors, each of shape: [batch_size, seq_len, embed_dim] |
|
|
360 |
branch_preds = [ |
|
|
361 |
branch(x_list[idx]) for idx, branch in enumerate(self.branch_head) |
|
|
362 |
] |
|
|
363 |
# branch_preds contains 4 tensors, each of shape: [batch_size, num_classes] |
|
|
364 |
x = torch.stack(x_list, dim=2) |
|
|
365 |
x = x.view(x.size(0), x.size(1), -1) |
|
|
366 |
# x shape after concatenation: [batch_size, seq_len, total_embed_dim] |
|
|
367 |
x = self.transformer(x) |
|
|
368 |
# x shape after transformer: [batch_size, embed_dim] |
|
|
369 |
|
|
|
370 |
x = self.final_layer(x) |
|
|
371 |
return (x, branch_preds) if self.return_features else x |