Diff of /Models.py [000000] .. [367703]

Switch to unified view

a b/Models.py
1
'''
2
Created by Victor Delvigne
3
ISIA Lab, Faculty of Engineering University of Mons, Mons (Belgium)
4
victor.delvigne@umons.ac.be
5
6
Source: Bashivan, et al."Learning Representations from EEG with Deep Recurrent-Convolutional Neural Networks." International conference on learning representations (2016).
7
8
Copyright (C) 2019 - UMons
9
10
This library is free software; you can redistribute it and/or
11
modify it under the terms of the GNU Lesser General Public
12
License as published by the Free Software Foundation; either
13
version 2.1 of the License, or (at your option) any later version.
14
15
This library is distributed in the hope that it will be useful,
16
but WITHOUT ANY WARRANTY; without even the implied warranty of
17
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
18
Lesser General Public License for more details.
19
20
You should have received a copy of the GNU Lesser General Public
21
License along with this library; if not, write to the Free Software
22
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA
23
'''
24
25
import torch
26
27
import torch.optim as optim
28
import torch.nn as nn
29
import torch.nn.functional as F
30
31
32
class BasicCNN(nn.Module):
33
    '''
34
    Build the  Mean Basic model performing a classification with CNN 
35
36
    param input_image: list of EEG image [batch_size, n_window, n_channel, h, w]
37
    param kernel: kernel size used for the convolutional layers
38
    param stride: stride apply during the convolutions
39
    param padding: padding used during the convolutions
40
    param max_kernel: kernel used for the maxpooling steps
41
    param n_classes: number of classes
42
    return x: output of the last layers after the log softmax
43
    '''
44
    def __init__(self, input_image=torch.zeros(1, 3, 32, 32), kernel=(3,3), stride=1, padding=1,max_kernel=(2,2), n_classes=4):
45
        super(BasicCNN, self).__init__()
46
47
        n_channel = input_image.shape[1]
48
49
        self.conv1 = nn.Conv2d(n_channel,32,kernel,stride=stride, padding=padding)
50
        self.conv2 = nn.Conv2d(32,32,kernel,stride=stride, padding=padding)
51
        self.conv3 = nn.Conv2d(32,32,kernel,stride=stride, padding=padding)
52
        self.conv4 = nn.Conv2d(32,32,kernel,stride=stride, padding=padding)
53
        self.pool1 = nn.MaxPool2d(max_kernel)
54
        self.conv5 = nn.Conv2d(32,64,kernel,stride=stride,padding=padding)
55
        self.conv6 = nn.Conv2d(64,64,kernel,stride=stride,padding=padding)
56
        self.conv7 = nn.Conv2d(64,128,kernel,stride=stride,padding=padding)
57
58
        self.pool = nn.MaxPool2d((1,1))
59
        self.drop = nn.Dropout(p=0.5)
60
61
        self.fc1 = nn.Linear(2048,512)
62
        self.fc2 = nn.Linear(512,n_classes)
63
        self.max = nn.LogSoftmax()
64
    
65
    def forward(self, x):
66
        batch_size = x.shape[0]
67
        x = F.relu(self.conv1(x))
68
        x = F.relu(self.conv2(x))
69
        x = F.relu(self.conv3(x))
70
        x = F.relu(self.conv4(x))
71
        x = self.pool1(x)
72
        x = F.relu(self.conv5(x))
73
        x = F.relu(self.conv6(x))
74
        x = self.pool1(x)
75
        x = F.relu(self.conv7(x))
76
        x = self.pool1(x)
77
        x = x.reshape(x.shape[0],x.shape[1], -1)
78
        x = self.pool(x)
79
        x = x.reshape(x.shape[0],-1)
80
        x = self.fc1(x)
81
        x = self.fc2(x)
82
        x = self.max(x)
83
        return x
84
85
86
class MaxCNN(nn.Module):
87
    '''
88
    Build the Max-pooling model performing a maxpool over the 7 parallel convnets
89
90
    param input_image: list of EEG image [batch_size, n_window, n_channel, h, w]
91
    param kernel: kernel size used for the convolutional layers
92
    param stride: stride apply during the convolutions
93
    param padding: padding used during the convolutions
94
    param max_kernel: kernel used for the maxpooling steps
95
    param n_classes: number of classes
96
    return x: output of the last layers after the log softmax
97
    '''
