Diff of /conformer.py [000000] .. [8bbec7]

Switch to unified view

a b/conformer.py
1
"""
2
EEG Conformer 
3
4
Convolutional Transformer for EEG decoding
5
6
Couple CNN and Transformer in a concise manner with amazing results
7
"""
8
# remember to change paths
9
10
import argparse
11
import os
12
gpus = [0]
13
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
14
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, gpus))
15
import numpy as np
16
import math
17
import glob
18
import random
19
import itertools
20
import datetime
21
import time
22
import datetime
23
import sys
24
import scipy.io
25
26
import torchvision.transforms as transforms
27
from torchvision.utils import save_image, make_grid
28
29
from torch.utils.data import DataLoader
30
from torch.autograd import Variable
31
from torchsummary import summary
32
import torch.autograd as autograd
33
from torchvision.models import vgg19
34
35
import torch.nn as nn
36
import torch.nn.functional as F
37
import torch
38
import torch.nn.init as init
39
40
from torch.utils.data import Dataset
41
from PIL import Image
42
import torchvision.transforms as transforms
43
from sklearn.decomposition import PCA
44
45
import torch
46
import torch.nn.functional as F
47
import matplotlib.pyplot as plt
48
49
from torch import nn
50
from torch import Tensor
51
from PIL import Image
52
from torchvision.transforms import Compose, Resize, ToTensor
53
from einops import rearrange, reduce, repeat
54
from einops.layers.torch import Rearrange, Reduce
55
# from common_spatial_pattern import csp
56
57
import matplotlib.pyplot as plt
58
# from torch.utils.tensorboard import SummaryWriter
59
from torch.backends import cudnn
60
cudnn.benchmark = False
61
cudnn.deterministic = True
62
63
# writer = SummaryWriter('./TensorBoardX/')
64
65
66
# Convolution module
67
# use conv to capture local features, instead of postion embedding.
68
class PatchEmbedding(nn.Module):
69
    def __init__(self, emb_size=40):
70
        # self.patch_size = patch_size
71
        super().__init__()
72
73
        self.shallownet = nn.Sequential(
74
            nn.Conv2d(1, 40, (1, 25), (1, 1)),
75
            nn.Conv2d(40, 40, (22, 1), (1, 1)),
76
            nn.BatchNorm2d(40),
77
            nn.ELU(),
78
            nn.AvgPool2d((1, 75), (1, 15)),  # pooling acts as slicing to obtain 'patch' along the time dimension as in ViT
79
            nn.Dropout(0.5),
80
        )
81
82
        self.projection = nn.Sequential(
83
            nn.Conv2d(40, emb_size, (1, 1), stride=(1, 1)),  # transpose, conv could enhance fiting ability slightly
84
            Rearrange('b e (h) (w) -> b (h w) e'),
85
        )
86
87
88
    def forward(self, x: Tensor) -> Tensor:
89
        b, _, _, _ = x.shape
90
        x = self.shallownet(x)
91
        x = self.projection(x)
92
        return x
93
94
95
class MultiHeadAttention(nn.Module):
96
    def __init__(self, emb_size, num_heads, dropout):
97
        super().__init__()
98
        self.emb_size = emb_size
99
        self.num_heads = num_heads
100
        self.keys = nn.Linear(emb_size, emb_size)
101
        self.queries = nn.Linear(emb_size, emb_size)
102
        self.values = nn.Linear(emb_size, emb_size)
103
        self.att_drop = nn.Dropout(dropout)
104
        self.projection = nn.Linear(emb_size, emb_size)
105
106
    def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
107
        queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
108
        keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
109
        values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
110
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)  
111
        if mask is not None:
112
            fill_value = torch.finfo(torch.float32).min
113
            energy.mask_fill(~mask, fill_value)
114
115
        scaling = self.emb_size ** (1 / 2)
116
        att = F.softmax(energy / scaling, dim=-1)
117
        att = self.att_drop(att)
118
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
119
        out = rearrange(out, "b h n d -> b n (h d)")
120
        out = self.projection(out)
121
        return out
122
123
124
class ResidualAdd(nn.Module):
125
    def __init__(self, fn):
126
        super().__init__()
127
        self.fn = fn
