a b/conformer_BCIIV2b.py
1
"""
2
EEG conformer 
3
Test on the datasets 2b
4
"""
5
6
7
import argparse
8
import os
9
gpus = [1]
10
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
11
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, gpus))
12
import numpy as np
13
import math
14
import glob
15
import random
16
import itertools
17
import datetime
18
import time
19
import datetime
20
import sys
21
import scipy.io
22
23
import torchvision.transforms as transforms
24
from torchvision.utils import save_image, make_grid
25
26
from torch.utils.data import DataLoader
27
from torch.autograd import Variable
28
from torchsummary import summary
29
import torch.autograd as autograd
30
from torchvision.models import vgg19
31
32
import torch.nn as nn
33
import torch.nn.functional as F
34
import torch
35
import torch.nn.init as init
36
37
from torch.utils.data import Dataset
38
from PIL import Image
39
import torchvision.transforms as transforms
40
from sklearn.decomposition import PCA
41
42
import torch
43
import torch.nn.functional as F
44
import matplotlib.pyplot as plt
45
46
from torch import nn
47
from torch import Tensor
48
from PIL import Image
49
from torchvision.transforms import Compose, Resize, ToTensor
50
from einops import rearrange, reduce, repeat
51
from einops.layers.torch import Rearrange, Reduce
52
# from common_spatial_pattern import csp
53
54
import matplotlib.pyplot as plt
55
# from torch.utils.tensorboard import SummaryWriter
56
from torch.backends import cudnn
57
cudnn.benchmark = False
58
cudnn.deterministic = True
59
60
61
class PatchEmbedding(nn.Module):
62
    def __init__(self, emb_size=40):
63
        # self.patch_size = patch_size
64
        super().__init__()
65
66
        self.shallownet = nn.Sequential(
67
            nn.Conv2d(1, 40, (1, 25), (1, 1)),
68
            nn.Conv2d(40, 40, (3, 1), (1, 1)),
69
            nn.BatchNorm2d(40),
70
            nn.ELU(),
71
            nn.AvgPool2d((1, 75), (1, 15)),
72
            nn.Dropout(0.5),
73
        )
74
75
        self.projection = nn.Sequential(
76
            nn.Conv2d(40, emb_size, (1, 1), stride=(1, 1)),  # 5 is better than 1
77
            Rearrange('b e (h) (w) -> b (h w) e'),
78
        )
79
80
81
    def forward(self, x: Tensor) -> Tensor:
82
        b, _, _, _ = x.shape
83
        x = self.shallownet(x)
84
        x = self.projection(x)
85
        return x
86
87
88
class MultiHeadAttention(nn.Module):
89
    def __init__(self, emb_size, num_heads, dropout):
90
        super().__init__()
91
        self.emb_size = emb_size
92
        self.num_heads = num_heads
93
        self.keys = nn.Linear(emb_size, emb_size)
94
        self.queries = nn.Linear(emb_size, emb_size)
95
        self.values = nn.Linear(emb_size, emb_size)
96
        self.att_drop = nn.Dropout(dropout)
97
        self.projection = nn.Linear(emb_size, emb_size)
98
99
    def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
100
        queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
101
        keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
102
        values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
103
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)  # batch, num_heads, query_len, key_len
104
        if mask is not None:
105
            fill_value = torch.finfo(torch.float32).min
106
            energy.mask_fill(~mask, fill_value)
107
108
        scaling = self.emb_size ** (1 / 2)
109
        att = F.softmax(energy / scaling, dim=-1)
110
        att = self.att_drop(att)
111
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
112
        out = rearrange(out, "b h n d -> b n (h d)")
113
        out = self.projection(out)
114
        return out
115
116
117
class ResidualAdd(nn.Module):
118
    def __init__(self, fn):
119
        super().__init__()
120
        self.fn = fn
121
122
    def forward(self, x, **kwargs):
123
        res = x
124
        x = self.fn(x, **kwargs)
