a b/Trans.py
1
"""
2
Transformer for EEG classification
3
4
The core idea is slicing, which means to split the signal along the time dimension. Slice is just like the patch in Vision Transformer.
5
"""
6
7
8
import os
9
import numpy as np
10
import math
11
import random
12
import time
13
import scipy.io
14
15
from torch.utils.data import DataLoader
16
from torch.autograd import Variable
17
from torchsummary import summary
18
19
import torch
20
import torch.nn.functional as F
21
22
from torch import nn
23
from torch import Tensor
24
25
from einops import rearrange, reduce, repeat
26
from einops.layers.torch import Rearrange, Reduce
27
from common_spatial_pattern import csp
28
# from confusion_matrix import plot_confusion_matrix
29
# from cm_no_normal import plot_confusion_matrix_nn
30
# from torchsummary import summary
31
32
import matplotlib.pyplot as plt
33
# from torch.utils.tensorboard import SummaryWriter
34
from torch.backends import cudnn
35
cudnn.benchmark = False
36
cudnn.deterministic = True
37
38
# writer = SummaryWriter('./TensorBoardX/')
39
40
# torch.cuda.set_device(6)
41
gpus = [0]
42
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
43
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, gpus))
44
45
46
class PatchEmbedding(nn.Module):
47
    def __init__(self, emb_size):
48
        # self.patch_size = patch_size
49
        super().__init__()
50
        self.projection = nn.Sequential(
51
            nn.Conv2d(1, 2, (1, 51), (1, 1)),
52
            nn.BatchNorm2d(2),
53
            nn.LeakyReLU(0.2),
54
            nn.Conv2d(2, emb_size, (16, 5), stride=(1, 5)),
55
            Rearrange('b e (h) (w) -> b (h w) e'),
56
        )
57
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
58
        # self.positions = nn.Parameter(torch.randn((100 + 1, emb_size)))
59
        # self.positions = nn.Parameter(torch.randn((2200 + 1, emb_size)))
60
61
    def forward(self, x: Tensor) -> Tensor:
62
        b, _, _, _ = x.shape
63
        x = self.projection(x)
64
        cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
65
66
        # position
67
        # x += self.positions
68
        return x
69
70
71
class MultiHeadAttention(nn.Module):
72
    def __init__(self, emb_size, num_heads, dropout):
73
        super().__init__()
74
        self.emb_size = emb_size
75
        self.num_heads = num_heads
76
        self.keys = nn.Linear(emb_size, emb_size)
77
        self.queries = nn.Linear(emb_size, emb_size)
78
        self.values = nn.Linear(emb_size, emb_size)
79
        self.att_drop = nn.Dropout(dropout)
80
        self.projection = nn.Linear(emb_size, emb_size)
81
82
    def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
83
        queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
84
        keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
85
        values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
86
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)  # batch, num_heads, query_len, key_len
87
        if mask is not None:
88
            fill_value = torch.finfo(torch.float32).min
89
            energy.mask_fill(~mask, fill_value)
90
91
        scaling = self.emb_size ** (1 / 2)
92
        att = F.softmax(energy / scaling, dim=-1)
93
        att = self.att_drop(att)
94
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
95
        out = rearrange(out, "b h n d -> b n (h d)")
96
        out = self.projection(out)
97
        return out
98
99
100
class ResidualAdd(nn.Module):
101
    def __init__(self, fn):
102
        super().__init__()
103
        self.fn = fn
104
105
    def forward(self, x, **kwargs):
106
        res = x
107
        x = self.fn(x, **kwargs)
108
        x += res
109
        return x
110
111
112
class FeedForwardBlock(nn.Sequential):
113
    def __init__(self, emb_size, expansion, drop_p):
114
        super().__init__(
115
            nn.Linear(emb_size, expansion * emb_size),
116
            nn.GELU(),
117
            nn.Dropout(drop_p),
118
            nn.Linear(expansion * emb_size, emb_size),
119
        )
120
121
class GELU(nn.Module):
122
    def forward(self, input: Tensor) -> Tensor:
123
        return input*0.5*(1.0+torch.erf(input/math.sqrt(2.0)))
124
125
126
class TransformerEncoderBlock(nn.Sequential):
127
    def __init__(self,
128
                 emb_size,
129
                 num_heads=5,
130
                 drop_p=0.5,
131
                 forward_expansion=4,
132
                 forward_drop_p=0.5):
