|
a |
|
b/conformer.py |
|
|
1 |
""" |
|
|
2 |
EEG Conformer |
|
|
3 |
|
|
|
4 |
Convolutional Transformer for EEG decoding |
|
|
5 |
|
|
|
6 |
Couple CNN and Transformer in a concise manner with amazing results |
|
|
7 |
""" |
|
|
8 |
# remember to change paths |
|
|
9 |
|
|
|
10 |
import argparse |
|
|
11 |
import os |
|
|
12 |
gpus = [0] |
|
|
13 |
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' |
|
|
14 |
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, gpus)) |
|
|
15 |
import numpy as np |
|
|
16 |
import math |
|
|
17 |
import glob |
|
|
18 |
import random |
|
|
19 |
import itertools |
|
|
20 |
import datetime |
|
|
21 |
import time |
|
|
22 |
import datetime |
|
|
23 |
import sys |
|
|
24 |
import scipy.io |
|
|
25 |
|
|
|
26 |
import torchvision.transforms as transforms |
|
|
27 |
from torchvision.utils import save_image, make_grid |
|
|
28 |
|
|
|
29 |
from torch.utils.data import DataLoader |
|
|
30 |
from torch.autograd import Variable |
|
|
31 |
from torchsummary import summary |
|
|
32 |
import torch.autograd as autograd |
|
|
33 |
from torchvision.models import vgg19 |
|
|
34 |
|
|
|
35 |
import torch.nn as nn |
|
|
36 |
import torch.nn.functional as F |
|
|
37 |
import torch |
|
|
38 |
import torch.nn.init as init |
|
|
39 |
|
|
|
40 |
from torch.utils.data import Dataset |
|
|
41 |
from PIL import Image |
|
|
42 |
import torchvision.transforms as transforms |
|
|
43 |
from sklearn.decomposition import PCA |
|
|
44 |
|
|
|
45 |
import torch |
|
|
46 |
import torch.nn.functional as F |
|
|
47 |
import matplotlib.pyplot as plt |
|
|
48 |
|
|
|
49 |
from torch import nn |
|
|
50 |
from torch import Tensor |
|
|
51 |
from PIL import Image |
|
|
52 |
from torchvision.transforms import Compose, Resize, ToTensor |
|
|
53 |
from einops import rearrange, reduce, repeat |
|
|
54 |
from einops.layers.torch import Rearrange, Reduce |
|
|
55 |
# from common_spatial_pattern import csp |
|
|
56 |
|
|
|
57 |
import matplotlib.pyplot as plt |
|
|
58 |
# from torch.utils.tensorboard import SummaryWriter |
|
|
59 |
from torch.backends import cudnn |
|
|
60 |
cudnn.benchmark = False |
|
|
61 |
cudnn.deterministic = True |
|
|
62 |
|
|
|
63 |
# writer = SummaryWriter('./TensorBoardX/') |
|
|
64 |
|
|
|
65 |
|
|
|
66 |
# Convolution module |
|
|
67 |
# use conv to capture local features, instead of postion embedding. |
|
|
68 |
class PatchEmbedding(nn.Module): |
|
|
69 |
def __init__(self, emb_size=40): |
|
|
70 |
# self.patch_size = patch_size |
|
|
71 |
super().__init__() |
|
|
72 |
|
|
|
73 |
self.shallownet = nn.Sequential( |
|
|
74 |
nn.Conv2d(1, 40, (1, 25), (1, 1)), |
|
|
75 |
nn.Conv2d(40, 40, (22, 1), (1, 1)), |
|
|
76 |
nn.BatchNorm2d(40), |
|
|
77 |
nn.ELU(), |
|
|
78 |
nn.AvgPool2d((1, 75), (1, 15)), # pooling acts as slicing to obtain 'patch' along the time dimension as in ViT |
|
|
79 |
nn.Dropout(0.5), |
|
|
80 |
) |
|
|
81 |
|
|
|
82 |
self.projection = nn.Sequential( |
|
|
83 |
nn.Conv2d(40, emb_size, (1, 1), stride=(1, 1)), # transpose, conv could enhance fiting ability slightly |
|
|
84 |
Rearrange('b e (h) (w) -> b (h w) e'), |
|
|
85 |
) |
|
|
86 |
|
|
|
87 |
|
|
|
88 |
def forward(self, x: Tensor) -> Tensor: |
|
|
89 |
b, _, _, _ = x.shape |
|
|
90 |
x = self.shallownet(x) |
|
|
91 |
x = self.projection(x) |
|
|
92 |
return x |
|
|
93 |
|
|
|
94 |
|
|
|
95 |
class MultiHeadAttention(nn.Module): |
|
|
96 |
def __init__(self, emb_size, num_heads, dropout): |
|
|
97 |
super().__init__() |
|
|
98 |
self.emb_size = emb_size |
|
|
99 |
self.num_heads = num_heads |
|
|
100 |
self.keys = nn.Linear(emb_size, emb_size) |
|
|
101 |
self.queries = nn.Linear(emb_size, emb_size) |
|
|
102 |
self.values = nn.Linear(emb_size, emb_size) |
|
|
103 |
self.att_drop = nn.Dropout(dropout) |
|
|
104 |
self.projection = nn.Linear(emb_size, emb_size) |
|
|
105 |
|
|
|
106 |
def forward(self, x: Tensor, mask: Tensor = None) -> Tensor: |
|
|
107 |
queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads) |
|
|
108 |
keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads) |
|
|
109 |
values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads) |
|
|
110 |
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) |
|
|
111 |
if mask is not None: |
|
|
112 |
fill_value = torch.finfo(torch.float32).min |
|
|
113 |
energy.mask_fill(~mask, fill_value) |
|
|
114 |
|
|
|
115 |
scaling = self.emb_size ** (1 / 2) |
|
|
116 |
att = F.softmax(energy / scaling, dim=-1) |
|
|
117 |
att = self.att_drop(att) |
|
|
118 |
out = torch.einsum('bhal, bhlv -> bhav ', att, values) |
|
|
119 |
out = rearrange(out, "b h n d -> b n (h d)") |
|
|
120 |
out = self.projection(out) |
|
|
121 |
return out |
|
|
122 |
|
|
|
123 |
|
|
|
124 |
class ResidualAdd(nn.Module): |
|
|
125 |
def __init__(self, fn): |
|
|
126 |
super().__init__() |
|
|
127 |
self.fn = fn |
|
|
128 |
|
|
|
129 |
def forward(self, x, **kwargs): |
|
|
130 |
res = x |
|
|
131 |
x = self.fn(x, **kwargs) |
|
|
132 |
x += res |
|
|
133 |
return x |
|
|
134 |
|
|
|
135 |
|
|
|
136 |
class FeedForwardBlock(nn.Sequential): |
|
|
137 |
def __init__(self, emb_size, expansion, drop_p): |
|
|
138 |
super().__init__( |
|
|
139 |
nn.Linear(emb_size, expansion * emb_size), |
|
|
140 |
nn.GELU(), |
|
|
141 |
nn.Dropout(drop_p), |
|
|
142 |
nn.Linear(expansion * emb_size, emb_size), |
|
|
143 |
) |
|
|
144 |
|
|
|
145 |
|
|
|
146 |
class GELU(nn.Module): |
|
|
147 |
def forward(self, input: Tensor) -> Tensor: |
|
|
148 |
return input*0.5*(1.0+torch.erf(input/math.sqrt(2.0))) |
|
|
149 |
|
|
|
150 |
|
|
|
151 |
class TransformerEncoderBlock(nn.Sequential): |
|
|
152 |
def __init__(self, |
|
|
153 |
emb_size, |
|
|
154 |
num_heads=10, |
|
|
155 |
drop_p=0.5, |
|
|
156 |
forward_expansion=4, |
|
|
157 |
forward_drop_p=0.5): |
|
|
158 |
super().__init__( |
|
|
159 |
ResidualAdd(nn.Sequential( |
|
|
160 |
nn.LayerNorm(emb_size), |
|
|
161 |
MultiHeadAttention(emb_size, num_heads, drop_p), |
|
|
162 |
nn.Dropout(drop_p) |
|
|
163 |
)), |
|
|
164 |
ResidualAdd(nn.Sequential( |
|
|
165 |
nn.LayerNorm(emb_size), |
|
|
166 |
FeedForwardBlock( |
|
|
167 |
emb_size, expansion=forward_expansion, drop_p=forward_drop_p), |
|
|
168 |
nn.Dropout(drop_p) |
|
|
169 |
) |
|
|
170 |
)) |
|
|
171 |
|
|
|
172 |
|
|
|
173 |
class TransformerEncoder(nn.Sequential): |
|
|
174 |
def __init__(self, depth, emb_size): |
|
|
175 |
super().__init__(*[TransformerEncoderBlock(emb_size) for _ in range(depth)]) |
|
|
176 |
|
|
|
177 |
|
|
|
178 |
class ClassificationHead(nn.Sequential): |
|
|
179 |
def __init__(self, emb_size, n_classes): |
|
|
180 |
super().__init__() |
|
|
181 |
|
|
|
182 |
# global average pooling |
|
|
183 |
self.clshead = nn.Sequential( |
|
|
184 |
Reduce('b n e -> b e', reduction='mean'), |
|
|
185 |
nn.LayerNorm(emb_size), |
|
|
186 |
nn.Linear(emb_size, n_classes) |
|
|
187 |
) |
|
|
188 |
self.fc = nn.Sequential( |
|
|
189 |
nn.Linear(2440, 256), |
|
|
190 |
nn.ELU(), |
|
|
191 |
nn.Dropout(0.5), |
|
|
192 |
nn.Linear(256, 32), |
|
|
193 |
nn.ELU(), |
|
|
194 |
nn.Dropout(0.3), |
|
|
195 |
nn.Linear(32, 4) |
|
|
196 |
) |
|
|
197 |
|
|
|
198 |
def forward(self, x): |
|
|
199 |
x = x.contiguous().view(x.size(0), -1) |
|
|
200 |
out = self.fc(x) |
|
|
201 |
return x, out |
|
|
202 |
|
|
|
203 |
|
|
|
204 |
class Conformer(nn.Sequential): |
|
|
205 |
def __init__(self, emb_size=40, depth=6, n_classes=4, **kwargs): |
|
|
206 |
super().__init__( |
|
|
207 |
|
|
|
208 |
PatchEmbedding(emb_size), |
|
|
209 |
TransformerEncoder(depth, emb_size), |
|
|
210 |
ClassificationHead(emb_size, n_classes) |
|
|
211 |
) |
|
|
212 |
|
|
|
213 |
|
|
|
214 |
class ExP(): |
|
|
215 |
def __init__(self, nsub): |
|
|
216 |
super(ExP, self).__init__() |
|
|
217 |
self.batch_size = 72 |
|
|
218 |
self.n_epochs = 2000 |
|
|
219 |
self.c_dim = 4 |
|
|
220 |
self.lr = 0.0002 |
|
|
221 |
self.b1 = 0.5 |
|
|
222 |
self.b2 = 0.999 |
|
|
223 |
self.dimension = (190, 50) |
|
|
224 |
self.nSub = nsub |
|
|
225 |
|
|
|
226 |
self.start_epoch = 0 |
|
|
227 |
self.root = '/Data/strict_TE/' |
|
|
228 |
|
|
|
229 |
self.log_write = open("./results/log_subject%d.txt" % self.nSub, "w") |
|
|
230 |
|
|
|
231 |
|
|
|
232 |
self.Tensor = torch.cuda.FloatTensor |
|
|
233 |
self.LongTensor = torch.cuda.LongTensor |
|
|
234 |
|
|
|
235 |
self.criterion_l1 = torch.nn.L1Loss().cuda() |
|
|
236 |
self.criterion_l2 = torch.nn.MSELoss().cuda() |
|
|
237 |
self.criterion_cls = torch.nn.CrossEntropyLoss().cuda() |
|
|
238 |
|
|
|
239 |
self.model = Conformer().cuda() |
|
|
240 |
self.model = nn.DataParallel(self.model, device_ids=[i for i in range(len(gpus))]) |
|
|
241 |
self.model = self.model.cuda() |
|
|
242 |
# summary(self.model, (1, 22, 1000)) |
|
|
243 |
|
|
|
244 |
|
|
|
245 |
# Segmentation and Reconstruction (S&R) data augmentation |
|
|
246 |
def interaug(self, timg, label): |
|
|
247 |
aug_data = [] |
|
|
248 |
aug_label = [] |
|
|
249 |
for cls4aug in range(4): |
|
|
250 |
cls_idx = np.where(label == cls4aug + 1) |
|
|
251 |
tmp_data = timg[cls_idx] |
|
|
252 |
tmp_label = label[cls_idx] |
|
|
253 |
|
|
|
254 |
tmp_aug_data = np.zeros((int(self.batch_size / 4), 1, 22, 1000)) |
|
|
255 |
for ri in range(int(self.batch_size / 4)): |
|
|
256 |
for rj in range(8): |
|
|
257 |
rand_idx = np.random.randint(0, tmp_data.shape[0], 8) |
|
|
258 |
tmp_aug_data[ri, :, :, rj * 125:(rj + 1) * 125] = tmp_data[rand_idx[rj], :, :, |
|
|
259 |
rj * 125:(rj + 1) * 125] |
|
|
260 |
|
|
|
261 |
aug_data.append(tmp_aug_data) |
|
|
262 |
aug_label.append(tmp_label[:int(self.batch_size / 4)]) |
|
|
263 |
aug_data = np.concatenate(aug_data) |
|
|
264 |
aug_label = np.concatenate(aug_label) |
|
|
265 |
aug_shuffle = np.random.permutation(len(aug_data)) |
|
|
266 |
aug_data = aug_data[aug_shuffle, :, :] |
|
|
267 |
aug_label = aug_label[aug_shuffle] |
|
|
268 |
|
|
|
269 |
aug_data = torch.from_numpy(aug_data).cuda() |
|
|
270 |
aug_data = aug_data.float() |
|
|
271 |
aug_label = torch.from_numpy(aug_label-1).cuda() |
|
|
272 |
aug_label = aug_label.long() |
|
|
273 |
return aug_data, aug_label |
|
|
274 |
|
|
|
275 |
def get_source_data(self): |
|
|
276 |
# ! please please recheck if you need validation set |
|
|
277 |
# ! and the data segement compared methods used |
|
|
278 |
|
|
|
279 |
# train data |
|
|
280 |
self.total_data = scipy.io.loadmat(self.root + 'A0%dT.mat' % self.nSub) |
|
|
281 |
self.train_data = self.total_data['data'] |
|
|
282 |
self.train_label = self.total_data['label'] |
|
|
283 |
|
|
|
284 |
self.train_data = np.transpose(self.train_data, (2, 1, 0)) |
|
|
285 |
self.train_data = np.expand_dims(self.train_data, axis=1) |
|
|
286 |
self.train_label = np.transpose(self.train_label) |
|
|
287 |
|
|
|
288 |
self.allData = self.train_data |
|
|
289 |
self.allLabel = self.train_label[0] |
|
|
290 |
|
|
|
291 |
shuffle_num = np.random.permutation(len(self.allData)) |
|
|
292 |
self.allData = self.allData[shuffle_num, :, :, :] |
|
|
293 |
self.allLabel = self.allLabel[shuffle_num] |
|
|
294 |
|
|
|
295 |
# test data |
|
|
296 |
self.test_tmp = scipy.io.loadmat(self.root + 'A0%dE.mat' % self.nSub) |
|
|
297 |
self.test_data = self.test_tmp['data'] |
|
|
298 |
self.test_label = self.test_tmp['label'] |
|
|
299 |
|
|
|
300 |
self.test_data = np.transpose(self.test_data, (2, 1, 0)) |
|
|
301 |
self.test_data = np.expand_dims(self.test_data, axis=1) |
|
|
302 |
self.test_label = np.transpose(self.test_label) |
|
|
303 |
|
|
|
304 |
self.testData = self.test_data |
|
|
305 |
self.testLabel = self.test_label[0] |
|
|
306 |
|
|
|
307 |
|
|
|
308 |
# standardize |
|
|
309 |
target_mean = np.mean(self.allData) |
|
|
310 |
target_std = np.std(self.allData) |
|
|
311 |
self.allData = (self.allData - target_mean) / target_std |
|
|
312 |
self.testData = (self.testData - target_mean) / target_std |
|
|
313 |
|
|
|
314 |
# data shape: (trial, conv channel, electrode channel, time samples) |
|
|
315 |
return self.allData, self.allLabel, self.testData, self.testLabel |
|
|
316 |
|
|
|
317 |
|
|
|
318 |
def train(self): |
|
|
319 |
|
|
|
320 |
img, label, test_data, test_label = self.get_source_data() |
|
|
321 |
|
|
|
322 |
img = torch.from_numpy(img) |
|
|
323 |
label = torch.from_numpy(label - 1) |
|
|
324 |
|
|
|
325 |
dataset = torch.utils.data.TensorDataset(img, label) |
|
|
326 |
self.dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True) |
|
|
327 |
|
|
|
328 |
test_data = torch.from_numpy(test_data) |
|
|
329 |
test_label = torch.from_numpy(test_label - 1) |
|
|
330 |
test_dataset = torch.utils.data.TensorDataset(test_data, test_label) |
|
|
331 |
self.test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=True) |
|
|
332 |
|
|
|
333 |
# Optimizers |
|
|
334 |
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(self.b1, self.b2)) |
|
|
335 |
|
|
|
336 |
test_data = Variable(test_data.type(self.Tensor)) |
|
|
337 |
test_label = Variable(test_label.type(self.LongTensor)) |
|
|
338 |
|
|
|
339 |
bestAcc = 0 |
|
|
340 |
averAcc = 0 |
|
|
341 |
num = 0 |
|
|
342 |
Y_true = 0 |
|
|
343 |
Y_pred = 0 |
|
|
344 |
|
|
|
345 |
# Train the cnn model |
|
|
346 |
total_step = len(self.dataloader) |
|
|
347 |
curr_lr = self.lr |
|
|
348 |
|
|
|
349 |
for e in range(self.n_epochs): |
|
|
350 |
# in_epoch = time.time() |
|
|
351 |
self.model.train() |
|
|
352 |
for i, (img, label) in enumerate(self.dataloader): |
|
|
353 |
|
|
|
354 |
img = Variable(img.cuda().type(self.Tensor)) |
|
|
355 |
label = Variable(label.cuda().type(self.LongTensor)) |
|
|
356 |
|
|
|
357 |
# data augmentation |
|
|
358 |
aug_data, aug_label = self.interaug(self.allData, self.allLabel) |
|
|
359 |
img = torch.cat((img, aug_data)) |
|
|
360 |
label = torch.cat((label, aug_label)) |
|
|
361 |
|
|
|
362 |
|
|
|
363 |
tok, outputs = self.model(img) |
|
|
364 |
|
|
|
365 |
loss = self.criterion_cls(outputs, label) |
|
|
366 |
|
|
|
367 |
self.optimizer.zero_grad() |
|
|
368 |
loss.backward() |
|
|
369 |
self.optimizer.step() |
|
|
370 |
|
|
|
371 |
|
|
|
372 |
# out_epoch = time.time() |
|
|
373 |
|
|
|
374 |
|
|
|
375 |
# test process |
|
|
376 |
if (e + 1) % 1 == 0: |
|
|
377 |
self.model.eval() |
|
|
378 |
Tok, Cls = self.model(test_data) |
|
|
379 |
|
|
|
380 |
|
|
|
381 |
loss_test = self.criterion_cls(Cls, test_label) |
|
|
382 |
y_pred = torch.max(Cls, 1)[1] |
|
|
383 |
acc = float((y_pred == test_label).cpu().numpy().astype(int).sum()) / float(test_label.size(0)) |
|
|
384 |
train_pred = torch.max(outputs, 1)[1] |
|
|
385 |
train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0)) |
|
|
386 |
|
|
|
387 |
print('Epoch:', e, |
|
|
388 |
' Train loss: %.6f' % loss.detach().cpu().numpy(), |
|
|
389 |
' Test loss: %.6f' % loss_test.detach().cpu().numpy(), |
|
|
390 |
' Train accuracy %.6f' % train_acc, |
|
|
391 |
' Test accuracy is %.6f' % acc) |
|
|
392 |
|
|
|
393 |
self.log_write.write(str(e) + " " + str(acc) + "\n") |
|
|
394 |
num = num + 1 |
|
|
395 |
averAcc = averAcc + acc |
|
|
396 |
if acc > bestAcc: |
|
|
397 |
bestAcc = acc |
|
|
398 |
Y_true = test_label |
|
|
399 |
Y_pred = y_pred |
|
|
400 |
|
|
|
401 |
|
|
|
402 |
torch.save(self.model.module.state_dict(), 'model.pth') |
|
|
403 |
averAcc = averAcc / num |
|
|
404 |
print('The average accuracy is:', averAcc) |
|
|
405 |
print('The best accuracy is:', bestAcc) |
|
|
406 |
self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n") |
|
|
407 |
self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n") |
|
|
408 |
|
|
|
409 |
return bestAcc, averAcc, Y_true, Y_pred |
|
|
410 |
# writer.close() |
|
|
411 |
|
|
|
412 |
|
|
|
413 |
def main(): |
|
|
414 |
best = 0 |
|
|
415 |
aver = 0 |
|
|
416 |
result_write = open("./results/sub_result.txt", "w") |
|
|
417 |
|
|
|
418 |
for i in range(9): |
|
|
419 |
starttime = datetime.datetime.now() |
|
|
420 |
|
|
|
421 |
|
|
|
422 |
seed_n = np.random.randint(2021) |
|
|
423 |
print('seed is ' + str(seed_n)) |
|
|
424 |
random.seed(seed_n) |
|
|
425 |
np.random.seed(seed_n) |
|
|
426 |
torch.manual_seed(seed_n) |
|
|
427 |
torch.cuda.manual_seed(seed_n) |
|
|
428 |
torch.cuda.manual_seed_all(seed_n) |
|
|
429 |
|
|
|
430 |
|
|
|
431 |
print('Subject %d' % (i+1)) |
|
|
432 |
exp = ExP(i + 1) |
|
|
433 |
|
|
|
434 |
bestAcc, averAcc, Y_true, Y_pred = exp.train() |
|
|
435 |
print('THE BEST ACCURACY IS ' + str(bestAcc)) |
|
|
436 |
result_write.write('Subject ' + str(i + 1) + ' : ' + 'Seed is: ' + str(seed_n) + "\n") |
|
|
437 |
result_write.write('Subject ' + str(i + 1) + ' : ' + 'The best accuracy is: ' + str(bestAcc) + "\n") |
|
|
438 |
result_write.write('Subject ' + str(i + 1) + ' : ' + 'The average accuracy is: ' + str(averAcc) + "\n") |
|
|
439 |
|
|
|
440 |
endtime = datetime.datetime.now() |
|
|
441 |
print('subject %d duration: '%(i+1) + str(endtime - starttime)) |
|
|
442 |
best = best + bestAcc |
|
|
443 |
aver = aver + averAcc |
|
|
444 |
if i == 0: |
|
|
445 |
yt = Y_true |
|
|
446 |
yp = Y_pred |
|
|
447 |
else: |
|
|
448 |
yt = torch.cat((yt, Y_true)) |
|
|
449 |
yp = torch.cat((yp, Y_pred)) |
|
|
450 |
|
|
|
451 |
|
|
|
452 |
best = best / 9 |
|
|
453 |
aver = aver / 9 |
|
|
454 |
|
|
|
455 |
result_write.write('**The average Best accuracy is: ' + str(best) + "\n") |
|
|
456 |
result_write.write('The average Aver accuracy is: ' + str(aver) + "\n") |
|
|
457 |
result_write.close() |
|
|
458 |
|
|
|
459 |
|
|
|
460 |
if __name__ == "__main__": |
|
|
461 |
print(time.asctime(time.localtime(time.time()))) |
|
|
462 |
main() |
|
|
463 |
print(time.asctime(time.localtime(time.time()))) |