a b/code_psd_fcnn/EEGConvNet.py
1
import torch
2
import torch.nn as nn
3
import torch.nn.functional as F
4
5
class EEGConvNet(nn.Module):
6
    def __init__(self, reduced_sensors, sfreq=None, batch_size=32):
7
        super(EEGConvNet, self).__init__()
8
        
9
        self.sfreq = sfreq
10
        self.batch_size = batch_size
11
        self.input_size = 8 if reduced_sensors else 62
12
13
        self.fc_block1 = nn.Linear(48, 64)
14
        self.batchnorm1 = nn.BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
15
        self.fc_block2 = nn.Linear(64, 32)
16
        self.fc_block3 = nn.Linear(32, 2)
17
18
        # Xavier initializations
19
        self.fc_block1.apply(lambda x: nn.init.xavier_normal_(x.weight, gain=1) if type(x) == nn.Linear else None)
20
        self.fc_block2.apply(lambda x: nn.init.xavier_normal_(x.weight, gain=1) if type(x) == nn.Linear else None)
21
        self.fc_block3.apply(lambda x: nn.init.xavier_normal_(x.weight, gain=1) if type(x) == nn.Linear else None)
22
        
23
    def forward(self, x):
24
        x = x.reshape(x.size(0), -1)
25
        x = F.dropout(F.leaky_relu(self.batchnorm1(self.fc_block1(x)), negative_slope=0.01), p=0.4, training=self.training)
26
        x = F.dropout(F.leaky_relu(self.fc_block2(x), negative_slope=0.01), p=0.5, training=self.training)
27
        out = self.fc_block3(x)
28
        return out