133
        super().__init__(
134
            ResidualAdd(nn.Sequential(
135
                nn.LayerNorm(emb_size),
136
                MultiHeadAttention(emb_size, num_heads, drop_p),
137
                nn.Dropout(drop_p)
138
            )),
139
            ResidualAdd(nn.Sequential(
140
                nn.LayerNorm(emb_size),
141
                FeedForwardBlock(
142
                    emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
143
                nn.Dropout(drop_p)
144
            )
145
            ))
146
147
148
class TransformerEncoder(nn.Sequential):
149
    def __init__(self, depth, emb_size):
150
        super().__init__(*[TransformerEncoderBlock(emb_size) for _ in range(depth)])
151
152
153
class ClassificationHead(nn.Sequential):
154
    def __init__(self, emb_size, n_classes):
155
        super().__init__()
156
        self.clshead = nn.Sequential(
157
            Reduce('b n e -> b e', reduction='mean'),
158
            nn.LayerNorm(emb_size),
159
            nn.Linear(emb_size, n_classes)
160
        )
161
162
    def forward(self, x):
163
        out = self.clshead(x)
164
        return x, out
165
166
167
class ViT(nn.Sequential):
168
    def __init__(self, emb_size=10, depth=3, n_classes=4, **kwargs):
169
        super().__init__(
170
            # channel_attention(),
171
            ResidualAdd(
172
                nn.Sequential(
173
                    nn.LayerNorm(1000),
174
                    channel_attention(),
175
                    nn.Dropout(0.5),
176
                )
177
            ),
178
179
            PatchEmbedding(emb_size),
180
            TransformerEncoder(depth, emb_size),
181
            ClassificationHead(emb_size, n_classes)
182
        )
183
184
185
class channel_attention(nn.Module):
186
    def __init__(self, sequence_num=1000, inter=30):
187
        super(channel_attention, self).__init__()
188
        self.sequence_num = sequence_num
189
        self.inter = inter
190
        self.extract_sequence = int(self.sequence_num / self.inter)  # You could choose to do that for less computation
191
192
        self.query = nn.Sequential(
193
            nn.Linear(16, 16),
194
            nn.LayerNorm(16),  # also may introduce improvement to a certain extent
195
            nn.Dropout(0.3)
196
        )
197
        self.key = nn.Sequential(
198
            nn.Linear(16, 16),
199
            # nn.LeakyReLU(),
200
            nn.LayerNorm(16),
201
            nn.Dropout(0.3)
202
        )
203
204
        # self.value = self.key
205
        self.projection = nn.Sequential(
206
            nn.Linear(16, 16),
207
            # nn.LeakyReLU(),
208
            nn.LayerNorm(16),
209
            nn.Dropout(0.3),
210
        )
211
212
        self.drop_out = nn.Dropout(0)
213
        self.pooling = nn.AvgPool2d(kernel_size=(1, self.inter), stride=(1, self.inter))
214
215
        for m in self.modules():
216
            if isinstance(m, nn.Linear):
217
                nn.init.xavier_normal_(m.weight)
218
                if m.bias is not None:
219
                    nn.init.constant_(m.bias, 0.0)
220
221
    def forward(self, x):
222
        temp = rearrange(x, 'b o c s->b o s c')
223
        temp_query = rearrange(self.query(temp), 'b o s c -> b o c s')
224
        temp_key = rearrange(self.key(temp), 'b o s c -> b o c s')
225
226
        channel_query = self.pooling(temp_query)
227
        channel_key = self.pooling(temp_key)
228
229
        scaling = self.extract_sequence ** (1 / 2)
230
231
        channel_atten = torch.einsum('b o c s, b o m s -> b o c m', channel_query, channel_key) / scaling
232
233
        channel_atten_score = F.softmax(channel_atten, dim=-1)
234
        channel_atten_score = self.drop_out(channel_atten_score)
235
236
        out = torch.einsum('b o c s, b o c m -> b o c s', x, channel_atten_score)
237
        '''
238
        projections after or before multiplying with attention score are almost the same.
239
        '''
240
        out = rearrange(out, 'b o c s -> b o s c')
241
        out = self.projection(out)
242
        out = rearrange(out, 'b o s c -> b o c s')
243
        return out
244
245
246
class Trans():
247
    def __init__(self, nsub):
248
        super(Trans, self).__init__()
249
        self.batch_size = 50
250
        self.n_epochs = 1000
251
        self.img_height = 22
252
        self.img_width = 600
253
        self.channels = 1
254
        self.c_dim = 4
255
        self.lr = 0.0002
256
        self.b1 = 0.5
257
        self.b2 = 0.9
258
        self.nSub = nsub
259
        self.start_epoch = 0
260
        self.root = '...'  # the path of data
261
262
        self.pretrain = False
263
264
        self.log_write = open("./results/log_subject%d.txt" % self.nSub, "w")
265
266
        self.img_shape = (self.channels, self.img_height, self.img_width)  # something no use
267
268
        self.Tensor = torch.cuda.FloatTensor
269
        self.LongTensor = torch.cuda.LongTensor
270
271
        self.criterion_l1 = torch.nn.L1Loss().cuda()
272
        self.criterion_l2 = torch.nn.MSELoss().cuda()
273
        self.criterion_cls = torch.nn.CrossEntropyLoss().cuda()
274
275
        self.model = ViT().cuda()
276
        self.model = nn.DataParallel(self.model, device_ids=[i for i in range(len(gpus))])
277
        self.model = self.model.cuda()
278
        summary(self.model, (1, 16, 1000))
279
280
        self.centers = {}
281
282
    def get_source_data(self):
283
284
        # to get the data of target subject
285
        self.total_data = scipy.io.loadmat(self.root + 'A0%dT.mat' % self.nSub)
286
        self.train_data = self.total_data['data']
287
        self.train_label = self.total_data['label']
288
289
        self.train_data = np.transpose(self.train_data, (2, 1, 0))
290
        self.train_data = np.expand_dims(self.train_data, axis=1)
291
        self.train_label = np.transpose(self.train_label)
292
293
        self.allData = self.train_data
294
        self.allLabel = self.train_label[0]
295
296
        # test data
297
        # to get the data of target subject
298
        self.test_tmp = scipy.io.loadmat(self.root + 'A0%dE.mat' % self.nSub)
299
        self.test_data = self.test_tmp['data']
300
        self.test_label = self.test_tmp['label']
301
302
        # self.train_data = self.train_data[250:1000, :, :]
303
        self.test_data = np.transpose(self.test_data, (2, 1, 0))
304
        self.test_data = np.expand_dims(self.test_data, axis=1)
305
        self.test_label = np.transpose(self.test_label)
306
307
        self.testData = self.test_data
308
        self.testLabel = self.test_label[0]
309
310
        # standardize
311
        target_mean = np.mean(self.allData)
312
        target_std = np.std(self.allData)
313
        self.allData = (self.allData - target_mean) / target_std
314
        self.testData = (self.testData - target_mean) / target_std
315
316
        tmp_alldata = np.transpose(np.squeeze(self.allData), (0, 2, 1))
317
        Wb = csp(tmp_alldata, self.allLabel-1)  # common spatial pattern
318
        self.allData = np.einsum('abcd, ce -> abed', self.allData, Wb)
319
        self.testData = np.einsum('abcd, ce -> abed', self.testData, Wb)
320
        return self.allData, self.allLabel, self.testData, self.testLabel
321
322
    def update_lr(self, optimizer, lr):
323
        for param_group in optimizer.param_groups:
324
            param_group['lr'] = lr
325
326
    # Do some data augmentation is a potential way to improve the generalization ability
327
    def aug(self, img, label):
328
        aug_data = []
329
        aug_label = []
330
        return aug_data, aug_label
331
332
    def train(self):
333
334
335
        img, label, test_data, test_label = self.get_source_data()
336
        img = torch.from_numpy(img)
337
        label = torch.from_numpy(label - 1)
338
339
340
        dataset = torch.utils.data.TensorDataset(img, label)
341
        self.dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)
