a b/models.py
1
""" 
2
Copyright (C) 2022 King Saud University, Saudi Arabia 
3
SPDX-License-Identifier: Apache-2.0 
4
5
Licensed under the Apache License, Version 2.0 (the "License"); you may not use
6
this file except in compliance with the License. You may obtain a copy of the 
7
License at
8
9
http://www.apache.org/licenses/LICENSE-2.0  
10
11
Unless required by applicable law or agreed to in writing, software distributed
12
under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 
13
CONDITIONS OF ANY KIND, either express or implied. See the License for the
14
specific language governing permissions and limitations under the License. 
15
16
Author:  Hamdi Altaheri 
17
"""
18
19
#%%
20
import tensorflow as tf
21
from tensorflow.keras.models import Model, Sequential
22
from tensorflow.keras.layers import Dense, Dropout, Activation, AveragePooling2D, MaxPooling2D
23
from tensorflow.keras.layers import Conv1D, Conv2D, SeparableConv2D, DepthwiseConv2D
24
from tensorflow.keras.layers import BatchNormalization, LayerNormalization, Flatten 
25
from tensorflow.keras.layers import Add, Concatenate, Lambda, Input, Permute
26
from tensorflow.keras.regularizers import L2
27
from tensorflow.keras.constraints import max_norm
28
29
from tensorflow.keras import backend as K
30
31
from attention_models import attention_block
32
33
#%% The proposed ATCNet model, https://doi.org/10.1109/TII.2022.3197419
34
def ATCNet_(n_classes, in_chans = 22, in_samples = 1125, n_windows = 5, attention = 'mha', 
35
           eegn_F1 = 16, eegn_D = 2, eegn_kernelSize = 64, eegn_poolSize = 7, eegn_dropout=0.3, 
36
           tcn_depth = 2, tcn_kernelSize = 4, tcn_filters = 32, tcn_dropout = 0.3, 
37
           tcn_activation = 'elu', fuse = 'average'):
38
    
39
    """ ATCNet model from Altaheri et al 2023.
40
        See details at https://ieeexplore.ieee.org/abstract/document/9852687
41
    
42
        Notes
43
        -----
44
        The initial values in this model are based on the values identified by
45
        the authors
46
        
47
        References
48
        ----------
49
        .. H. Altaheri, G. Muhammad, and M. Alsulaiman. "Physics-informed 
50
           attention temporal convolutional network for EEG-based motor imagery 
51
           classification." IEEE Transactions on Industrial Informatics, 
52
           vol. 19, no. 2, pp. 2249-2258, (2023) 
53
           https://doi.org/10.1109/TII.2022.3197419
54
    """
55
    input_1 = Input(shape = (1,in_chans, in_samples))   #     TensorShape([None, 1, 22, 1125])
56
    input_2 = Permute((3,2,1))(input_1) 
57
58
    dense_weightDecay = 0.5  
59
    conv_weightDecay = 0.009
60
    conv_maxNorm = 0.6
61
    from_logits = False
62
63
    numFilters = eegn_F1
64
    F2 = numFilters*eegn_D
65
66
    block1 = Conv_block_(input_layer = input_2, F1 = eegn_F1, D = eegn_D, 
67
                        kernLength = eegn_kernelSize, poolSize = eegn_poolSize,
68
                        weightDecay = conv_weightDecay, maxNorm = conv_maxNorm,
69
                        in_chans = in_chans, dropout = eegn_dropout)
70
    block1 = Lambda(lambda x: x[:,:,-1,:])(block1)
71
       
72
    # Sliding window 
73
    sw_concat = []   # to store concatenated or averaged sliding window outputs
74
    for i in range(n_windows):
75
        st = i
76
        end = block1.shape[1]-n_windows+i+1
77
        block2 = block1[:, st:end, :]
78
        
79
        # Attention_model
80
        if attention is not None:
81
            if (attention == 'se' or attention == 'cbam'):
82
                block2 = Permute((2, 1))(block2) # shape=(None, 32, 16)
83
                block2 = attention_block(block2, attention)
84
                block2 = Permute((2, 1))(block2) # shape=(None, 16, 32)
85
            else: block2 = attention_block(block2, attention)
86
87
        # Temporal convolutional network (TCN)