125
        x += res
126
        return x
127
128
129
class FeedForwardBlock(nn.Sequential):
130
    def __init__(self, emb_size, expansion, drop_p):
131
        super().__init__(
132
            nn.Linear(emb_size, expansion * emb_size),
133
            nn.GELU(),
134
            nn.Dropout(drop_p),
135
            nn.Linear(expansion * emb_size, emb_size),
136
        )
137
138
class GELU(nn.Module):
139
    def forward(self, input: Tensor) -> Tensor:
140
        return input*0.5*(1.0+torch.erf(input/math.sqrt(2.0)))
141
142
class TransformerEncoderBlock(nn.Sequential):
143
    def __init__(self,
144
                 emb_size,
145
                 num_heads=5,
146
                 drop_p=0.5,
147
                 forward_expansion=4,
148
                 forward_drop_p=0.5):
149
        super().__init__(
150
            ResidualAdd(nn.Sequential(
151
                nn.LayerNorm(emb_size),
152
                MultiHeadAttention(emb_size, num_heads, drop_p),
153
                nn.Dropout(drop_p)
154
            )),
155
            ResidualAdd(nn.Sequential(
156
                nn.LayerNorm(emb_size),
157
                FeedForwardBlock(
158
                    emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
159
                nn.Dropout(drop_p)
160
            )
161
            ))
162
163
164
class TransformerEncoder(nn.Sequential):
165
    def __init__(self, depth, emb_size):
166
        super().__init__(*[TransformerEncoderBlock(emb_size) for _ in range(depth)])
167
168
169
class ClassificationHead(nn.Sequential):
170
    def __init__(self, emb_size, n_classes):
171
        super().__init__()
172
        self.cov = nn.Sequential(
173
            nn.Conv1d(190, 1, 1, 1),
174
            nn.LeakyReLU(0.2),
175
            nn.Dropout(0.5)
176
        )
177
        self.clshead = nn.Sequential(
178
            Reduce('b n e -> b e', reduction='mean'),
179
            nn.LayerNorm(emb_size),
180
            nn.Linear(emb_size, n_classes)
181
        )
182
        self.clshead_fc = nn.Sequential(
183
            Reduce('b n e -> b e', reduction='mean'),
184
            nn.LayerNorm(emb_size),
185
            nn.Linear(emb_size, 32),
186
            nn.ELU(),
187
            nn.Dropout(0.5),
188
            nn.Linear(32, n_classes)
189
        )
190
        self.fc = nn.Sequential(
191
            nn.Linear(2440, 256),
192
            nn.ELU(),
193
            nn.Dropout(0.5),
194
            nn.Linear(256, 32),
195
            nn.ELU(),
196
            nn.Dropout(0.3),
197
            nn.Linear(32, 2)
198
        )
199
200
    def forward(self, x):
201
        x = x.contiguous().view(x.size(0), -1)
202
        out = self.fc(x)
203
        
204
        return x, out
205
206
207
# ! Rethink the use of Transformer for EEG signal
208
class ViT(nn.Sequential):
209
    def __init__(self, emb_size=40, depth=10, n_classes=2, **kwargs):
210
        super().__init__(
211
212
            PatchEmbedding(emb_size),
213
            TransformerEncoder(depth, emb_size),
214
            ClassificationHead(emb_size, n_classes)
215
        )
216
217
218
class ExGAN():
219
    def __init__(self, nsub):
220
        super(ExGAN, self).__init__()
221
        self.batch_size = 100
222
        self.n_epochs = 2000
223
        self.img_height = 22
224
        self.img_width = 600
225
        self.channels = 1
226
        self.c_dim = 4
227
        self.lr = 0.0002
228
        self.b1 = 0.5
229
        self.b2 = 0.999
230
        self.alpha = 0.0002
231
        self.dimension = (190, 50)
232
233
        self.nSub = nsub
234
235
        self.start_epoch = 0
236
        self.root = '/Data/strict_TE/2b/'
237
238
        self.pretrain = False
239
240
        self.log_write = open("/Code/CT/results/cf/2b/log_subject%d.txt" % self.nSub, "w")
241
242
        self.img_shape = (self.channels, self.img_height, self.img_width)
243
244
        self.Tensor = torch.cuda.FloatTensor
245
        self.LongTensor = torch.cuda.LongTensor
246
247
        self.criterion_l1 = torch.nn.L1Loss().cuda()
248
        self.criterion_l2 = torch.nn.MSELoss().cuda()
249
        self.criterion_cls = torch.nn.CrossEntropyLoss().cuda()
250
251
        self.model = ViT().cuda()
252
        self.model = nn.DataParallel(self.model, device_ids=[i for i in range(len(gpus))])
253
        self.model = self.model.cuda()
254
255
        self.centers = {}
256
257
    def interaug(self, timg, label):
258
        aug_data = []
259
        aug_label = []
260
        for cls4aug in range(2):
261
            cls_idx = np.where(label == cls4aug + 1)
262
            tmp_data = timg[cls_idx]
263
            tmp_label = label[cls_idx]
264
265
            tmp_aug_data = np.zeros((int(self.batch_size / 2), 1, 3, 1000))
266
            for ri in range(int(self.batch_size / 2)):
267
                for rj in range(8):
268
                    rand_idx = np.random.randint(0, tmp_data.shape[0], 8)
269
                    tmp_aug_data[ri, :, :, rj * 125:(rj + 1) * 125] = tmp_data[rand_idx[rj], :, :,
270
                                                                      rj * 125:(rj + 1) * 125]
271
272
            aug_data.append(tmp_aug_data)
273
            aug_label.append(tmp_label[:int(self.batch_size / 2)])
274
        aug_data = np.concatenate(aug_data)
275
        aug_label = np.concatenate(aug_label)
276
        aug_shuffle = np.random.permutation(len(aug_data))
277
        aug_data = aug_data[aug_shuffle, :, :]
278
        aug_label = aug_label[aug_shuffle]
279
280
        aug_data = torch.from_numpy(aug_data).cuda()
281
        aug_data = aug_data.float()
282
        aug_label = torch.from_numpy(aug_label-1).cuda()
283
        aug_label = aug_label.long()
284
        return aug_data, aug_label
285
286
    def get_source_data(self):
287
288
        # to get the data of target subject
289
        train_data = []
290
        train_label = []
291
        for session_index in range(3):
292
            target_tmp = scipy.io.loadmat(self.root + 'B0%d0%dT.mat' % (self.nSub, session_index+1))
293
            train_data_tmp = target_tmp['data']
294
            train_label_tmp = target_tmp['label']
295
            train_data_tmp = np.transpose(train_data_tmp, (2, 1, 0))
296
            train_data_tmp = np.expand_dims(train_data_tmp, axis=1)
297
            train_label_tmp = np.transpose(train_label_tmp)
298
            train_label_tmp = train_label_tmp[0]
299
            train_data.append(train_data_tmp)
300
            train_label.append(train_label_tmp)
301
302
        self.allData = np.concatenate(train_data)
303
        self.allLabel = np.concatenate(train_label)
304
305
        shuffle_num = np.random.permutation(len(self.allData))
306
        self.allData = self.allData[shuffle_num, :, :, :]
307
        self.allLabel = self.allLabel[shuffle_num]
308
309
        # test data
310
        test_data = []
311
        test_label = []
312
        for session_index in range(2):
313
            test_tmp = scipy.io.loadmat(self.root + 'B0%d0%dE.mat' % (self.nSub, session_index+4))
314
            test_data_tmp = test_tmp['data']
315
            test_label_tmp = test_tmp['label']
316
            test_data_tmp = np.transpose(test_data_tmp, (2, 1, 0))
317
            test_data_tmp = np.expand_dims(test_data_tmp, axis=1)
318
            test_label_tmp = np.transpose(test_label_tmp)
319
            test_label_tmp = test_label_tmp[0]
320
            test_data.append(test_data_tmp)
321
            test_label.append(test_label_tmp)
322
323
        self.testData = np.concatenate(test_data)
324
        self.testLabel = np.concatenate(test_label)
325
326
        # standardize
327
        target_mean = np.mean(self.allData)
328
        target_std = np.std(self.allData)
329
        self.allData = (self.allData - target_mean) / target_std
330
        self.testData = (self.testData - target_mean) / target_std
331
332
        return self.allData, self.allLabel, self.testData, self.testLabel
333
334
    def update_lr(self, optimizer, lr):
335
        for param_group in optimizer.param_groups:
336
            param_group['lr'] = lr
337
338
    def aug(self, img, label):
339
        aug_data = []
340
        aug_label = []
341
        for cls4aug in range(4):
342
            cls_idx = np.where(label == cls4aug + 1)
343
            tmp_data = img[cls_idx]
344
            tmp_label = label[cls_idx]
345
346
            tmp_aug_data = np.zeros(tmp_data.shape)
347
            for ri in range(tmp_data.shape[0]):
348
                for rj in range(8):
349
                    rand_idx = np.random.randint(0, tmp_data.shape[0], 8)
350
                    tmp_aug_data[ri, :, :, rj * 125:(rj + 1) * 125] = tmp_data[rand_idx[rj], :, :, rj * 125:(rj + 1) * 125]
351
352
            aug_data.append(tmp_aug_data)
353
            aug_label.append(tmp_label)
354
        aug_data = np.concatenate(aug_data)
355
        aug_label = np.concatenate(aug_label)
356
        aug_shuffle = np.random.permutation(len(aug_data))
357
        aug_data = aug_data[aug_shuffle, :, :]
358
        aug_label = aug_label[aug_shuffle]
359
360
        return aug_data, aug_label
361
362
    def update_centers(self, feature, label):
363
            deltac = {}
364
            count = {}
365
            count[0] = 0
366
            for i in range(len(label)):
367
                l = label[i]
368
                if l in deltac:
369
                    deltac[l] += self.centers[l]-feature[i]
370
                else:
371
                    deltac[l] = self.centers[l]-feature[i]
372
                if l in count:
373
                    count[l] += 1
374
                else:
375
                    count[l] = 1
376
377
            for ke in deltac.keys():
378
                deltac[ke] = deltac[ke]/(count[ke]+1)
379
380
            return deltac
381
382
    def train(self):
383
384
        img, label, test_data, test_label = self.get_source_data()
385
386
        img = torch.from_numpy(img)
387
        label = torch.from_numpy(label - 1)
388
389
        dataset = torch.utils.data.TensorDataset(img, label)
390
        self.dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)