98
    def __init__(self, input_image=torch.zeros(1, 7, 3, 32, 32), kernel=(3,3), stride=1, padding=1,max_kernel=(2,2), n_classes=4):
99
        super(MaxCNN, self).__init__()
100
101
        n_window = input_image.shape[1]
102
        n_channel = input_image.shape[2]
103
104
        self.conv1 = nn.Conv2d(n_channel,32,kernel,stride=stride, padding=padding)
105
        self.conv2 = nn.Conv2d(32,32,kernel,stride=stride, padding=padding)
106
        self.conv3 = nn.Conv2d(32,32,kernel,stride=stride, padding=padding)
107
        self.conv4 = nn.Conv2d(32,32,kernel,stride=stride, padding=padding)
108
        self.pool1 = nn.MaxPool2d(max_kernel)
109
        self.conv5 = nn.Conv2d(32,64,kernel,stride=stride,padding=padding)
110
        self.conv6 = nn.Conv2d(64,64,kernel,stride=stride,padding=padding)
111
        self.conv7 = nn.Conv2d(64,128,kernel,stride=stride,padding=padding)
112
113
        self.pool = nn.MaxPool2d((n_window,1))
114
        self.drop = nn.Dropout(p=0.5)
115
116
        self.fc = nn.Linear(n_window*int(4*4*128/n_window),512)
117
        self.fc2 = nn.Linear(512,n_classes)
118
        self.max = nn.LogSoftmax()
119
120
    def forward(self, x):
121
        if x.get_device() == 0:
122
            tmp = torch.zeros(x.shape[0],x.shape[1],128,4,4).cuda()
123
        else:
124
            tmp = torch.zeros(x.shape[0],x.shape[1],128,4,4).cpu()
125
        for i in range(7):
126
            tmp[:,i] = self.pool1( F.relu(self.conv7(self.pool1(F.relu(self.conv6(F.relu(self.conv5(self.pool1( F.relu(self.conv4(F.relu(self.conv3( F.relu(self.conv2(F.relu(self.conv1(x[:,i])))))))))))))))))
127
        x = tmp.reshape(x.shape[0], x.shape[1],4*128*4,1)
128
        x = self.pool(x)
129
        x = x.view(x.shape[0],-1)
130
        x = self.fc2(self.fc(x))
131
        x = self.max(x)
132
        return x
133
134
135
class TempCNN(nn.Module):
136
    '''
137
    Build the Conv1D model performing a convolution1D over the 7 parallel convnets
138
139
    param input_image: list of EEG image [batch_size, n_window, n_channel, h, w]
140
    param kernel: kernel size used for the convolutional layers
141
    param stride: stride apply during the convolutions
142
    param padding: padding used during the convolutions
143
    param max_kernel: kernel used for the maxpooling steps
144
    param n_classes: number of classes
145
    return x: output of the last layers after the log softmax
146
    '''
147
    def __init__(self, input_image=torch.zeros(1, 7, 3, 32, 32), kernel=(3,3), stride=1, padding=1,max_kernel=(2,2), n_classes=4):
148
        super(TempCNN, self).__init__()
149
150
        n_window = input_image.shape[1]
151
        n_channel = input_image.shape[2]
152
153
        self.conv1 = nn.Conv2d(n_channel,32,kernel,stride=stride, padding=padding)
154
        self.conv2 = nn.Conv2d(32,32,kernel,stride=stride, padding=padding)
155
        self.conv3 = nn.Conv2d(32,32,kernel,stride=stride, padding=padding)
156
        self.conv4 = nn.Conv2d(32,32,kernel,stride=stride, padding=padding)
157
        self.pool1 = nn.MaxPool2d(max_kernel)
158
        self.conv5 = nn.Conv2d(32,64,kernel,stride=stride,padding=padding)
159
        self.conv6 = nn.Conv2d(64,64,kernel,stride=stride,padding=padding)
160
        self.conv7 = nn.Conv2d(64,128,kernel,stride=stride,padding=padding)
161
162
        #Temporal CNN Layer