88
        block3 = TCN_block_(input_layer = block2, input_dimension = F2, depth = tcn_depth,
89
                            kernel_size = tcn_kernelSize, filters = tcn_filters, 
90
                            weightDecay = conv_weightDecay, maxNorm = conv_maxNorm,
91
                            dropout = tcn_dropout, activation = tcn_activation)
92
        # Get feature maps of the last sequence
93
        block3 = Lambda(lambda x: x[:,-1,:])(block3)
94
        
95
        # Outputs of sliding window: Average_after_dense or concatenate_then_dense
96
        if(fuse == 'average'):
97
            sw_concat.append(Dense(n_classes, kernel_regularizer=L2(dense_weightDecay))(block3))
98
        elif(fuse == 'concat'):
99
            if i == 0:
100
                sw_concat = block3
101
            else:
102
                sw_concat = Concatenate()([sw_concat, block3])
103
                
104
    if(fuse == 'average'):
105
        if len(sw_concat) > 1: # more than one window
106
            sw_concat = tf.keras.layers.Average()(sw_concat[:])
107
        else: # one window (# windows = 1)
108
            sw_concat = sw_concat[0]
109
    elif(fuse == 'concat'):
110
        sw_concat = Dense(n_classes, kernel_regularizer=L2(dense_weightDecay))(sw_concat)
111
               
112
    if from_logits:  # No activation here because we are using from_logits=True
113
        out = Activation('linear', name = 'linear')(sw_concat)
114
    else:   # Using softmax activation
115
        out = Activation('softmax', name = 'softmax')(sw_concat)
116
       
117
    return Model(inputs = input_1, outputs = out)
118
119
#%% Convolutional (CV) block used in the ATCNet model
120
def Conv_block(input_layer, F1=4, kernLength=64, poolSize=8, D=2, in_chans=22, dropout=0.1):
121
    """ Conv_block
122
    
123
        Notes
124
        -----
125
        This block is the same as EEGNet with SeparableConv2D replaced by Conv2D 
126
        The original code for this model is available at: https://github.com/vlawhern/arl-eegmodels
127
        See details at https://arxiv.org/abs/1611.08024
128
    """
129
    F2= F1*D
130
    block1 = Conv2D(F1, (kernLength, 1), padding = 'same',data_format='channels_last',use_bias = False)(input_layer)
131
    block1 = BatchNormalization(axis = -1)(block1)
132
    block2 = DepthwiseConv2D((1, in_chans), use_bias = False, 
133
                                    depth_multiplier = D,
134
                                    data_format='channels_last',
135
                                    depthwise_constraint = max_norm(1.))(block1)
136
    block2 = BatchNormalization(axis = -1)(block2)
137
    block2 = Activation('elu')(block2)
138
    block2 = AveragePooling2D((8,1),data_format='channels_last')(block2)
139
    block2 = Dropout(dropout)(block2)
140
    block3 = Conv2D(F2, (16, 1),
141
                            data_format='channels_last',
142
                            use_bias = False, padding = 'same')(block2)
143
    block3 = BatchNormalization(axis = -1)(block3)
144
    block3 = Activation('elu')(block3)
145
    
146
    block3 = AveragePooling2D((poolSize,1),data_format='channels_last')(block3)
147
    block3 = Dropout(dropout)(block3)
148
    return block3
149
150
def Conv_block_(input_layer, F1=4, kernLength=64, poolSize=8, D=2, in_chans=22, 
151
                weightDecay = 0.009, maxNorm = 0.6, dropout=0.25):
152
    """ Conv_block
153
    
154
        Notes
155
        -----
156
        using  different regularization methods.
157
    """
158
    
159
    F2= F1*D
160
    block1 = Conv2D(F1, (kernLength, 1), padding = 'same', data_format='channels_last', 
161
                    kernel_regularizer=L2(weightDecay),
162
                    
163
                    # In a Conv2D layer with data_format="channels_last", the weight tensor has shape 
164
                    # (rows, cols, input_depth, output_depth), set axis to [0, 1, 2] to constrain 
165
                    # the weights of each filter tensor of size (rows, cols, input_depth).
166
                    kernel_constraint = max_norm(maxNorm, axis=[0,1,2]),
167
                    use_bias = False)(input_layer)