342
343
        test_data = torch.from_numpy(test_data)
344
        test_label = torch.from_numpy(test_label - 1)
345
        test_dataset = torch.utils.data.TensorDataset(test_data, test_label)
346
        self.test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=True)
347
348
        # Optimizers
349
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(self.b1, self.b2))
350
351
        test_data = Variable(test_data.type(self.Tensor))
352
        test_label = Variable(test_label.type(self.LongTensor))
353
354
        bestAcc = 0
355
        averAcc = 0
356
        num = 0
357
        Y_true = 0
358
        Y_pred = 0
359
360
        # Train the cnn model
361
        total_step = len(self.dataloader)
362
        curr_lr = self.lr
363
        # some better optimization strategy is worthy to explore. Sometimes terrible over-fitting.
364
365
366
        for e in range(self.n_epochs):
367
            in_epoch = time.time()
368
            self.model.train()
369
            for i, (img, label) in enumerate(self.dataloader):
370
371
                img = Variable(img.cuda().type(self.Tensor))
372
                label = Variable(label.cuda().type(self.LongTensor))
373
                tok, outputs = self.model(img)
374
                loss = self.criterion_cls(outputs, label)
375
                self.optimizer.zero_grad()
376
                loss.backward()
377
                self.optimizer.step()
378
379
            out_epoch = time.time()