163
        self.conv8 = nn.Conv1d(n_window,64,(4*4*128,3),stride=stride,padding=padding)
164
165
        self.pool = nn.MaxPool2d((n_window,1))
166
        self.drop = nn.Dropout(p=0.5)
167
        self.fc = nn.Linear(64*3,n_classes)
168
        self.max = nn.LogSoftmax()
169
170
    def forward(self, x):
171
        if x.get_device() == 0:
172
            tmp = torch.zeros(x.shape[0],x.shape[1],128,4,4).cuda()
173
        else:
174
            tmp = torch.zeros(x.shape[0],x.shape[1],128,4,4).cpu()
175
        for i in range(7):
176
            tmp[:,i] = self.pool1( F.relu(self.conv7(self.pool1(F.relu(self.conv6(F.relu(self.conv5(self.pool1( F.relu(self.conv4(F.relu(self.conv3( F.relu(self.conv2(F.relu(self.conv1(x[:,i])))))))))))))))))
177
        x = tmp.reshape(x.shape[0], x.shape[1],4*128*4,1)
178
        x = F.relu(self.conv8(x))
179
        x = x.view(x.shape[0],-1)
180
        x = self.fc(x)
181
        x = self.max(x)
182
        return x
183
184
185
class LSTM(nn.Module):
186
    '''
187
    Build the LSTM model applying a RNN over the 7 parallel convnets outputs
188
189
    param input_image: list of EEG image [batch_size, n_window, n_channel, h, w]
190
    param kernel: kernel size used for the convolutional layers
191
    param stride: stride apply during the convolutions
192
    param padding: padding used during the convolutions
193
    param max_kernel: kernel used for the maxpooling steps
194
    param n_classes: number of classes
195
    param n_units: number of units
196
    return x: output of the last layers after the log softmax
197
    '''
198
    def __init__(self, input_image=torch.zeros(1, 7, 3, 32, 32), kernel=(3,3), stride=1, padding=1,max_kernel=(2,2), n_classes=4, n_units=128):
199
        super(LSTM, self).__init__()
200
201
        n_window = input_image.shape[1]
202
        n_channel = input_image.shape[2]
203
204
        self.conv1 = nn.Conv2d(n_channel,32,kernel,stride=stride, padding=padding)
205
        self.conv2 = nn.Conv2d(32,32,kernel,stride=stride, padding=padding)
206
        self.conv3 = nn.Conv2d(32,32,kernel,stride=stride, padding=padding)
207
        self.conv4 = nn.Conv2d(32,32,kernel,stride=stride, padding=padding)
208
        self.pool1 = nn.MaxPool2d(max_kernel)
209
        self.conv5 = nn.Conv2d(32,64,kernel,stride=stride,padding=padding)
210
        self.conv6 = nn.Conv2d(64,64,kernel,stride=stride,padding=padding)
211
        self.conv7 = nn.Conv2d(64,128,kernel,stride=stride,padding=padding)
212
213
        # LSTM Layer
214
        self.rnn = nn.RNN(4*4*128, n_units, n_window)
215
        self.rnn_out = torch.zeros(2, 7, 128)
216
217
        self.pool = nn.MaxPool2d((n_window,1))
218
        self.drop = nn.Dropout(p=0.5)
219
        self.fc = nn.Linear(896, n_classes)
220
        self.max = nn.LogSoftmax()
221
222
    def forward(self, x):
223
        if x.get_device() == 0:
224
            tmp = torch.zeros(x.shape[0], x.shape[1], 128, 4, 4).cuda()
225
        else:
226
            tmp = torch.zeros(x.shape[0], x.shape[1], 128, 4, 4).cpu()
227
        for i in range(7):
228
            img = x[:, i]
229
            img = F.relu(self.conv1(img))
230
            img = F.relu(self.conv2(img))
231
            img = F.relu(self.conv3(img))
232
            img = F.relu(self.conv4(img))
233
            img = self.pool1(img)
234
            img = F.relu(self.conv5(img))
235
            img = F.relu(self.conv6(img))
236
            img = self.pool1(img)
237
            img = F.relu(self.conv7(img))
238
            tmp[:, i] = self.pool1(img)