168
    block1 = BatchNormalization(axis = -1)(block1)  # bn_axis = -1 if data_format() == 'channels_last' else 1
169
    
170
    block2 = DepthwiseConv2D((1, in_chans),  
171
                             depth_multiplier = D,
172
                             data_format='channels_last',
173
                             depthwise_regularizer=L2(weightDecay),
174
                             depthwise_constraint  = max_norm(maxNorm, axis=[0,1,2]),
175
                             use_bias = False)(block1)
176
    block2 = BatchNormalization(axis = -1)(block2)
177
    block2 = Activation('elu')(block2)
178
    block2 = AveragePooling2D((8,1),data_format='channels_last')(block2)
179
    block2 = Dropout(dropout)(block2)
180
    
181
    block3 = Conv2D(F2, (16, 1),
182
                            data_format='channels_last',
183
                            kernel_regularizer=L2(weightDecay),
184
                            kernel_constraint = max_norm(maxNorm, axis=[0,1,2]),
185
                            use_bias = False, padding = 'same')(block2)
186
    block3 = BatchNormalization(axis = -1)(block3)
187
    block3 = Activation('elu')(block3)
188
    
189
    block3 = AveragePooling2D((poolSize,1),data_format='channels_last')(block3)
190
    block3 = Dropout(dropout)(block3)
191
    return block3
192
193
#%% Temporal convolutional (TC) block used in the ATCNet model
194
def TCN_block(input_layer,input_dimension,depth,kernel_size,filters,dropout,activation='relu'):
195
    """ TCN_block from Bai et al 2018
196
        Temporal Convolutional Network (TCN)
197
        
198
        Notes
199
        -----
200
        THe original code available at https://github.com/locuslab/TCN/blob/master/TCN/tcn.py
201
        This implementation has a slight modification from the original code
202
        and it is taken from the code by Ingolfsson et al at https://github.com/iis-eth-zurich/eeg-tcnet
203
        See details at https://arxiv.org/abs/2006.00622
204
205
        References
206
        ----------
207
        .. Bai, S., Kolter, J. Z., & Koltun, V. (2018).
208
           An empirical evaluation of generic convolutional and recurrent networks
209
           for sequence modeling.
210
           arXiv preprint arXiv:1803.01271.
211
    """    
212
    
213
    block = Conv1D(filters,kernel_size=kernel_size,dilation_rate=1,activation='linear',
214
                   padding = 'causal',kernel_initializer='he_uniform')(input_layer)
215
    block = BatchNormalization()(block)
216
    block = Activation(activation)(block)
217
    block = Dropout(dropout)(block)
218
    block = Conv1D(filters,kernel_size=kernel_size,dilation_rate=1,activation='linear',
219
                   padding = 'causal',kernel_initializer='he_uniform')(block)
220
    block = BatchNormalization()(block)
221
    block = Activation(activation)(block)
222
    block = Dropout(dropout)(block)
223
    if(input_dimension != filters):
224
        conv = Conv1D(filters,kernel_size=1,padding='same')(input_layer)
225
        added = Add()([block,conv])
226
    else:
227
        added = Add()([block,input_layer])
228
    out = Activation(activation)(added)
229
    
230
    for i in range(depth-1):
231
        block = Conv1D(filters,kernel_size=kernel_size,dilation_rate=2**(i+1),activation='linear',
232
                   padding = 'causal',kernel_initializer='he_uniform')(out)
233
        block = BatchNormalization()(block)
234
        block = Activation(activation)(block)
235
        block = Dropout(dropout)(block)
236
        block = Conv1D(filters,kernel_size=kernel_size,dilation_rate=2**(i+1),activation='linear',
237
                   padding = 'causal',kernel_initializer='he_uniform')(block)
238
        block = BatchNormalization()(block)
239
        block = Activation(activation)(block)
240
        block = Dropout(dropout)(block)
241
        added = Add()([block, out])
242
        out = Activation(activation)(added)
243
        
244
    return out
245
246
def TCN_block_(input_layer,input_dimension,depth,kernel_size,filters, dropout,
247
               weightDecay = 0.009, maxNorm = 0.6, activation='relu'):