128
129
    def forward(self, x, **kwargs):
130
        res = x
131
        x = self.fn(x, **kwargs)
132
        x += res
133
        return x
134
135
136
class FeedForwardBlock(nn.Sequential):
137
    def __init__(self, emb_size, expansion, drop_p):
138
        super().__init__(
139
            nn.Linear(emb_size, expansion * emb_size),
140
            nn.GELU(),
141
            nn.Dropout(drop_p),
142
            nn.Linear(expansion * emb_size, emb_size),
143
        )
144
145
146
class GELU(nn.Module):
147
    def forward(self, input: Tensor) -> Tensor:
148
        return input*0.5*(1.0+torch.erf(input/math.sqrt(2.0)))
149
150
151
class TransformerEncoderBlock(nn.Sequential):
152
    def __init__(self,
153
                 emb_size,
154
                 num_heads=10,
155
                 drop_p=0.5,
156
                 forward_expansion=4,
157
                 forward_drop_p=0.5):
158
        super().__init__(
159
            ResidualAdd(nn.Sequential(
160
                nn.LayerNorm(emb_size),
161
                MultiHeadAttention(emb_size, num_heads, drop_p),
162
                nn.Dropout(drop_p)
163
            )),
164
            ResidualAdd(nn.Sequential(
165
                nn.LayerNorm(emb_size),
166
                FeedForwardBlock(
167
                    emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
168
                nn.Dropout(drop_p)
169
            )
170
            ))
171
172
173
class TransformerEncoder(nn.Sequential):
174
    def __init__(self, depth, emb_size):
175
        super().__init__(*[TransformerEncoderBlock(emb_size) for _ in range(depth)])
176
177
178
class ClassificationHead(nn.Sequential):
179
    def __init__(self, emb_size, n_classes):
180
        super().__init__()
181
        
182
        # global average pooling
183
        self.clshead = nn.Sequential(
184
            Reduce('b n e -> b e', reduction='mean'),
185
            nn.LayerNorm(emb_size),
186
            nn.Linear(emb_size, n_classes)
187
        )
188
        self.fc = nn.Sequential(
189
            nn.Linear(2440, 256),
190
            nn.ELU(),
191
            nn.Dropout(0.5),
192
            nn.Linear(256, 32),
193
            nn.ELU(),
194
            nn.Dropout(0.3),
195
            nn.Linear(32, 4)
196
        )
197
198
    def forward(self, x):
199
        x = x.contiguous().view(x.size(0), -1)
200
        out = self.fc(x)
201
        return x, out
202
203
204
class Conformer(nn.Sequential):
205
    def __init__(self, emb_size=40, depth=6, n_classes=4, **kwargs):
206
        super().__init__(
207
208
            PatchEmbedding(emb_size),
209
            TransformerEncoder(depth, emb_size),
210
            ClassificationHead(emb_size, n_classes)
211
        )
212
213
214
class ExP():
215
    def __init__(self, nsub):
216
        super(ExP, self).__init__()
217
        self.batch_size = 72
218
        self.n_epochs = 2000
219
        self.c_dim = 4
220
        self.lr = 0.0002
221
        self.b1 = 0.5
222
        self.b2 = 0.999
223
        self.dimension = (190, 50)
224
        self.nSub = nsub
225
226
        self.start_epoch = 0
227
        self.root = '/Data/strict_TE/'
228
229
        self.log_write = open("./results/log_subject%d.txt" % self.nSub, "w")
230
231
232
        self.Tensor = torch.cuda.FloatTensor
233
        self.LongTensor = torch.cuda.LongTensor
234
235
        self.criterion_l1 = torch.nn.L1Loss().cuda()
236
        self.criterion_l2 = torch.nn.MSELoss().cuda()
237
        self.criterion_cls = torch.nn.CrossEntropyLoss().cuda()
238
239
        self.model = Conformer().cuda()
240
        self.model = nn.DataParallel(self.model, device_ids=[i for i in range(len(gpus))])
241
        self.model = self.model.cuda()
242
        # summary(self.model, (1, 22, 1000))
243
244
245
    # Segmentation and Reconstruction (S&R) data augmentation
246
    def interaug(self, timg, label):  
247
        aug_data = []
248
        aug_label = []
249
        for cls4aug in range(4):
250
            cls_idx = np.where(label == cls4aug + 1)
251
            tmp_data = timg[cls_idx]
252
            tmp_label = label[cls_idx]
253
254
            tmp_aug_data = np.zeros((int(self.batch_size / 4), 1, 22, 1000))
255
            for ri in range(int(self.batch_size / 4)):
256
                for rj in range(8):
257
                    rand_idx = np.random.randint(0, tmp_data.shape[0], 8)
258
                    tmp_aug_data[ri, :, :, rj * 125:(rj + 1) * 125] = tmp_data[rand_idx[rj], :, :,
259
                                                                      rj * 125:(rj + 1) * 125]
260
261
            aug_data.append(tmp_aug_data)
262
            aug_label.append(tmp_label[:int(self.batch_size / 4)])
263
        aug_data = np.concatenate(aug_data)
264
        aug_label = np.concatenate(aug_label)
265
        aug_shuffle = np.random.permutation(len(aug_data))
266
        aug_data = aug_data[aug_shuffle, :, :]
267
        aug_label = aug_label[aug_shuffle]
268
269
        aug_data = torch.from_numpy(aug_data).cuda()
270
        aug_data = aug_data.float()
271
        aug_label = torch.from_numpy(aug_label-1).cuda()
272
        aug_label = aug_label.long()
273
        return aug_data, aug_label
274
275
    def get_source_data(self):
276
        # ! please please recheck if you need validation set 
277
        # ! and the data segement compared methods used
278
279
        # train data
280
        self.total_data = scipy.io.loadmat(self.root + 'A0%dT.mat' % self.nSub)
281
        self.train_data = self.total_data['data']
282
        self.train_label = self.total_data['label']
283
284
        self.train_data = np.transpose(self.train_data, (2, 1, 0))
285
        self.train_data = np.expand_dims(self.train_data, axis=1)
286
        self.train_label = np.transpose(self.train_label)
287
288
        self.allData = self.train_data
289
        self.allLabel = self.train_label[0]
290
291
        shuffle_num = np.random.permutation(len(self.allData))
292
        self.allData = self.allData[shuffle_num, :, :, :]
293
        self.allLabel = self.allLabel[shuffle_num]
294
295
        # test data
296
        self.test_tmp = scipy.io.loadmat(self.root + 'A0%dE.mat' % self.nSub)
297
        self.test_data = self.test_tmp['data']
298
        self.test_label = self.test_tmp['label']
299
300
        self.test_data = np.transpose(self.test_data, (2, 1, 0))
301
        self.test_data = np.expand_dims(self.test_data, axis=1)
302
        self.test_label = np.transpose(self.test_label)
303
304
        self.testData = self.test_data
305
        self.testLabel = self.test_label[0]
306
307
308
        # standardize
309
        target_mean = np.mean(self.allData)
310
        target_std = np.std(self.allData)
311
        self.allData = (self.allData - target_mean) / target_std
312
        self.testData = (self.testData - target_mean) / target_std
313
314
        # data shape: (trial, conv channel, electrode channel, time samples)
315
        return self.allData, self.allLabel, self.testData, self.testLabel
316
317
318
    def train(self):
319
320
        img, label, test_data, test_label = self.get_source_data()
321
322
        img = torch.from_numpy(img)
323
        label = torch.from_numpy(label - 1)
324
325
        dataset = torch.utils.data.TensorDataset(img, label)
326
        self.dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)
