|
a |
|
b/models.py |
|
|
1 |
""" |
|
|
2 |
Copyright (C) 2022 King Saud University, Saudi Arabia |
|
|
3 |
SPDX-License-Identifier: Apache-2.0 |
|
|
4 |
|
|
|
5 |
Licensed under the Apache License, Version 2.0 (the "License"); you may not use |
|
|
6 |
this file except in compliance with the License. You may obtain a copy of the |
|
|
7 |
License at |
|
|
8 |
|
|
|
9 |
http://www.apache.org/licenses/LICENSE-2.0 |
|
|
10 |
|
|
|
11 |
Unless required by applicable law or agreed to in writing, software distributed |
|
|
12 |
under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR |
|
|
13 |
CONDITIONS OF ANY KIND, either express or implied. See the License for the |
|
|
14 |
specific language governing permissions and limitations under the License. |
|
|
15 |
|
|
|
16 |
Author: Hamdi Altaheri |
|
|
17 |
""" |
|
|
18 |
|
|
|
19 |
#%% |
|
|
20 |
import tensorflow as tf |
|
|
21 |
from tensorflow.keras.models import Model, Sequential |
|
|
22 |
from tensorflow.keras.layers import Dense, Dropout, Activation, AveragePooling2D, MaxPooling2D |
|
|
23 |
from tensorflow.keras.layers import Conv1D, Conv2D, SeparableConv2D, DepthwiseConv2D |
|
|
24 |
from tensorflow.keras.layers import BatchNormalization, LayerNormalization, Flatten |
|
|
25 |
from tensorflow.keras.layers import Add, Concatenate, Lambda, Input, Permute |
|
|
26 |
from tensorflow.keras.regularizers import L2 |
|
|
27 |
from tensorflow.keras.constraints import max_norm |
|
|
28 |
|
|
|
29 |
from tensorflow.keras import backend as K |
|
|
30 |
|
|
|
31 |
from attention_models import attention_block |
|
|
32 |
|
|
|
33 |
#%% The proposed ATCNet model, https://doi.org/10.1109/TII.2022.3197419 |
|
|
34 |
def ATCNet_(n_classes, in_chans = 22, in_samples = 1125, n_windows = 5, attention = 'mha', |
|
|
35 |
eegn_F1 = 16, eegn_D = 2, eegn_kernelSize = 64, eegn_poolSize = 7, eegn_dropout=0.3, |
|
|
36 |
tcn_depth = 2, tcn_kernelSize = 4, tcn_filters = 32, tcn_dropout = 0.3, |
|
|
37 |
tcn_activation = 'elu', fuse = 'average'): |
|
|
38 |
|
|
|
39 |
""" ATCNet model from Altaheri et al 2023. |
|
|
40 |
See details at https://ieeexplore.ieee.org/abstract/document/9852687 |
|
|
41 |
|
|
|
42 |
Notes |
|
|
43 |
----- |
|
|
44 |
The initial values in this model are based on the values identified by |
|
|
45 |
the authors |
|
|
46 |
|
|
|
47 |
References |
|
|
48 |
---------- |
|
|
49 |
.. H. Altaheri, G. Muhammad, and M. Alsulaiman. "Physics-informed |
|
|
50 |
attention temporal convolutional network for EEG-based motor imagery |
|
|
51 |
classification." IEEE Transactions on Industrial Informatics, |
|
|
52 |
vol. 19, no. 2, pp. 2249-2258, (2023) |
|
|
53 |
https://doi.org/10.1109/TII.2022.3197419 |
|
|
54 |
""" |
|
|
55 |
input_1 = Input(shape = (1,in_chans, in_samples)) # TensorShape([None, 1, 22, 1125]) |
|
|
56 |
input_2 = Permute((3,2,1))(input_1) |
|
|
57 |
|
|
|
58 |
dense_weightDecay = 0.5 |
|
|
59 |
conv_weightDecay = 0.009 |
|
|
60 |
conv_maxNorm = 0.6 |
|
|
61 |
from_logits = False |
|
|
62 |
|
|
|
63 |
numFilters = eegn_F1 |
|
|
64 |
F2 = numFilters*eegn_D |
|
|
65 |
|
|
|
66 |
block1 = Conv_block_(input_layer = input_2, F1 = eegn_F1, D = eegn_D, |
|
|
67 |
kernLength = eegn_kernelSize, poolSize = eegn_poolSize, |
|
|
68 |
weightDecay = conv_weightDecay, maxNorm = conv_maxNorm, |
|
|
69 |
in_chans = in_chans, dropout = eegn_dropout) |
|
|
70 |
block1 = Lambda(lambda x: x[:,:,-1,:])(block1) |
|
|
71 |
|
|
|
72 |
# Sliding window |
|
|
73 |
sw_concat = [] # to store concatenated or averaged sliding window outputs |
|
|
74 |
for i in range(n_windows): |
|
|
75 |
st = i |
|
|
76 |
end = block1.shape[1]-n_windows+i+1 |
|
|
77 |
block2 = block1[:, st:end, :] |
|
|
78 |
|
|
|
79 |
# Attention_model |
|
|
80 |
if attention is not None: |
|
|
81 |
if (attention == 'se' or attention == 'cbam'): |
|
|
82 |
block2 = Permute((2, 1))(block2) # shape=(None, 32, 16) |
|
|
83 |
block2 = attention_block(block2, attention) |
|
|
84 |
block2 = Permute((2, 1))(block2) # shape=(None, 16, 32) |
|
|
85 |
else: block2 = attention_block(block2, attention) |
|
|
86 |
|
|
|
87 |
# Temporal convolutional network (TCN) |
|
|
88 |
block3 = TCN_block_(input_layer = block2, input_dimension = F2, depth = tcn_depth, |
|
|
89 |
kernel_size = tcn_kernelSize, filters = tcn_filters, |
|
|
90 |
weightDecay = conv_weightDecay, maxNorm = conv_maxNorm, |
|
|
91 |
dropout = tcn_dropout, activation = tcn_activation) |
|
|
92 |
# Get feature maps of the last sequence |
|
|
93 |
block3 = Lambda(lambda x: x[:,-1,:])(block3) |
|
|
94 |
|
|
|
95 |
# Outputs of sliding window: Average_after_dense or concatenate_then_dense |
|
|
96 |
if(fuse == 'average'): |
|
|
97 |
sw_concat.append(Dense(n_classes, kernel_regularizer=L2(dense_weightDecay))(block3)) |
|
|
98 |
elif(fuse == 'concat'): |
|
|
99 |
if i == 0: |
|
|
100 |
sw_concat = block3 |
|
|
101 |
else: |
|
|
102 |
sw_concat = Concatenate()([sw_concat, block3]) |
|
|
103 |
|
|
|
104 |
if(fuse == 'average'): |
|
|
105 |
if len(sw_concat) > 1: # more than one window |
|
|
106 |
sw_concat = tf.keras.layers.Average()(sw_concat[:]) |
|
|
107 |
else: # one window (# windows = 1) |
|
|
108 |
sw_concat = sw_concat[0] |
|
|
109 |
elif(fuse == 'concat'): |
|
|
110 |
sw_concat = Dense(n_classes, kernel_regularizer=L2(dense_weightDecay))(sw_concat) |
|
|
111 |
|
|
|
112 |
if from_logits: # No activation here because we are using from_logits=True |
|
|
113 |
out = Activation('linear', name = 'linear')(sw_concat) |
|
|
114 |
else: # Using softmax activation |
|
|
115 |
out = Activation('softmax', name = 'softmax')(sw_concat) |
|
|
116 |
|
|
|
117 |
return Model(inputs = input_1, outputs = out) |
|
|
118 |
|
|
|
119 |
#%% Convolutional (CV) block used in the ATCNet model |
|
|
120 |
def Conv_block(input_layer, F1=4, kernLength=64, poolSize=8, D=2, in_chans=22, dropout=0.1): |
|
|
121 |
""" Conv_block |
|
|
122 |
|
|
|
123 |
Notes |
|
|
124 |
----- |
|
|
125 |
This block is the same as EEGNet with SeparableConv2D replaced by Conv2D |
|
|
126 |
The original code for this model is available at: https://github.com/vlawhern/arl-eegmodels |
|
|
127 |
See details at https://arxiv.org/abs/1611.08024 |
|
|
128 |
""" |
|
|
129 |
F2= F1*D |
|
|
130 |
block1 = Conv2D(F1, (kernLength, 1), padding = 'same',data_format='channels_last',use_bias = False)(input_layer) |
|
|
131 |
block1 = BatchNormalization(axis = -1)(block1) |
|
|
132 |
block2 = DepthwiseConv2D((1, in_chans), use_bias = False, |
|
|
133 |
depth_multiplier = D, |
|
|
134 |
data_format='channels_last', |
|
|
135 |
depthwise_constraint = max_norm(1.))(block1) |
|
|
136 |
block2 = BatchNormalization(axis = -1)(block2) |
|
|
137 |
block2 = Activation('elu')(block2) |
|
|
138 |
block2 = AveragePooling2D((8,1),data_format='channels_last')(block2) |
|
|
139 |
block2 = Dropout(dropout)(block2) |
|
|
140 |
block3 = Conv2D(F2, (16, 1), |
|
|
141 |
data_format='channels_last', |
|
|
142 |
use_bias = False, padding = 'same')(block2) |
|
|
143 |
block3 = BatchNormalization(axis = -1)(block3) |
|
|
144 |
block3 = Activation('elu')(block3) |
|
|
145 |
|
|
|
146 |
block3 = AveragePooling2D((poolSize,1),data_format='channels_last')(block3) |
|
|
147 |
block3 = Dropout(dropout)(block3) |
|
|
148 |
return block3 |
|
|
149 |
|
|
|
150 |
def Conv_block_(input_layer, F1=4, kernLength=64, poolSize=8, D=2, in_chans=22, |
|
|
151 |
weightDecay = 0.009, maxNorm = 0.6, dropout=0.25): |
|
|
152 |
""" Conv_block |
|
|
153 |
|
|
|
154 |
Notes |
|
|
155 |
----- |
|
|
156 |
using different regularization methods. |
|
|
157 |
""" |
|
|
158 |
|
|
|
159 |
F2= F1*D |
|
|
160 |
block1 = Conv2D(F1, (kernLength, 1), padding = 'same', data_format='channels_last', |
|
|
161 |
kernel_regularizer=L2(weightDecay), |
|
|
162 |
|
|
|
163 |
# In a Conv2D layer with data_format="channels_last", the weight tensor has shape |
|
|
164 |
# (rows, cols, input_depth, output_depth), set axis to [0, 1, 2] to constrain |
|
|
165 |
# the weights of each filter tensor of size (rows, cols, input_depth). |
|
|
166 |
kernel_constraint = max_norm(maxNorm, axis=[0,1,2]), |
|
|
167 |
use_bias = False)(input_layer) |
|
|
168 |
block1 = BatchNormalization(axis = -1)(block1) # bn_axis = -1 if data_format() == 'channels_last' else 1 |
|
|
169 |
|
|
|
170 |
block2 = DepthwiseConv2D((1, in_chans), |
|
|
171 |
depth_multiplier = D, |
|
|
172 |
data_format='channels_last', |
|
|
173 |
depthwise_regularizer=L2(weightDecay), |
|
|
174 |
depthwise_constraint = max_norm(maxNorm, axis=[0,1,2]), |
|
|
175 |
use_bias = False)(block1) |
|
|
176 |
block2 = BatchNormalization(axis = -1)(block2) |
|
|
177 |
block2 = Activation('elu')(block2) |
|
|
178 |
block2 = AveragePooling2D((8,1),data_format='channels_last')(block2) |
|
|
179 |
block2 = Dropout(dropout)(block2) |
|
|
180 |
|
|
|
181 |
block3 = Conv2D(F2, (16, 1), |
|
|
182 |
data_format='channels_last', |
|
|
183 |
kernel_regularizer=L2(weightDecay), |
|
|
184 |
kernel_constraint = max_norm(maxNorm, axis=[0,1,2]), |
|
|
185 |
use_bias = False, padding = 'same')(block2) |
|
|
186 |
block3 = BatchNormalization(axis = -1)(block3) |
|
|
187 |
block3 = Activation('elu')(block3) |
|
|
188 |
|
|
|
189 |
block3 = AveragePooling2D((poolSize,1),data_format='channels_last')(block3) |
|
|
190 |
block3 = Dropout(dropout)(block3) |
|
|
191 |
return block3 |
|
|
192 |
|
|
|
193 |
#%% Temporal convolutional (TC) block used in the ATCNet model |
|
|
194 |
def TCN_block(input_layer,input_dimension,depth,kernel_size,filters,dropout,activation='relu'): |
|
|
195 |
""" TCN_block from Bai et al 2018 |
|
|
196 |
Temporal Convolutional Network (TCN) |
|
|
197 |
|
|
|
198 |
Notes |
|
|
199 |
----- |
|
|
200 |
THe original code available at https://github.com/locuslab/TCN/blob/master/TCN/tcn.py |
|
|
201 |
This implementation has a slight modification from the original code |
|
|
202 |
and it is taken from the code by Ingolfsson et al at https://github.com/iis-eth-zurich/eeg-tcnet |
|
|
203 |
See details at https://arxiv.org/abs/2006.00622 |
|
|
204 |
|
|
|
205 |
References |
|
|
206 |
---------- |
|
|
207 |
.. Bai, S., Kolter, J. Z., & Koltun, V. (2018). |
|
|
208 |
An empirical evaluation of generic convolutional and recurrent networks |
|
|
209 |
for sequence modeling. |
|
|
210 |
arXiv preprint arXiv:1803.01271. |
|
|
211 |
""" |
|
|
212 |
|
|
|
213 |
block = Conv1D(filters,kernel_size=kernel_size,dilation_rate=1,activation='linear', |
|
|
214 |
padding = 'causal',kernel_initializer='he_uniform')(input_layer) |
|
|
215 |
block = BatchNormalization()(block) |
|
|
216 |
block = Activation(activation)(block) |
|
|
217 |
block = Dropout(dropout)(block) |
|
|
218 |
block = Conv1D(filters,kernel_size=kernel_size,dilation_rate=1,activation='linear', |
|
|
219 |
padding = 'causal',kernel_initializer='he_uniform')(block) |
|
|
220 |
block = BatchNormalization()(block) |
|
|
221 |
block = Activation(activation)(block) |
|
|
222 |
block = Dropout(dropout)(block) |
|
|
223 |
if(input_dimension != filters): |
|
|
224 |
conv = Conv1D(filters,kernel_size=1,padding='same')(input_layer) |
|
|
225 |
added = Add()([block,conv]) |
|
|
226 |
else: |
|
|
227 |
added = Add()([block,input_layer]) |
|
|
228 |
out = Activation(activation)(added) |
|
|
229 |
|
|
|
230 |
for i in range(depth-1): |
|
|
231 |
block = Conv1D(filters,kernel_size=kernel_size,dilation_rate=2**(i+1),activation='linear', |
|
|
232 |
padding = 'causal',kernel_initializer='he_uniform')(out) |
|
|
233 |
block = BatchNormalization()(block) |
|
|
234 |
block = Activation(activation)(block) |
|
|
235 |
block = Dropout(dropout)(block) |
|
|
236 |
block = Conv1D(filters,kernel_size=kernel_size,dilation_rate=2**(i+1),activation='linear', |
|
|
237 |
padding = 'causal',kernel_initializer='he_uniform')(block) |
|
|
238 |
block = BatchNormalization()(block) |
|
|
239 |
block = Activation(activation)(block) |
|
|
240 |
block = Dropout(dropout)(block) |
|
|
241 |
added = Add()([block, out]) |
|
|
242 |
out = Activation(activation)(added) |
|
|
243 |
|
|
|
244 |
return out |
|
|
245 |
|
|
|
246 |
def TCN_block_(input_layer,input_dimension,depth,kernel_size,filters, dropout, |
|
|
247 |
weightDecay = 0.009, maxNorm = 0.6, activation='relu'): |
|
|
248 |
""" TCN_block from Bai et al 2018 |
|
|
249 |
Temporal Convolutional Network (TCN) |
|
|
250 |
|
|
|
251 |
Notes |
|
|
252 |
----- |
|
|
253 |
using different regularization methods |
|
|
254 |
""" |
|
|
255 |
|
|
|
256 |
block = Conv1D(filters, kernel_size=kernel_size, dilation_rate=1, activation='linear', |
|
|
257 |
kernel_regularizer=L2(weightDecay), |
|
|
258 |
kernel_constraint = max_norm(maxNorm, axis=[0,1]), |
|
|
259 |
|
|
|
260 |
padding = 'causal',kernel_initializer='he_uniform')(input_layer) |
|
|
261 |
block = BatchNormalization()(block) |
|
|
262 |
block = Activation(activation)(block) |
|
|
263 |
block = Dropout(dropout)(block) |
|
|
264 |
block = Conv1D(filters,kernel_size=kernel_size,dilation_rate=1,activation='linear', |
|
|
265 |
kernel_regularizer=L2(weightDecay), |
|
|
266 |
kernel_constraint = max_norm(maxNorm, axis=[0,1]), |
|
|
267 |
|
|
|
268 |
padding = 'causal',kernel_initializer='he_uniform')(block) |
|
|
269 |
block = BatchNormalization()(block) |
|
|
270 |
block = Activation(activation)(block) |
|
|
271 |
block = Dropout(dropout)(block) |
|
|
272 |
if(input_dimension != filters): |
|
|
273 |
conv = Conv1D(filters,kernel_size=1, |
|
|
274 |
kernel_regularizer=L2(weightDecay), |
|
|
275 |
kernel_constraint = max_norm(maxNorm, axis=[0,1]), |
|
|
276 |
|
|
|
277 |
padding='same')(input_layer) |
|
|
278 |
added = Add()([block,conv]) |
|
|
279 |
else: |
|
|
280 |
added = Add()([block,input_layer]) |
|
|
281 |
out = Activation(activation)(added) |
|
|
282 |
|
|
|
283 |
for i in range(depth-1): |
|
|
284 |
block = Conv1D(filters,kernel_size=kernel_size,dilation_rate=2**(i+1),activation='linear', |
|
|
285 |
kernel_regularizer=L2(weightDecay), |
|
|
286 |
kernel_constraint = max_norm(maxNorm, axis=[0,1]), |
|
|
287 |
|
|
|
288 |
padding = 'causal',kernel_initializer='he_uniform')(out) |
|
|
289 |
block = BatchNormalization()(block) |
|
|
290 |
block = Activation(activation)(block) |
|
|
291 |
block = Dropout(dropout)(block) |
|
|
292 |
block = Conv1D(filters,kernel_size=kernel_size,dilation_rate=2**(i+1),activation='linear', |
|
|
293 |
kernel_regularizer=L2(weightDecay), |
|
|
294 |
kernel_constraint = max_norm(maxNorm, axis=[0,1]), |
|
|
295 |
|
|
|
296 |
padding = 'causal',kernel_initializer='he_uniform')(block) |
|
|
297 |
block = BatchNormalization()(block) |
|
|
298 |
block = Activation(activation)(block) |
|
|
299 |
block = Dropout(dropout)(block) |
|
|
300 |
added = Add()([block, out]) |
|
|
301 |
out = Activation(activation)(added) |
|
|
302 |
|
|
|
303 |
return out |
|
|
304 |
|
|
|
305 |
|
|
|
306 |
#%% Reproduced TCNet_Fusion model: https://doi.org/10.1016/j.bspc.2021.102826 |
|
|
307 |
def TCNet_Fusion(n_classes, Chans=22, Samples=1125, layers=2, kernel_s=4, filt=12, |
|
|
308 |
dropout=0.3, activation='elu', F1=24, D=2, kernLength=32, dropout_eeg=0.3): |
|
|
309 |
""" TCNet_Fusion model from Musallam et al 2021. |
|
|
310 |
See details at https://doi.org/10.1016/j.bspc.2021.102826 |
|
|
311 |
|
|
|
312 |
Notes |
|
|
313 |
----- |
|
|
314 |
The initial values in this model are based on the values identified by |
|
|
315 |
the authors |
|
|
316 |
|
|
|
317 |
References |
|
|
318 |
---------- |
|
|
319 |
.. Musallam, Y.K., AlFassam, N.I., Muhammad, G., Amin, S.U., Alsulaiman, |
|
|
320 |
M., Abdul, W., Altaheri, H., Bencherif, M.A. and Algabri, M., 2021. |
|
|
321 |
Electroencephalography-based motor imagery classification |
|
|
322 |
using temporal convolutional network fusion. |
|
|
323 |
Biomedical Signal Processing and Control, 69, p.102826. |
|
|
324 |
""" |
|
|
325 |
input1 = Input(shape = (1,Chans, Samples)) |
|
|
326 |
input2 = Permute((3,2,1))(input1) |
|
|
327 |
regRate=.25 |
|
|
328 |
|
|
|
329 |
numFilters = F1 |
|
|
330 |
F2= numFilters*D |
|
|
331 |
|
|
|
332 |
EEGNet_sep = EEGNet(input_layer=input2,F1=F1,kernLength=kernLength,D=D,Chans=Chans,dropout=dropout_eeg) |
|
|
333 |
block2 = Lambda(lambda x: x[:,:,-1,:])(EEGNet_sep) |
|
|
334 |
FC = Flatten()(block2) |
|
|
335 |
|
|
|
336 |
outs = TCN_block(input_layer=block2,input_dimension=F2,depth=layers,kernel_size=kernel_s,filters=filt,dropout=dropout,activation=activation) |
|
|
337 |
|
|
|
338 |
Con1 = Concatenate()([block2,outs]) |
|
|
339 |
out = Flatten()(Con1) |
|
|
340 |
Con2 = Concatenate()([out,FC]) |
|
|
341 |
dense = Dense(n_classes, name = 'dense',kernel_constraint = max_norm(regRate))(Con2) |
|
|
342 |
softmax = Activation('softmax', name = 'softmax')(dense) |
|
|
343 |
|
|
|
344 |
return Model(inputs=input1,outputs=softmax) |
|
|
345 |
|
|
|
346 |
|
|
|
347 |
#%% Reproduced EEGTCNet model: https://arxiv.org/abs/2006.00622 |
|
|
348 |
def EEGTCNet(n_classes, Chans=22, Samples=1125, layers=2, kernel_s=4, filt=12, dropout=0.3, activation='elu', F1=8, D=2, kernLength=32, dropout_eeg=0.2): |
|
|
349 |
""" EEGTCNet model from Ingolfsson et al 2020. |
|
|
350 |
See details at https://arxiv.org/abs/2006.00622 |
|
|
351 |
|
|
|
352 |
The original code for this model is available at https://github.com/iis-eth-zurich/eeg-tcnet |
|
|
353 |
|
|
|
354 |
Notes |
|
|
355 |
----- |
|
|
356 |
The initial values in this model are based on the values identified by the authors |
|
|
357 |
|
|
|
358 |
References |
|
|
359 |
---------- |
|
|
360 |
.. Ingolfsson, T. M., Hersche, M., Wang, X., Kobayashi, N., |
|
|
361 |
Cavigelli, L., & Benini, L. (2020, October). |
|
|
362 |
Eeg-tcnet: An accurate temporal convolutional network |
|
|
363 |
for embedded motor-imagery brain–machine interfaces. |
|
|
364 |
In 2020 IEEE International Conference on Systems, |
|
|
365 |
Man, and Cybernetics (SMC) (pp. 2958-2965). IEEE. |
|
|
366 |
""" |
|
|
367 |
input1 = Input(shape = (1,Chans, Samples)) |
|
|
368 |
input2 = Permute((3,2,1))(input1) |
|
|
369 |
regRate=.25 |
|
|
370 |
numFilters = F1 |
|
|
371 |
F2= numFilters*D |
|
|
372 |
|
|
|
373 |
EEGNet_sep = EEGNet(input_layer=input2,F1=F1,kernLength=kernLength,D=D,Chans=Chans,dropout=dropout_eeg) |
|
|
374 |
block2 = Lambda(lambda x: x[:,:,-1,:])(EEGNet_sep) |
|
|
375 |
outs = TCN_block(input_layer=block2,input_dimension=F2,depth=layers,kernel_size=kernel_s,filters=filt,dropout=dropout,activation=activation) |
|
|
376 |
out = Lambda(lambda x: x[:,-1,:])(outs) |
|
|
377 |
dense = Dense(n_classes, name = 'dense',kernel_constraint = max_norm(regRate))(out) |
|
|
378 |
softmax = Activation('softmax', name = 'softmax')(dense) |
|
|
379 |
|
|
|
380 |
return Model(inputs=input1,outputs=softmax) |
|
|
381 |
|
|
|
382 |
#%% Reproduced MBEEG_SENet model: https://doi.org/10.3390/diagnostics12040995 |
|
|
383 |
def MBEEG_SENet(nb_classes, Chans, Samples, D=2): |
|
|
384 |
""" MBEEG_SENet model from Altuwaijri et al 2022. |
|
|
385 |
See details at https://doi.org/10.3390/diagnostics12040995 |
|
|
386 |
|
|
|
387 |
Notes |
|
|
388 |
----- |
|
|
389 |
The initial values in this model are based on the values identified by |
|
|
390 |
the authors |
|
|
391 |
|
|
|
392 |
References |
|
|
393 |
---------- |
|
|
394 |
.. G. Altuwaijri, G. Muhammad, H. Altaheri, & M. Alsulaiman. |
|
|
395 |
A Multi-Branch Convolutional Neural Network with Squeeze-and-Excitation |
|
|
396 |
Attention Blocks for EEG-Based Motor Imagery Signals Classification. |
|
|
397 |
Diagnostics, 12(4), 995, (2022). |
|
|
398 |
https://doi.org/10.3390/diagnostics12040995 |
|
|
399 |
""" |
|
|
400 |
|
|
|
401 |
input1 = Input(shape = (1,Chans, Samples)) |
|
|
402 |
input2 = Permute((3,2,1))(input1) |
|
|
403 |
regRate=.25 |
|
|
404 |
|
|
|
405 |
EEGNet_sep1 = EEGNet(input_layer=input2, F1=4, kernLength=16, D=D, Chans=Chans, dropout=0) |
|
|
406 |
EEGNet_sep2 = EEGNet(input_layer=input2, F1=8, kernLength=32, D=D, Chans=Chans, dropout=0.1) |
|
|
407 |
EEGNet_sep3 = EEGNet(input_layer=input2, F1=16, kernLength=64, D=D, Chans=Chans, dropout=0.2) |
|
|
408 |
|
|
|
409 |
SE1 = attention_block(EEGNet_sep1, 'se', ratio=4) |
|
|
410 |
SE2 = attention_block(EEGNet_sep2, 'se', ratio=4) |
|
|
411 |
SE3 = attention_block(EEGNet_sep3, 'se', ratio=2) |
|
|
412 |
|
|
|
413 |
|
|
|
414 |
FC1 = Flatten()(SE1) |
|
|
415 |
FC2 = Flatten()(SE2) |
|
|
416 |
FC3 = Flatten()(SE3) |
|
|
417 |
|
|
|
418 |
CON = Concatenate()([FC1,FC2,FC3]) |
|
|
419 |
|
|
|
420 |
dense1 = Dense(nb_classes, name = 'dense1',kernel_constraint = max_norm(regRate))(CON) |
|
|
421 |
softmax = Activation('softmax', name = 'softmax')(dense1) |
|
|
422 |
|
|
|
423 |
return Model(inputs=input1,outputs=softmax) |
|
|
424 |
|
|
|
425 |
|
|
|
426 |
|
|
|
427 |
#%% Reproduced EEGNeX model: https://arxiv.org/abs/2207.12369 |
|
|
428 |
def EEGNeX_8_32(n_timesteps, n_features, n_outputs): |
|
|
429 |
""" EEGNeX model from Chen et al 2022. |
|
|
430 |
See details at https://arxiv.org/abs/2207.12369 |
|
|
431 |
|
|
|
432 |
The original code for this model is available at https://github.com/chenxiachan/EEGNeX |
|
|
433 |
|
|
|
434 |
References |
|
|
435 |
---------- |
|
|
436 |
.. Chen, X., Teng, X., Chen, H., Pan, Y., & Geyer, P. (2022). |
|
|
437 |
Toward reliable signals decoding for electroencephalogram: |
|
|
438 |
A benchmark study to EEGNeX. arXiv preprint arXiv:2207.12369. |
|
|
439 |
""" |
|
|
440 |
|
|
|
441 |
model = Sequential() |
|
|
442 |
model.add(Input(shape=(1, n_features, n_timesteps))) |
|
|
443 |
|
|
|
444 |
model.add(Conv2D(filters=8, kernel_size=(1, 32), use_bias = False, padding='same', data_format="channels_first")) |
|
|
445 |
model.add(LayerNormalization()) |
|
|
446 |
model.add(Activation(activation='elu')) |
|
|
447 |
model.add(Conv2D(filters=32, kernel_size=(1, 32), use_bias = False, padding='same', data_format="channels_first")) |
|
|
448 |
model.add(LayerNormalization()) |
|
|
449 |
model.add(Activation(activation='elu')) |
|
|
450 |
|
|
|
451 |
model.add(DepthwiseConv2D(kernel_size=(n_features, 1), depth_multiplier=2, use_bias = False, depthwise_constraint=max_norm(1.), data_format="channels_first")) |
|
|
452 |
model.add(LayerNormalization()) |
|
|
453 |
model.add(Activation(activation='elu')) |
|
|
454 |
model.add(AveragePooling2D(pool_size=(1, 4), padding='same', data_format="channels_first")) |
|
|
455 |
model.add(Dropout(0.5)) |
|
|
456 |
|
|
|
457 |
|
|
|
458 |
model.add(Conv2D(filters=32, kernel_size=(1, 16), use_bias = False, padding='same', dilation_rate=(1, 2), data_format='channels_first')) |
|
|
459 |
model.add(LayerNormalization()) |
|
|
460 |
model.add(Activation(activation='elu')) |
|
|
461 |
|
|
|
462 |
model.add(Conv2D(filters=8, kernel_size=(1, 16), use_bias = False, padding='same', dilation_rate=(1, 4), data_format='channels_first')) |
|
|
463 |
model.add(LayerNormalization()) |
|
|
464 |
model.add(Activation(activation='elu')) |
|
|
465 |
model.add(Dropout(0.5)) |
|
|
466 |
|
|
|
467 |
model.add(Flatten()) |
|
|
468 |
model.add(Dense(n_outputs, kernel_constraint=max_norm(0.25))) |
|
|
469 |
model.add(Activation(activation='softmax')) |
|
|
470 |
|
|
|
471 |
# save a plot of the model |
|
|
472 |
# plot_model(model, show_shapes=True, to_file='EEGNeX_8_32.png') |
|
|
473 |
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy']) |
|
|
474 |
return model |
|
|
475 |
|
|
|
476 |
#%% Reproduced EEGNet model: https://arxiv.org/abs/1611.08024 |
|
|
477 |
def EEGNet_classifier(n_classes, Chans=22, Samples=1125, F1=8, D=2, kernLength=64, dropout_eeg=0.25): |
|
|
478 |
input1 = Input(shape = (1,Chans, Samples)) |
|
|
479 |
input2 = Permute((3,2,1))(input1) |
|
|
480 |
regRate=.25 |
|
|
481 |
|
|
|
482 |
eegnet = EEGNet(input_layer=input2, F1=F1, kernLength=kernLength, D=D, Chans=Chans, dropout=dropout_eeg) |
|
|
483 |
eegnet = Flatten()(eegnet) |
|
|
484 |
dense = Dense(n_classes, name = 'dense',kernel_constraint = max_norm(regRate))(eegnet) |
|
|
485 |
softmax = Activation('softmax', name = 'softmax')(dense) |
|
|
486 |
|
|
|
487 |
return Model(inputs=input1, outputs=softmax) |
|
|
488 |
|
|
|
489 |
def EEGNet(input_layer, F1=8, kernLength=64, D=2, Chans=22, dropout=0.25): |
|
|
490 |
""" EEGNet model from Lawhern et al 2018 |
|
|
491 |
See details at https://arxiv.org/abs/1611.08024 |
|
|
492 |
|
|
|
493 |
The original code for this model is available at: https://github.com/vlawhern/arl-eegmodels |
|
|
494 |
|
|
|
495 |
Notes |
|
|
496 |
----- |
|
|
497 |
The initial values in this model are based on the values identified by the authors |
|
|
498 |
|
|
|
499 |
References |
|
|
500 |
---------- |
|
|
501 |
.. Lawhern, V. J., Solon, A. J., Waytowich, N. R., Gordon, |
|
|
502 |
S. M., Hung, C. P., & Lance, B. J. (2018). |
|
|
503 |
EEGNet: A Compact Convolutional Network for EEG-based |
|
|
504 |
Brain-Computer Interfaces. |
|
|
505 |
arXiv preprint arXiv:1611.08024. |
|
|
506 |
""" |
|
|
507 |
F2= F1*D |
|
|
508 |
block1 = Conv2D(F1, (kernLength, 1), padding = 'same',data_format='channels_last',use_bias = False)(input_layer) |
|
|
509 |
block1 = BatchNormalization(axis = -1)(block1) |
|
|
510 |
block2 = DepthwiseConv2D((1, Chans), use_bias = False, |
|
|
511 |
depth_multiplier = D, |
|
|
512 |
data_format='channels_last', |
|
|
513 |
depthwise_constraint = max_norm(1.))(block1) |
|
|
514 |
block2 = BatchNormalization(axis = -1)(block2) |
|
|
515 |
block2 = Activation('elu')(block2) |
|
|
516 |
block2 = AveragePooling2D((8,1),data_format='channels_last')(block2) |
|
|
517 |
block2 = Dropout(dropout)(block2) |
|
|
518 |
block3 = SeparableConv2D(F2, (16, 1), |
|
|
519 |
data_format='channels_last', |
|
|
520 |
use_bias = False, padding = 'same')(block2) |
|
|
521 |
block3 = BatchNormalization(axis = -1)(block3) |
|
|
522 |
block3 = Activation('elu')(block3) |
|
|
523 |
block3 = AveragePooling2D((8,1),data_format='channels_last')(block3) |
|
|
524 |
block3 = Dropout(dropout)(block3) |
|
|
525 |
return block3 |
|
|
526 |
|
|
|
527 |
|
|
|
528 |
#%% Reproduced DeepConvNet model: https://doi.org/10.1002/hbm.23730 |
|
|
529 |
def DeepConvNet(nb_classes, Chans = 64, Samples = 256, |
|
|
530 |
dropoutRate = 0.5): |
|
|
531 |
""" Keras implementation of the Deep Convolutional Network as described in |
|
|
532 |
Schirrmeister et. al. (2017), Human Brain Mapping. |
|
|
533 |
See details at https://onlinelibrary.wiley.com/doi/full/10.1002/hbm.23730 |
|
|
534 |
|
|
|
535 |
The original code for this model is available at: https://github.com/braindecode/braindecode |
|
|
536 |
|
|
|
537 |
Notes |
|
|
538 |
----- |
|
|
539 |
The initial values in this model are based on the values identified by the authors |
|
|
540 |
|
|
|
541 |
This implementation is taken from code by the Army Research Laboratory (ARL) |
|
|
542 |
at https://github.com/vlawhern/arl-eegmodels |
|
|
543 |
|
|
|
544 |
References |
|
|
545 |
---------- |
|
|
546 |
.. Schirrmeister, R. T., Springenberg, J. T., Fiederer, L. D. J., |
|
|
547 |
Glasstetter, M., Eggensperger, K., Tangermann, M., ... & Ball, T. (2017). |
|
|
548 |
Deep learning with convolutional neural networks for EEG decoding |
|
|
549 |
and visualization. Human brain mapping, 38(11), 5391-5420. |
|
|
550 |
|
|
|
551 |
""" |
|
|
552 |
|
|
|
553 |
# start the model |
|
|
554 |
# input_main = Input((Chans, Samples, 1)) |
|
|
555 |
input_main = Input((1, Chans, Samples)) |
|
|
556 |
input_2 = Permute((2,3,1))(input_main) |
|
|
557 |
|
|
|
558 |
block1 = Conv2D(25, (1, 10), |
|
|
559 |
input_shape=(Chans, Samples, 1), |
|
|
560 |
kernel_constraint = max_norm(2., axis=(0,1,2)))(input_2) |
|
|
561 |
block1 = Conv2D(25, (Chans, 1), |
|
|
562 |
kernel_constraint = max_norm(2., axis=(0,1,2)))(block1) |
|
|
563 |
block1 = BatchNormalization(epsilon=1e-05, momentum=0.9)(block1) |
|
|
564 |
block1 = Activation('elu')(block1) |
|
|
565 |
block1 = MaxPooling2D(pool_size=(1, 3), strides=(1, 3))(block1) |
|
|
566 |
block1 = Dropout(dropoutRate)(block1) |
|
|
567 |
|
|
|
568 |
block2 = Conv2D(50, (1, 10), |
|
|
569 |
kernel_constraint = max_norm(2., axis=(0,1,2)))(block1) |
|
|
570 |
block2 = BatchNormalization(epsilon=1e-05, momentum=0.9)(block2) |
|
|
571 |
block2 = Activation('elu')(block2) |
|
|
572 |
block1 = MaxPooling2D(pool_size=(1, 3), strides=(1, 3))(block1) |
|
|
573 |
block2 = Dropout(dropoutRate)(block2) |
|
|
574 |
|
|
|
575 |
block3 = Conv2D(100, (1, 10), |
|
|
576 |
kernel_constraint = max_norm(2., axis=(0,1,2)))(block2) |
|
|
577 |
block3 = BatchNormalization(epsilon=1e-05, momentum=0.9)(block3) |
|
|
578 |
block3 = Activation('elu')(block3) |
|
|
579 |
block1 = MaxPooling2D(pool_size=(1, 3), strides=(1, 3))(block1) |
|
|
580 |
block3 = Dropout(dropoutRate)(block3) |
|
|
581 |
|
|
|
582 |
block4 = Conv2D(200, (1, 10), |
|
|
583 |
kernel_constraint = max_norm(2., axis=(0,1,2)))(block3) |
|
|
584 |
block4 = BatchNormalization(epsilon=1e-05, momentum=0.9)(block4) |
|
|
585 |
block4 = Activation('elu')(block4) |
|
|
586 |
block1 = MaxPooling2D(pool_size=(1, 3), strides=(1, 3))(block1) |
|
|
587 |
block4 = Dropout(dropoutRate)(block4) |
|
|
588 |
|
|
|
589 |
flatten = Flatten()(block4) |
|
|
590 |
|
|
|
591 |
dense = Dense(nb_classes, kernel_constraint = max_norm(0.5))(flatten) |
|
|
592 |
softmax = Activation('softmax')(dense) |
|
|
593 |
|
|
|
594 |
return Model(inputs=input_main, outputs=softmax) |
|
|
595 |
|
|
|
596 |
#%% need these for ShallowConvNet |
|
|
597 |
def square(x): |
|
|
598 |
return K.square(x) |
|
|
599 |
|
|
|
600 |
def log(x): |
|
|
601 |
return K.log(K.clip(x, min_value = 1e-7, max_value = 10000)) |
|
|
602 |
|
|
|
603 |
#%% Reproduced ShallowConvNet model: https://doi.org/10.1002/hbm.23730 |
|
|
604 |
def ShallowConvNet(nb_classes, Chans = 64, Samples = 128, dropoutRate = 0.5): |
|
|
605 |
""" Keras implementation of the Shallow Convolutional Network as described |
|
|
606 |
in Schirrmeister et. al. (2017), Human Brain Mapping. |
|
|
607 |
See details at https://onlinelibrary.wiley.com/doi/full/10.1002/hbm.23730 |
|
|
608 |
|
|
|
609 |
The original code for this model is available at: https://github.com/braindecode/braindecode |
|
|
610 |
|
|
|
611 |
Notes |
|
|
612 |
----- |
|
|
613 |
The initial values in this model are based on the values identified by the authors |
|
|
614 |
|
|
|
615 |
This implementation is taken from code by the Army Research Laboratory (ARL) |
|
|
616 |
at https://github.com/vlawhern/arl-eegmodels |
|
|
617 |
|
|
|
618 |
References |
|
|
619 |
---------- |
|
|
620 |
.. Schirrmeister, R. T., Springenberg, J. T., Fiederer, L. D. J., |
|
|
621 |
Glasstetter, M., Eggensperger, K., Tangermann, M., ... & Ball, T. (2017). |
|
|
622 |
Deep learning with convolutional neural networks for EEG decoding |
|
|
623 |
and visualization. Human brain mapping, 38(11), 5391-5420. |
|
|
624 |
|
|
|
625 |
""" |
|
|
626 |
# start the model |
|
|
627 |
# input_main = Input((Chans, Samples, 1)) |
|
|
628 |
input_main = Input((1, Chans, Samples)) |
|
|
629 |
input_2 = Permute((2,3,1))(input_main) |
|
|
630 |
|
|
|
631 |
block1 = Conv2D(40, (1, 25), |
|
|
632 |
input_shape=(Chans, Samples, 1), |
|
|
633 |
kernel_constraint = max_norm(2., axis=(0,1,2)))(input_2) |
|
|
634 |
block1 = Conv2D(40, (Chans, 1), use_bias=False, |
|
|
635 |
kernel_constraint = max_norm(2., axis=(0,1,2)))(block1) |
|
|
636 |
block1 = BatchNormalization(epsilon=1e-05, momentum=0.9)(block1) |
|
|
637 |
block1 = Activation(square)(block1) |
|
|
638 |
block1 = AveragePooling2D(pool_size=(1, 75), strides=(1, 15))(block1) |
|
|
639 |
block1 = Activation(log)(block1) |
|
|
640 |
block1 = Dropout(dropoutRate)(block1) |
|
|
641 |
flatten = Flatten()(block1) |
|
|
642 |
dense = Dense(nb_classes, kernel_constraint = max_norm(0.5))(flatten) |
|
|
643 |
softmax = Activation('softmax')(dense) |
|
|
644 |
|
|
|
645 |
return Model(inputs=input_main, outputs=softmax) |