248
    """ TCN_block from Bai et al 2018
249
        Temporal Convolutional Network (TCN)
250
        
251
        Notes
252
        -----
253
        using different regularization methods
254
    """    
255
    
256
    block = Conv1D(filters, kernel_size=kernel_size, dilation_rate=1, activation='linear',
257
                    kernel_regularizer=L2(weightDecay),
258
                    kernel_constraint = max_norm(maxNorm, axis=[0,1]),
259
                    
260
                    padding = 'causal',kernel_initializer='he_uniform')(input_layer)
261
    block = BatchNormalization()(block)
262
    block = Activation(activation)(block)
263
    block = Dropout(dropout)(block)
264
    block = Conv1D(filters,kernel_size=kernel_size,dilation_rate=1,activation='linear',
265
                    kernel_regularizer=L2(weightDecay),
266
                    kernel_constraint = max_norm(maxNorm, axis=[0,1]),
267
268
                    padding = 'causal',kernel_initializer='he_uniform')(block)
269
    block = BatchNormalization()(block)
270
    block = Activation(activation)(block)
271
    block = Dropout(dropout)(block)
272
    if(input_dimension != filters):
273
        conv = Conv1D(filters,kernel_size=1,
274
                    kernel_regularizer=L2(weightDecay),
275
                    kernel_constraint = max_norm(maxNorm, axis=[0,1]),
276
                      
277
                    padding='same')(input_layer)
278
        added = Add()([block,conv])
279
    else:
280
        added = Add()([block,input_layer])
281
    out = Activation(activation)(added)
282
    
283
    for i in range(depth-1):
284
        block = Conv1D(filters,kernel_size=kernel_size,dilation_rate=2**(i+1),activation='linear',
285
                    kernel_regularizer=L2(weightDecay),
286
                    kernel_constraint = max_norm(maxNorm, axis=[0,1]),
287
                    
288
                   padding = 'causal',kernel_initializer='he_uniform')(out)
289
        block = BatchNormalization()(block)
290
        block = Activation(activation)(block)
291
        block = Dropout(dropout)(block)
292
        block = Conv1D(filters,kernel_size=kernel_size,dilation_rate=2**(i+1),activation='linear',
293
                    kernel_regularizer=L2(weightDecay),
294
                    kernel_constraint = max_norm(maxNorm, axis=[0,1]),
295
296
                    padding = 'causal',kernel_initializer='he_uniform')(block)
297
        block = BatchNormalization()(block)
298
        block = Activation(activation)(block)
299
        block = Dropout(dropout)(block)
300
        added = Add()([block, out])
301
        out = Activation(activation)(added)
302
        
303
    return out
304
305
306
#%% Reproduced TCNet_Fusion model: https://doi.org/10.1016/j.bspc.2021.102826
307
def TCNet_Fusion(n_classes, Chans=22, Samples=1125, layers=2, kernel_s=4, filt=12,
308
                 dropout=0.3, activation='elu', F1=24, D=2, kernLength=32, dropout_eeg=0.3):
309
    """ TCNet_Fusion model from Musallam et al 2021.
310
    See details at https://doi.org/10.1016/j.bspc.2021.102826
311
    
312
        Notes
313
        -----
314
        The initial values in this model are based on the values identified by
315
        the authors
316
        
317
        References
318
        ----------
319
        .. Musallam, Y.K., AlFassam, N.I., Muhammad, G., Amin, S.U., Alsulaiman,
320
           M., Abdul, W., Altaheri, H., Bencherif, M.A. and Algabri, M., 2021. 
321
           Electroencephalography-based motor imagery classification
322
           using temporal convolutional network fusion. 
323
           Biomedical Signal Processing and Control, 69, p.102826.
324
    """
325
    input1 = Input(shape = (1,Chans, Samples))
326
    input2 = Permute((3,2,1))(input1)
327
    regRate=.25
328
329
    numFilters = F1
330
    F2= numFilters*D
331
    
332
    EEGNet_sep = EEGNet(input_layer=input2,F1=F1,kernLength=kernLength,D=D,Chans=Chans,dropout=dropout_eeg)
333
    block2 = Lambda(lambda x: x[:,:,-1,:])(EEGNet_sep)