327
328
        test_data = torch.from_numpy(test_data)
329
        test_label = torch.from_numpy(test_label - 1)
330
        test_dataset = torch.utils.data.TensorDataset(test_data, test_label)
331
        self.test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=True)
332
333
        # Optimizers
334
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(self.b1, self.b2))
335
336
        test_data = Variable(test_data.type(self.Tensor))
337
        test_label = Variable(test_label.type(self.LongTensor))
338
339
        bestAcc = 0
340
        averAcc = 0
341
        num = 0
342
        Y_true = 0
343
        Y_pred = 0
344
345
        # Train the cnn model
346
        total_step = len(self.dataloader)
347
        curr_lr = self.lr
348
349
        for e in range(self.n_epochs):
350
            # in_epoch = time.time()
351
            self.model.train()
352
            for i, (img, label) in enumerate(self.dataloader):
353
354
                img = Variable(img.cuda().type(self.Tensor))
355
                label = Variable(label.cuda().type(self.LongTensor))
356
357
                # data augmentation
358
                aug_data, aug_label = self.interaug(self.allData, self.allLabel)
359
                img = torch.cat((img, aug_data))
360
                label = torch.cat((label, aug_label))
361
362
363
                tok, outputs = self.model(img)
364
365
                loss = self.criterion_cls(outputs, label) 