391
392
        test_data = torch.from_numpy(test_data)
393
        test_label = torch.from_numpy(test_label - 1)
394
        test_dataset = torch.utils.data.TensorDataset(test_data, test_label)
395
        self.test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=True)
396
397
        for i in range(self.c_dim):
398
            self.centers[i] = torch.randn(self.dimension)
399
            self.centers[i] = self.centers[i].cuda()
400
401
        # Optimizers
402
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(self.b1, self.b2))
403
404
        test_data = Variable(test_data.type(self.Tensor))
405
        test_label = Variable(test_label.type(self.LongTensor))
406
407
        bestAcc = 0
408
        averAcc = 0
409
        num = 0
410
        Y_true = 0
411
        Y_pred = 0
412
413
        # Train the cnn model
414
        total_step = len(self.dataloader)
415
        curr_lr = self.lr
416
417
        for e in range(self.n_epochs):
418
            in_epoch = time.time()
419
            self.model.train()
420
            for i, (img, label) in enumerate(self.dataloader):
421
422
                img = Variable(img.cuda().type(self.Tensor))
423
                # img = self.active_function(img)
424
                label = Variable(label.cuda().type(self.LongTensor))
425
426
                aug_data, aug_label = self.interaug(self.allData, self.allLabel)
427
                img = torch.cat((img, aug_data))