334
    FC = Flatten()(block2) 
335
336
    outs = TCN_block(input_layer=block2,input_dimension=F2,depth=layers,kernel_size=kernel_s,filters=filt,dropout=dropout,activation=activation)
337
338
    Con1 = Concatenate()([block2,outs]) 
339
    out = Flatten()(Con1) 
340
    Con2 = Concatenate()([out,FC]) 
341
    dense        = Dense(n_classes, name = 'dense',kernel_constraint = max_norm(regRate))(Con2)
342
    softmax      = Activation('softmax', name = 'softmax')(dense)
343
    
344
    return Model(inputs=input1,outputs=softmax)
345
346
347
#%% Reproduced EEGTCNet model: https://arxiv.org/abs/2006.00622
348
def EEGTCNet(n_classes, Chans=22, Samples=1125, layers=2, kernel_s=4, filt=12, dropout=0.3, activation='elu', F1=8, D=2, kernLength=32, dropout_eeg=0.2):
349
    """ EEGTCNet model from Ingolfsson et al 2020.
350
    See details at https://arxiv.org/abs/2006.00622
351
    
352
    The original code for this model is available at https://github.com/iis-eth-zurich/eeg-tcnet
353
    
354
        Notes
355
        -----
356
        The initial values in this model are based on the values identified by the authors
357
        
358
        References
359
        ----------
360
        .. Ingolfsson, T. M., Hersche, M., Wang, X., Kobayashi, N.,
361
           Cavigelli, L., & Benini, L. (2020, October). 
362
           Eeg-tcnet: An accurate temporal convolutional network
363
           for embedded motor-imagery brain–machine interfaces. 
364
           In 2020 IEEE International Conference on Systems, 
365
           Man, and Cybernetics (SMC) (pp. 2958-2965). IEEE.
366
    """
367
    input1 = Input(shape = (1,Chans, Samples))
368
    input2 = Permute((3,2,1))(input1)
369
    regRate=.25
370
    numFilters = F1
371
    F2= numFilters*D
372
373
    EEGNet_sep = EEGNet(input_layer=input2,F1=F1,kernLength=kernLength,D=D,Chans=Chans,dropout=dropout_eeg)
374
    block2 = Lambda(lambda x: x[:,:,-1,:])(EEGNet_sep)
375
    outs = TCN_block(input_layer=block2,input_dimension=F2,depth=layers,kernel_size=kernel_s,filters=filt,dropout=dropout,activation=activation)
376
    out = Lambda(lambda x: x[:,-1,:])(outs)
377
    dense        = Dense(n_classes, name = 'dense',kernel_constraint = max_norm(regRate))(out)
378
    softmax      = Activation('softmax', name = 'softmax')(dense)
379
    
380
    return Model(inputs=input1,outputs=softmax)
381
382
#%% Reproduced MBEEG_SENet model: https://doi.org/10.3390/diagnostics12040995
383
def MBEEG_SENet(nb_classes, Chans, Samples, D=2):
384
    """ MBEEG_SENet model from Altuwaijri et al 2022.
385
    See details at https://doi.org/10.3390/diagnostics12040995
386
    
387
        Notes
388
        -----
389
        The initial values in this model are based on the values identified by
390
        the authors
391
        
392
        References
393
        ----------
394
        .. G. Altuwaijri, G. Muhammad, H. Altaheri, & M. Alsulaiman. 
395
           A Multi-Branch Convolutional Neural Network with Squeeze-and-Excitation 
396
           Attention Blocks for EEG-Based Motor Imagery Signals Classification. 
397
           Diagnostics, 12(4), 995, (2022).  
398
           https://doi.org/10.3390/diagnostics12040995
399
    """
400
401
    input1 = Input(shape = (1,Chans, Samples))
402
    input2 = Permute((3,2,1))(input1)
403
    regRate=.25
404
405
    EEGNet_sep1 = EEGNet(input_layer=input2, F1=4, kernLength=16, D=D, Chans=Chans, dropout=0)
406
    EEGNet_sep2 = EEGNet(input_layer=input2, F1=8, kernLength=32, D=D, Chans=Chans, dropout=0.1)
407
    EEGNet_sep3 = EEGNet(input_layer=input2, F1=16, kernLength=64, D=D, Chans=Chans, dropout=0.2)
