"""
Transformer for EEG classification
The core idea is slicing, which means to split the signal along the time dimension. Slice is just like the patch in Vision Transformer.
"""
import os
import numpy as np
import math
import random
import time
import scipy.io
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torchsummary import summary
import torch
import torch.nn.functional as F
from torch import nn
from torch import Tensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
from common_spatial_pattern import csp
# from confusion_matrix import plot_confusion_matrix
# from cm_no_normal import plot_confusion_matrix_nn
# from torchsummary import summary
import matplotlib.pyplot as plt
# from torch.utils.tensorboard import SummaryWriter
from torch.backends import cudnn
cudnn.benchmark = False
cudnn.deterministic = True
# writer = SummaryWriter('./TensorBoardX/')
# torch.cuda.set_device(6)
gpus = [0]
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, gpus))
class PatchEmbedding(nn.Module):
def __init__(self, emb_size):
# self.patch_size = patch_size
super().__init__()
self.projection = nn.Sequential(
nn.Conv2d(1, 2, (1, 51), (1, 1)),
nn.BatchNorm2d(2),
nn.LeakyReLU(0.2),
nn.Conv2d(2, emb_size, (16, 5), stride=(1, 5)),
Rearrange('b e (h) (w) -> b (h w) e'),
)
self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
# self.positions = nn.Parameter(torch.randn((100 + 1, emb_size)))
# self.positions = nn.Parameter(torch.randn((2200 + 1, emb_size)))
def forward(self, x: Tensor) -> Tensor:
b, _, _, _ = x.shape
x = self.projection(x)
cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
# position
# x += self.positions
return x
class MultiHeadAttention(nn.Module):
def __init__(self, emb_size, num_heads, dropout):
super().__init__()
self.emb_size = emb_size
self.num_heads = num_heads
self.keys = nn.Linear(emb_size, emb_size)
self.queries = nn.Linear(emb_size, emb_size)
self.values = nn.Linear(emb_size, emb_size)
self.att_drop = nn.Dropout(dropout)
self.projection = nn.Linear(emb_size, emb_size)
def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len
if mask is not None:
fill_value = torch.finfo(torch.float32).min
energy.mask_fill(~mask, fill_value)
scaling = self.emb_size ** (1 / 2)
att = F.softmax(energy / scaling, dim=-1)
att = self.att_drop(att)
out = torch.einsum('bhal, bhlv -> bhav ', att, values)
out = rearrange(out, "b h n d -> b n (h d)")
out = self.projection(out)
return out
class ResidualAdd(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x, **kwargs):
res = x
x = self.fn(x, **kwargs)
x += res
return x
class FeedForwardBlock(nn.Sequential):
def __init__(self, emb_size, expansion, drop_p):
super().__init__(
nn.Linear(emb_size, expansion * emb_size),
nn.GELU(),
nn.Dropout(drop_p),
nn.Linear(expansion * emb_size, emb_size),
)
class GELU(nn.Module):
def forward(self, input: Tensor) -> Tensor:
return input*0.5*(1.0+torch.erf(input/math.sqrt(2.0)))
class TransformerEncoderBlock(nn.Sequential):
def __init__(self,
emb_size,
num_heads=5,
drop_p=0.5,
forward_expansion=4,
forward_drop_p=0.5):
super().__init__(
ResidualAdd(nn.Sequential(
nn.LayerNorm(emb_size),
MultiHeadAttention(emb_size, num_heads, drop_p),
nn.Dropout(drop_p)
)),
ResidualAdd(nn.Sequential(
nn.LayerNorm(emb_size),
FeedForwardBlock(
emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
nn.Dropout(drop_p)
)
))
class TransformerEncoder(nn.Sequential):
def __init__(self, depth, emb_size):
super().__init__(*[TransformerEncoderBlock(emb_size) for _ in range(depth)])
class ClassificationHead(nn.Sequential):
def __init__(self, emb_size, n_classes):
super().__init__()
self.clshead = nn.Sequential(
Reduce('b n e -> b e', reduction='mean'),
nn.LayerNorm(emb_size),
nn.Linear(emb_size, n_classes)
)
def forward(self, x):
out = self.clshead(x)
return x, out
class ViT(nn.Sequential):
def __init__(self, emb_size=10, depth=3, n_classes=4, **kwargs):
super().__init__(
# channel_attention(),
ResidualAdd(
nn.Sequential(
nn.LayerNorm(1000),
channel_attention(),
nn.Dropout(0.5),
)
),
PatchEmbedding(emb_size),
TransformerEncoder(depth, emb_size),
ClassificationHead(emb_size, n_classes)
)
class channel_attention(nn.Module):
def __init__(self, sequence_num=1000, inter=30):
super(channel_attention, self).__init__()
self.sequence_num = sequence_num
self.inter = inter
self.extract_sequence = int(self.sequence_num / self.inter) # You could choose to do that for less computation
self.query = nn.Sequential(
nn.Linear(16, 16),
nn.LayerNorm(16), # also may introduce improvement to a certain extent
nn.Dropout(0.3)
)
self.key = nn.Sequential(
nn.Linear(16, 16),
# nn.LeakyReLU(),
nn.LayerNorm(16),
nn.Dropout(0.3)
)
# self.value = self.key
self.projection = nn.Sequential(
nn.Linear(16, 16),
# nn.LeakyReLU(),
nn.LayerNorm(16),
nn.Dropout(0.3),
)
self.drop_out = nn.Dropout(0)
self.pooling = nn.AvgPool2d(kernel_size=(1, self.inter), stride=(1, self.inter))
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0.0)
def forward(self, x):
temp = rearrange(x, 'b o c s->b o s c')
temp_query = rearrange(self.query(temp), 'b o s c -> b o c s')
temp_key = rearrange(self.key(temp), 'b o s c -> b o c s')
channel_query = self.pooling(temp_query)
channel_key = self.pooling(temp_key)
scaling = self.extract_sequence ** (1 / 2)
channel_atten = torch.einsum('b o c s, b o m s -> b o c m', channel_query, channel_key) / scaling
channel_atten_score = F.softmax(channel_atten, dim=-1)
channel_atten_score = self.drop_out(channel_atten_score)
out = torch.einsum('b o c s, b o c m -> b o c s', x, channel_atten_score)
'''
projections after or before multiplying with attention score are almost the same.
'''
out = rearrange(out, 'b o c s -> b o s c')
out = self.projection(out)
out = rearrange(out, 'b o s c -> b o c s')
return out
class Trans():
def __init__(self, nsub):
super(Trans, self).__init__()
self.batch_size = 50
self.n_epochs = 1000
self.img_height = 22
self.img_width = 600
self.channels = 1
self.c_dim = 4
self.lr = 0.0002
self.b1 = 0.5
self.b2 = 0.9
self.nSub = nsub
self.start_epoch = 0
self.root = '...' # the path of data
self.pretrain = False
self.log_write = open("./results/log_subject%d.txt" % self.nSub, "w")
self.img_shape = (self.channels, self.img_height, self.img_width) # something no use
self.Tensor = torch.cuda.FloatTensor
self.LongTensor = torch.cuda.LongTensor
self.criterion_l1 = torch.nn.L1Loss().cuda()
self.criterion_l2 = torch.nn.MSELoss().cuda()
self.criterion_cls = torch.nn.CrossEntropyLoss().cuda()
self.model = ViT().cuda()
self.model = nn.DataParallel(self.model, device_ids=[i for i in range(len(gpus))])
self.model = self.model.cuda()
summary(self.model, (1, 16, 1000))
self.centers = {}
def get_source_data(self):
# to get the data of target subject
self.total_data = scipy.io.loadmat(self.root + 'A0%dT.mat' % self.nSub)
self.train_data = self.total_data['data']
self.train_label = self.total_data['label']
self.train_data = np.transpose(self.train_data, (2, 1, 0))
self.train_data = np.expand_dims(self.train_data, axis=1)
self.train_label = np.transpose(self.train_label)
self.allData = self.train_data
self.allLabel = self.train_label[0]
# test data
# to get the data of target subject
self.test_tmp = scipy.io.loadmat(self.root + 'A0%dE.mat' % self.nSub)
self.test_data = self.test_tmp['data']
self.test_label = self.test_tmp['label']
# self.train_data = self.train_data[250:1000, :, :]
self.test_data = np.transpose(self.test_data, (2, 1, 0))
self.test_data = np.expand_dims(self.test_data, axis=1)
self.test_label = np.transpose(self.test_label)
self.testData = self.test_data
self.testLabel = self.test_label[0]
# standardize
target_mean = np.mean(self.allData)
target_std = np.std(self.allData)
self.allData = (self.allData - target_mean) / target_std
self.testData = (self.testData - target_mean) / target_std
tmp_alldata = np.transpose(np.squeeze(self.allData), (0, 2, 1))
Wb = csp(tmp_alldata, self.allLabel-1) # common spatial pattern
self.allData = np.einsum('abcd, ce -> abed', self.allData, Wb)
self.testData = np.einsum('abcd, ce -> abed', self.testData, Wb)
return self.allData, self.allLabel, self.testData, self.testLabel
def update_lr(self, optimizer, lr):
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# Do some data augmentation is a potential way to improve the generalization ability
def aug(self, img, label):
aug_data = []
aug_label = []
return aug_data, aug_label
def train(self):
img, label, test_data, test_label = self.get_source_data()
img = torch.from_numpy(img)
label = torch.from_numpy(label - 1)
dataset = torch.utils.data.TensorDataset(img, label)
self.dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)
test_data = torch.from_numpy(test_data)
test_label = torch.from_numpy(test_label - 1)
test_dataset = torch.utils.data.TensorDataset(test_data, test_label)
self.test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=True)
# Optimizers
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(self.b1, self.b2))
test_data = Variable(test_data.type(self.Tensor))
test_label = Variable(test_label.type(self.LongTensor))
bestAcc = 0
averAcc = 0
num = 0
Y_true = 0
Y_pred = 0
# Train the cnn model
total_step = len(self.dataloader)
curr_lr = self.lr
# some better optimization strategy is worthy to explore. Sometimes terrible over-fitting.
for e in range(self.n_epochs):
in_epoch = time.time()
self.model.train()
for i, (img, label) in enumerate(self.dataloader):
img = Variable(img.cuda().type(self.Tensor))
label = Variable(label.cuda().type(self.LongTensor))
tok, outputs = self.model(img)
loss = self.criterion_cls(outputs, label)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
out_epoch = time.time()
if (e + 1) % 1 == 0:
self.model.eval()
Tok, Cls = self.model(test_data)
loss_test = self.criterion_cls(Cls, test_label)
y_pred = torch.max(Cls, 1)[1]
acc = float((y_pred == test_label).cpu().numpy().astype(int).sum()) / float(test_label.size(0))
train_pred = torch.max(outputs, 1)[1]
train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0))
print('Epoch:', e,
' Train loss:', loss.detach().cpu().numpy(),
' Test loss:', loss_test.detach().cpu().numpy(),
' Train accuracy:', train_acc,
' Test accuracy is:', acc)
self.log_write.write(str(e) + " " + str(acc) + "\n")
num = num + 1
averAcc = averAcc + acc
if acc > bestAcc:
bestAcc = acc
Y_true = test_label
Y_pred = y_pred
torch.save(self.model.module.state_dict(), 'model.pth')
averAcc = averAcc / num
print('The average accuracy is:', averAcc)
print('The best accuracy is:', bestAcc)
self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n")
self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n")
return bestAcc, averAcc, Y_true, Y_pred
def main():
best = 0
aver = 0
result_write = open("./results/sub_result.txt", "w")
for i in range(9):
seed_n = np.random.randint(500)
print('seed is ' + str(seed_n))
random.seed(seed_n)
np.random.seed(seed_n)
torch.manual_seed(seed_n)
torch.cuda.manual_seed(seed_n)
torch.cuda.manual_seed_all(seed_n)
print('Subject %d' % (i+1))
trans = Trans(i + 1)
bestAcc, averAcc, Y_true, Y_pred = trans.train()
print('THE BEST ACCURACY IS ' + str(bestAcc))
result_write.write('Subject ' + str(i + 1) + ' : ' + 'Seed is: ' + str(seed_n) + "\n")
result_write.write('**Subject ' + str(i + 1) + ' : ' + 'The best accuracy is: ' + str(bestAcc) + "\n")
result_write.write('Subject ' + str(i + 1) + ' : ' + 'The average accuracy is: ' + str(averAcc) + "\n")
# plot_confusion_matrix(Y_true, Y_pred, i+1)
best = best + bestAcc
aver = aver + averAcc
if i == 0:
yt = Y_true
yp = Y_pred
else:
yt = torch.cat((yt, Y_true))
yp = torch.cat((yp, Y_pred))
best = best / 9
aver = aver / 9
# plot_confusion_matrix(yt, yp, 666)
result_write.write('**The average Best accuracy is: ' + str(best) + "\n")
result_write.write('The average Aver accuracy is: ' + str(aver) + "\n")
result_write.close()
if __name__ == "__main__":
main()