239
            del img
240
        x = tmp.reshape(x.shape[0], x.shape[1], 4 * 128 * 4)
241
        del tmp
242
        self.rnn_out, _ = self.rnn(x)
243
        x = self.rnn_out.view(x.shape[0], -1)
244
        x = self.fc(x)
245
        x = self.max(x)
246
        return x
247
248
249
class Mix(nn.Module):
250
    '''
251
        Build the LSTM model applying a RNN and a CNN over the 7 parallel convnets outputs
252
253
        param input_image: list of EEG image [batch_size, n_window, n_channel, h, w]
254
        param kernel: kernel size used for the convolutional layers
255
        param stride: stride apply during the convolutions
256
        param padding: padding used during the convolutions
257
        param max_kernel: kernel used for the maxpooling steps
258
        param n_classes: number of classes
259
        param n_units: number of units
260
        return x: output of the last layers after the log softmax
261
        '''
262
    def __init__(self, input_image=torch.zeros(1, 7, 3, 32, 32), kernel=(3,3), stride=1, padding=1,max_kernel=(2,2), n_classes=4, n_units=128):
263
        super(Mix, self).__init__()
264
265
        n_window = input_image.shape[1]
266
        n_channel = input_image.shape[2]
267
268
        self.conv1 = nn.Conv2d(n_channel,32,kernel,stride=stride, padding=padding)
269
        self.conv2 = nn.Conv2d(32,32,kernel,stride=stride, padding=padding)
270
        self.conv3 = nn.Conv2d(32,32,kernel,stride=stride, padding=padding)
271
        self.conv4 = nn.Conv2d(32,32,kernel,stride=stride, padding=padding)
272
        self.pool1 = nn.MaxPool2d(max_kernel)
273
        self.conv5 = nn.Conv2d(32,64,kernel,stride=stride,padding=padding)
274
        self.conv6 = nn.Conv2d(64,64,kernel,stride=stride,padding=padding)
275
        self.conv7 = nn.Conv2d(64,128,kernel,stride=stride,padding=padding)
276
277
        # LSTM Layer
278
        self.rnn = nn.RNN(4*4*128, n_units, n_window)
279
        self.rnn_out = torch.zeros(2, 7, 128)
280
281
        # Temporal CNN Layer
282
        self.conv8 = nn.Conv1d(n_window, 64, (4 * 4 * 128, 3), stride=stride, padding=padding)
283
284
        self.pool = nn.MaxPool2d((n_window, 1))
285
        self.drop = nn.Dropout(p=0.5)
286
        self.fc1 = nn.Linear(1088,512)
287
        self.fc2 = nn.Linear(512, n_classes)
288
        self.max = nn.LogSoftmax()
289
290
291
    def forward(self, x):
292
        if x.get_device() == 0:
293
            tmp = torch.zeros(x.shape[0], x.shape[1], 128, 4, 4).cuda()
294
        else:
295
            tmp = torch.zeros(x.shape[0], x.shape[1], 128, 4, 4).cpu()
296
        for i in range(7):
297
            img = x[:, i]
298
            img = F.relu(self.conv1(img))
299
            img = F.relu(self.conv2(img))
300
            img = F.relu(self.conv3(img))
301
            img = F.relu(self.conv4(img))
302
            img = self.pool1(img)
303
            img = F.relu(self.conv5(img))
304
            img = F.relu(self.conv6(img))
305
            img = self.pool1(img)
306
            img = F.relu(self.conv7(img))
307
            tmp[:, i] = self.pool1(img)
308
            del img
309
310
        temp_conv = F.relu(self.conv8(tmp.reshape(x.shape[0], x.shape[1], 4 * 128 * 4, 1)))
311
        temp_conv = temp_conv.reshape(temp_conv.shape[0], -1)
312
313
        self.lstm_out, _ = self.rnn(tmp.reshape(x.shape[0], x.shape[1], 4 * 128 * 4))
314
        del tmp
315
        lstm = self.lstm_out.view(x.shape[0], -1)
316
317
        x = torch.cat((temp_conv, lstm), 1)
318
319
        x = self.fc1(x)
320
        x = self.fc2(x)
321
        x = self.max(x)
322
        return x
323