408
409
    SE1 = attention_block(EEGNet_sep1, 'se', ratio=4)
410
    SE2 = attention_block(EEGNet_sep2, 'se', ratio=4)
411
    SE3 = attention_block(EEGNet_sep3, 'se', ratio=2)
412
413
  
414
    FC1 = Flatten()(SE1)
415
    FC2 = Flatten()(SE2)
416
    FC3 = Flatten()(SE3)
417
418
    CON = Concatenate()([FC1,FC2,FC3])
419
   
420
    dense1 = Dense(nb_classes, name = 'dense1',kernel_constraint = max_norm(regRate))(CON)
421
    softmax = Activation('softmax', name = 'softmax')(dense1)
422
    
423
    return Model(inputs=input1,outputs=softmax)
424
425
426
427
#%% Reproduced EEGNeX model: https://arxiv.org/abs/2207.12369
428
def EEGNeX_8_32(n_timesteps, n_features, n_outputs):
429
    """ EEGNeX model from Chen et al 2022.
430
    See details at https://arxiv.org/abs/2207.12369
431
    
432
    The original code for this model is available at https://github.com/chenxiachan/EEGNeX
433
           
434
        References
435
        ----------
436
        .. Chen, X., Teng, X., Chen, H., Pan, Y., & Geyer, P. (2022).
437
           Toward reliable signals decoding for electroencephalogram: 
438
           A benchmark study to EEGNeX. arXiv preprint arXiv:2207.12369.
439
    """
440
441
    model = Sequential()
442
    model.add(Input(shape=(1, n_features, n_timesteps)))
443
444
    model.add(Conv2D(filters=8, kernel_size=(1, 32), use_bias = False, padding='same', data_format="channels_first"))
445
    model.add(LayerNormalization())
446
    model.add(Activation(activation='elu'))
447
    model.add(Conv2D(filters=32, kernel_size=(1, 32), use_bias = False, padding='same', data_format="channels_first"))
448
    model.add(LayerNormalization())
449
    model.add(Activation(activation='elu'))
450
451
    model.add(DepthwiseConv2D(kernel_size=(n_features, 1), depth_multiplier=2, use_bias = False, depthwise_constraint=max_norm(1.), data_format="channels_first"))
452
    model.add(LayerNormalization())
453
    model.add(Activation(activation='elu'))
454
    model.add(AveragePooling2D(pool_size=(1, 4), padding='same', data_format="channels_first"))
455
    model.add(Dropout(0.5))
456
457
    
458
    model.add(Conv2D(filters=32, kernel_size=(1, 16), use_bias = False, padding='same', dilation_rate=(1, 2), data_format='channels_first'))
459
    model.add(LayerNormalization())
460
    model.add(Activation(activation='elu'))
461
    
462
    model.add(Conv2D(filters=8, kernel_size=(1, 16), use_bias = False, padding='same', dilation_rate=(1, 4),  data_format='channels_first'))
463
    model.add(LayerNormalization())
464
    model.add(Activation(activation='elu'))
465
    model.add(Dropout(0.5))
466
    
467
    model.add(Flatten())
468
    model.add(Dense(n_outputs, kernel_constraint=max_norm(0.25)))
469
    model.add(Activation(activation='softmax'))
470
    
471
    # save a plot of the model
472
    # plot_model(model, show_shapes=True, to_file='EEGNeX_8_32.png')
473
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
474
    return model 
475
476
#%% Reproduced EEGNet model: https://arxiv.org/abs/1611.08024
477
def EEGNet_classifier(n_classes, Chans=22, Samples=1125, F1=8, D=2, kernLength=64, dropout_eeg=0.25):
478
    input1 = Input(shape = (1,Chans, Samples))   
479
    input2 = Permute((3,2,1))(input1) 
480
    regRate=.25
481
482
    eegnet = EEGNet(input_layer=input2, F1=F1, kernLength=kernLength, D=D, Chans=Chans, dropout=dropout_eeg)
483
    eegnet = Flatten()(eegnet)
484
    dense = Dense(n_classes, name = 'dense',kernel_constraint = max_norm(regRate))(eegnet)
485
    softmax = Activation('softmax', name = 'softmax')(dense)