428
                label = torch.cat((label, aug_label))
429
430
                tok, outputs = self.model(img)
431
432
                loss = self.criterion_cls(outputs, label)
433
434
                self.optimizer.zero_grad()
435
                loss.backward()
436
                self.optimizer.step()
437
438
            out_epoch = time.time()
439
440
            if (e + 1) % 1 == 0:
441
                self.model.eval()
442
                Tok, Cls = self.model(test_data)
443
444
445
                loss_test = self.criterion_cls(Cls, test_label)
446
                y_pred = torch.max(Cls, 1)[1]
447
                acc = float((y_pred == test_label).cpu().numpy().astype(int).sum()) / float(test_label.size(0))
448
                train_pred = torch.max(outputs, 1)[1]
449
                train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0))
450
                print('Epoch:', e,
451
                      '  Train loss: %.6f' % loss.detach().cpu().numpy(),
452
                      '  Test loss: %.6f' % loss_test.detach().cpu().numpy(),
453
                      '  Train accuracy %.6f' % train_acc,
454
                      '  Test accuracy is %.6f' % acc)
455
                self.log_write.write(str(e) + "    " + str(acc) + "\n")
456
                num = num + 1
457
                averAcc = averAcc + acc
458
                if acc > bestAcc:
459
                    bestAcc = acc