366
367
                self.optimizer.zero_grad()
368
                loss.backward()
369
                self.optimizer.step()
370
371
372
            # out_epoch = time.time()
373
374
375
            # test process
376
            if (e + 1) % 1 == 0:
377
                self.model.eval()
378
                Tok, Cls = self.model(test_data)
379
380
381
                loss_test = self.criterion_cls(Cls, test_label)
382
                y_pred = torch.max(Cls, 1)[1]
383
                acc = float((y_pred == test_label).cpu().numpy().astype(int).sum()) / float(test_label.size(0))
384
                train_pred = torch.max(outputs, 1)[1]
385
                train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0))
386
387
                print('Epoch:', e,
388
                      '  Train loss: %.6f' % loss.detach().cpu().numpy(),
389
                      '  Test loss: %.6f' % loss_test.detach().cpu().numpy(),
390
                      '  Train accuracy %.6f' % train_acc,
391
                      '  Test accuracy is %.6f' % acc)
392
393
                self.log_write.write(str(e) + "    " + str(acc) + "\n")
394
                num = num + 1
395
                averAcc = averAcc + acc
396
                if acc > bestAcc:
397
                    bestAcc = acc
398
                    Y_true = test_label
399
                    Y_pred = y_pred
400
401
402
        torch.save(self.model.module.state_dict(), 'model.pth')
403
        averAcc = averAcc / num
404
        print('The average accuracy is:', averAcc)
405
        print('The best accuracy is:', bestAcc)
406
        self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n")
407
        self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n")
408
409
        return bestAcc, averAcc, Y_true, Y_pred
410
        # writer.close()
411
412
413
def main():
414
    best = 0
415
    aver = 0
416
    result_write = open("./results/sub_result.txt", "w")
417
418
    for i in range(9):
419
        starttime = datetime.datetime.now()
420
421
422
        seed_n = np.random.randint(2021)
423
        print('seed is ' + str(seed_n))
424
        random.seed(seed_n)
425
        np.random.seed(seed_n)
426
        torch.manual_seed(seed_n)
427
        torch.cuda.manual_seed(seed_n)
428
        torch.cuda.manual_seed_all(seed_n)
429
430
431
        print('Subject %d' % (i+1))
432
        exp = ExP(i + 1)
433
434
        bestAcc, averAcc, Y_true, Y_pred = exp.train()
435
        print('THE BEST ACCURACY IS ' + str(bestAcc))
436
        result_write.write('Subject ' + str(i + 1) + ' : ' + 'Seed is: ' + str(seed_n) + "\n")
437
        result_write.write('Subject ' + str(i + 1) + ' : ' + 'The best accuracy is: ' + str(bestAcc) + "\n")
438
        result_write.write('Subject ' + str(i + 1) + ' : ' + 'The average accuracy is: ' + str(averAcc) + "\n")
439
440
        endtime = datetime.datetime.now()
441
        print('subject %d duration: '%(i+1) + str(endtime - starttime))
442
        best = best + bestAcc
443
        aver = aver + averAcc
444
        if i == 0:
445
            yt = Y_true
446
            yp = Y_pred
447
        else:
448
            yt = torch.cat((yt, Y_true))
449
            yp = torch.cat((yp, Y_pred))
450
451
452
    best = best / 9
453
    aver = aver / 9
454
455
    result_write.write('**The average Best accuracy is: ' + str(best) + "\n")
456
    result_write.write('The average Aver accuracy is: ' + str(aver) + "\n")
457
    result_write.close()
458
459
460
if __name__ == "__main__":
461
    print(time.asctime(time.localtime(time.time())))
462
    main()
463
    print(time.asctime(time.localtime(time.time())))