486
    
487
    return Model(inputs=input1, outputs=softmax)
488
489
def EEGNet(input_layer, F1=8, kernLength=64, D=2, Chans=22, dropout=0.25):
490
    """ EEGNet model from Lawhern et al 2018
491
    See details at https://arxiv.org/abs/1611.08024
492
    
493
    The original code for this model is available at: https://github.com/vlawhern/arl-eegmodels
494
    
495
        Notes
496
        -----
497
        The initial values in this model are based on the values identified by the authors
498
        
499
        References
500
        ----------
501
        .. Lawhern, V. J., Solon, A. J., Waytowich, N. R., Gordon,
502
           S. M., Hung, C. P., & Lance, B. J. (2018).
503
           EEGNet: A Compact Convolutional Network for EEG-based
504
           Brain-Computer Interfaces.
505
           arXiv preprint arXiv:1611.08024.
506
    """
507
    F2= F1*D
508
    block1 = Conv2D(F1, (kernLength, 1), padding = 'same',data_format='channels_last',use_bias = False)(input_layer)
509
    block1 = BatchNormalization(axis = -1)(block1)
510
    block2 = DepthwiseConv2D((1, Chans), use_bias = False, 
511
                                    depth_multiplier = D,
512
                                    data_format='channels_last',
513
                                    depthwise_constraint = max_norm(1.))(block1)
514
    block2 = BatchNormalization(axis = -1)(block2)
515
    block2 = Activation('elu')(block2)
516
    block2 = AveragePooling2D((8,1),data_format='channels_last')(block2)
517
    block2 = Dropout(dropout)(block2)
518
    block3 = SeparableConv2D(F2, (16, 1),
519
                            data_format='channels_last',
520
                            use_bias = False, padding = 'same')(block2)
521
    block3 = BatchNormalization(axis = -1)(block3)
522
    block3 = Activation('elu')(block3)
523
    block3 = AveragePooling2D((8,1),data_format='channels_last')(block3)
524
    block3 = Dropout(dropout)(block3)
525
    return block3
526
527
528
#%% Reproduced DeepConvNet model: https://doi.org/10.1002/hbm.23730
529
def DeepConvNet(nb_classes, Chans = 64, Samples = 256,
530
                dropoutRate = 0.5):
531
    """ Keras implementation of the Deep Convolutional Network as described in
532
    Schirrmeister et. al. (2017), Human Brain Mapping.
533
    See details at https://onlinelibrary.wiley.com/doi/full/10.1002/hbm.23730
534
    
535
    The original code for this model is available at: https://github.com/braindecode/braindecode
536
537
        Notes
538
        -----
539
        The initial values in this model are based on the values identified by the authors
540
541
        This implementation is taken from code by the Army Research Laboratory (ARL) 
542
        at https://github.com/vlawhern/arl-eegmodels
543
       
544
        References
545
        ----------
546
        .. Schirrmeister, R. T., Springenberg, J. T., Fiederer, L. D. J., 
547
           Glasstetter, M., Eggensperger, K., Tangermann, M., ... & Ball, T. (2017). 
548
           Deep learning with convolutional neural networks for EEG decoding 
549
           and visualization. Human brain mapping, 38(11), 5391-5420.
550
551
    """
552
553
    # start the model
554
    # input_main   = Input((Chans, Samples, 1))
555
    input_main   = Input((1, Chans, Samples))
556
    input_2 = Permute((2,3,1))(input_main) 
557
    
558
    block1       = Conv2D(25, (1, 10), 
559
                                 input_shape=(Chans, Samples, 1),
560
                                 kernel_constraint = max_norm(2., axis=(0,1,2)))(input_2)
561
    block1       = Conv2D(25, (Chans, 1),
562
                                 kernel_constraint = max_norm(2., axis=(0,1,2)))(block1)
563
    block1       = BatchNormalization(epsilon=1e-05, momentum=0.9)(block1)
564
    block1       = Activation('elu')(block1)
565
    block1       = MaxPooling2D(pool_size=(1, 3), strides=(1, 3))(block1)
566
    block1       = Dropout(dropoutRate)(block1)
567
  