380
381
            if (e + 1) % 1 == 0:
382
                self.model.eval()
383
                Tok, Cls = self.model(test_data)
384
385
                loss_test = self.criterion_cls(Cls, test_label)
386
                y_pred = torch.max(Cls, 1)[1]
387
                acc = float((y_pred == test_label).cpu().numpy().astype(int).sum()) / float(test_label.size(0))
388
                train_pred = torch.max(outputs, 1)[1]
389
                train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0))
390
                print('Epoch:', e,
391
                      '  Train loss:', loss.detach().cpu().numpy(),
392
                      '  Test loss:', loss_test.detach().cpu().numpy(),
393
                      '  Train accuracy:', train_acc,
394
                      '  Test accuracy is:', acc)
395
                self.log_write.write(str(e) + "    " + str(acc) + "\n")
396
                num = num + 1
397
                averAcc = averAcc + acc
398
                if acc > bestAcc:
399
                    bestAcc = acc
400
                    Y_true = test_label
401
                    Y_pred = y_pred
402
403
        torch.save(self.model.module.state_dict(), 'model.pth')
404
        averAcc = averAcc / num
405
        print('The average accuracy is:', averAcc)
406
        print('The best accuracy is:', bestAcc)
407
        self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n")
408
        self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n")
409
410
        return bestAcc, averAcc, Y_true, Y_pred
411
412
413
def main():
414
    best = 0
415
    aver = 0
416
    result_write = open("./results/sub_result.txt", "w")
417
418
    for i in range(9):
419
        seed_n = np.random.randint(500)
420
        print('seed is ' + str(seed_n))
421
        random.seed(seed_n)
422
        np.random.seed(seed_n)
423
        torch.manual_seed(seed_n)
424
        torch.cuda.manual_seed(seed_n)
425
        torch.cuda.manual_seed_all(seed_n)
426
        print('Subject %d' % (i+1))
427
        trans = Trans(i + 1)
428
        bestAcc, averAcc, Y_true, Y_pred = trans.train()
429
        print('THE BEST ACCURACY IS ' + str(bestAcc))
430
        result_write.write('Subject ' + str(i + 1) + ' : ' + 'Seed is: ' + str(seed_n) + "\n")
431
        result_write.write('**Subject ' + str(i + 1) + ' : ' + 'The best accuracy is: ' + str(bestAcc) + "\n")
432
        result_write.write('Subject ' + str(i + 1) + ' : ' + 'The average accuracy is: ' + str(averAcc) + "\n")
433
        # plot_confusion_matrix(Y_true, Y_pred, i+1)
434
        best = best + bestAcc
435
        aver = aver + averAcc
436
        if i == 0:
437
            yt = Y_true
438
            yp = Y_pred
439
        else:
440
            yt = torch.cat((yt, Y_true))
441
            yp = torch.cat((yp, Y_pred))
442
443
444
    best = best / 9
445
    aver = aver / 9
446
    # plot_confusion_matrix(yt, yp, 666)
447
    result_write.write('**The average Best accuracy is: ' + str(best) + "\n")
448
    result_write.write('The average Aver accuracy is: ' + str(aver) + "\n")
449
    result_write.close()
450
451
452
if __name__ == "__main__":
453
    main()