460
                    Y_true = test_label
461
                    Y_pred = y_pred
462
463
        torch.save(self.model.module.state_dict(), 'model.pth')
464
        averAcc = averAcc / num
465
        print('The average accuracy is:', averAcc)
466
        print('The best accuracy is:', bestAcc)
467
        self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n")
468
        self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n")
469
        return bestAcc, averAcc, Y_true, Y_pred
470
471
472
def main():
473
    best = 0
474
    aver = 0
475
    result_write = open("/Code/CT/results/cf/2b/sub_result.txt", "w")
476
477
    for i in range(9):
478
        starttime = datetime.datetime.now()
479
        seed_n = np.random.randint(2021)
480
481
        print('seed is ' + str(seed_n))
482
        random.seed(seed_n)
483
        np.random.seed(seed_n)
484
        torch.manual_seed(seed_n)
485
        torch.cuda.manual_seed(seed_n)
486
        torch.cuda.manual_seed_all(seed_n)
487
        print('Subject %d' % (i+1))
488
        exgan = ExGAN(i + 1)
489
490
        bestAcc, averAcc, Y_true, Y_pred = exgan.train()
491
        print('THE BEST ACCURACY IS ' + str(bestAcc))
492
        result_write.write('Subject ' + str(i + 1) + ' : ' + 'Seed is: ' + str(seed_n) + "\n")
493
        result_write.write('**Subject ' + str(i + 1) + ' : ' + 'The best accuracy is: ' + str(bestAcc) + "\n")
494
        result_write.write('Subject ' + str(i + 1) + ' : ' + 'The average accuracy is: ' + str(averAcc) + "\n")
495
496
        endtime = datetime.datetime.now()
497
        print('subject %d duration: '%(i+1) + str(endtime - starttime))
498
499
        best = best + bestAcc
500
        aver = aver + averAcc
501
        if i == 0:
502
            yt = Y_true
503
            yp = Y_pred
504
        else:
505
            yt = torch.cat((yt, Y_true))
506
            yp = torch.cat((yp, Y_pred))
507
508
509
    best = best / 9
510
    aver = aver / 9
511
512
    result_write.write('**The average Best accuracy is: ' + str(best) + "\n")
513
    result_write.write('The average Aver accuracy is: ' + str(aver) + "\n")
514
    result_write.close()
515
516
517
if __name__ == "__main__":
518
    print(time.asctime(time.localtime(time.time())))
519
    main()
520
    print(time.asctime(time.localtime(time.time())))