568
    block2       = Conv2D(50, (1, 10),
569
                                 kernel_constraint = max_norm(2., axis=(0,1,2)))(block1)
570
    block2       = BatchNormalization(epsilon=1e-05, momentum=0.9)(block2)
571
    block2       = Activation('elu')(block2)
572
    block1       = MaxPooling2D(pool_size=(1, 3), strides=(1, 3))(block1)
573
    block2       = Dropout(dropoutRate)(block2)
574
    
575
    block3       = Conv2D(100, (1, 10),
576
                                 kernel_constraint = max_norm(2., axis=(0,1,2)))(block2)
577
    block3       = BatchNormalization(epsilon=1e-05, momentum=0.9)(block3)
578
    block3       = Activation('elu')(block3)
579
    block1       = MaxPooling2D(pool_size=(1, 3), strides=(1, 3))(block1)
580
    block3       = Dropout(dropoutRate)(block3)
581
    
582
    block4       = Conv2D(200, (1, 10),
583
                                 kernel_constraint = max_norm(2., axis=(0,1,2)))(block3)
584
    block4       = BatchNormalization(epsilon=1e-05, momentum=0.9)(block4)
585
    block4       = Activation('elu')(block4)
586
    block1       = MaxPooling2D(pool_size=(1, 3), strides=(1, 3))(block1)
587
    block4       = Dropout(dropoutRate)(block4)
588
    
589
    flatten      = Flatten()(block4)
590
    
591
    dense        = Dense(nb_classes, kernel_constraint = max_norm(0.5))(flatten)
592
    softmax      = Activation('softmax')(dense)
593
    
594
    return Model(inputs=input_main, outputs=softmax)
595
596
#%% need these for ShallowConvNet
597
def square(x):
598
    return K.square(x)
599
600
def log(x):
601
    return K.log(K.clip(x, min_value = 1e-7, max_value = 10000))  
602
603
#%% Reproduced ShallowConvNet model: https://doi.org/10.1002/hbm.23730
604
def ShallowConvNet(nb_classes, Chans = 64, Samples = 128, dropoutRate = 0.5):
605
    """ Keras implementation of the Shallow Convolutional Network as described
606
    in Schirrmeister et. al. (2017), Human Brain Mapping.
607
    See details at https://onlinelibrary.wiley.com/doi/full/10.1002/hbm.23730
608
       
609
    The original code for this model is available at: https://github.com/braindecode/braindecode
610
611
        Notes
612
        -----
613
        The initial values in this model are based on the values identified by the authors
614
615
        This implementation is taken from code by the Army Research Laboratory (ARL) 
616
        at https://github.com/vlawhern/arl-eegmodels
617
       
618
        References
619
        ----------
620
        .. Schirrmeister, R. T., Springenberg, J. T., Fiederer, L. D. J., 
621
           Glasstetter, M., Eggensperger, K., Tangermann, M., ... & Ball, T. (2017). 
622
           Deep learning with convolutional neural networks for EEG decoding 
623
           and visualization. Human brain mapping, 38(11), 5391-5420.
624
625
    """
626
    # start the model
627
    # input_main   = Input((Chans, Samples, 1))
628
    input_main   = Input((1, Chans, Samples))
629
    input_2 = Permute((2,3,1))(input_main) 
630
631
    block1       = Conv2D(40, (1, 25), 
632
                                 input_shape=(Chans, Samples, 1),
633
                                 kernel_constraint = max_norm(2., axis=(0,1,2)))(input_2)
634
    block1       = Conv2D(40, (Chans, 1), use_bias=False, 
635
                          kernel_constraint = max_norm(2., axis=(0,1,2)))(block1)
636
    block1       = BatchNormalization(epsilon=1e-05, momentum=0.9)(block1)
637
    block1       = Activation(square)(block1)
638
    block1       = AveragePooling2D(pool_size=(1, 75), strides=(1, 15))(block1)
639
    block1       = Activation(log)(block1)
640
    block1       = Dropout(dropoutRate)(block1)
641
    flatten      = Flatten()(block1)
642
    dense        = Dense(nb_classes, kernel_constraint = max_norm(0.5))(flatten)
643
    softmax      = Activation('softmax')(dense)
644
    
645
    return Model(inputs=input_main, outputs=softmax)