--- a +++ b/conformer_BCIIV2b.py @@ -0,0 +1,520 @@ +""" +EEG conformer +Test on the datasets 2b +""" + + +import argparse +import os +gpus = [1] +os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' +os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, gpus)) +import numpy as np +import math +import glob +import random +import itertools +import datetime +import time +import datetime +import sys +import scipy.io + +import torchvision.transforms as transforms +from torchvision.utils import save_image, make_grid + +from torch.utils.data import DataLoader +from torch.autograd import Variable +from torchsummary import summary +import torch.autograd as autograd +from torchvision.models import vgg19 + +import torch.nn as nn +import torch.nn.functional as F +import torch +import torch.nn.init as init + +from torch.utils.data import Dataset +from PIL import Image +import torchvision.transforms as transforms +from sklearn.decomposition import PCA + +import torch +import torch.nn.functional as F +import matplotlib.pyplot as plt + +from torch import nn +from torch import Tensor +from PIL import Image +from torchvision.transforms import Compose, Resize, ToTensor +from einops import rearrange, reduce, repeat +from einops.layers.torch import Rearrange, Reduce +# from common_spatial_pattern import csp + +import matplotlib.pyplot as plt +# from torch.utils.tensorboard import SummaryWriter +from torch.backends import cudnn +cudnn.benchmark = False +cudnn.deterministic = True + + +class PatchEmbedding(nn.Module): + def __init__(self, emb_size=40): + # self.patch_size = patch_size + super().__init__() + + self.shallownet = nn.Sequential( + nn.Conv2d(1, 40, (1, 25), (1, 1)), + nn.Conv2d(40, 40, (3, 1), (1, 1)), + nn.BatchNorm2d(40), + nn.ELU(), + nn.AvgPool2d((1, 75), (1, 15)), + nn.Dropout(0.5), + ) + + self.projection = nn.Sequential( + nn.Conv2d(40, emb_size, (1, 1), stride=(1, 1)), # 5 is better than 1 + Rearrange('b e (h) (w) -> b (h w) e'), + ) + + + def forward(self, x: Tensor) -> Tensor: + b, _, _, _ = x.shape + x = self.shallownet(x) + x = self.projection(x) + return x + + +class MultiHeadAttention(nn.Module): + def __init__(self, emb_size, num_heads, dropout): + super().__init__() + self.emb_size = emb_size + self.num_heads = num_heads + self.keys = nn.Linear(emb_size, emb_size) + self.queries = nn.Linear(emb_size, emb_size) + self.values = nn.Linear(emb_size, emb_size) + self.att_drop = nn.Dropout(dropout) + self.projection = nn.Linear(emb_size, emb_size) + + def forward(self, x: Tensor, mask: Tensor = None) -> Tensor: + queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads) + keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads) + values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads) + energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys) # batch, num_heads, query_len, key_len + if mask is not None: + fill_value = torch.finfo(torch.float32).min + energy.mask_fill(~mask, fill_value) + + scaling = self.emb_size ** (1 / 2) + att = F.softmax(energy / scaling, dim=-1) + att = self.att_drop(att) + out = torch.einsum('bhal, bhlv -> bhav ', att, values) + out = rearrange(out, "b h n d -> b n (h d)") + out = self.projection(out) + return out + + +class ResidualAdd(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x, **kwargs): + res = x + x = self.fn(x, **kwargs) + x += res + return x + + +class FeedForwardBlock(nn.Sequential): + def __init__(self, emb_size, expansion, drop_p): + super().__init__( + nn.Linear(emb_size, expansion * emb_size), + nn.GELU(), + nn.Dropout(drop_p), + nn.Linear(expansion * emb_size, emb_size), + ) + +class GELU(nn.Module): + def forward(self, input: Tensor) -> Tensor: + return input*0.5*(1.0+torch.erf(input/math.sqrt(2.0))) + +class TransformerEncoderBlock(nn.Sequential): + def __init__(self, + emb_size, + num_heads=5, + drop_p=0.5, + forward_expansion=4, + forward_drop_p=0.5): + super().__init__( + ResidualAdd(nn.Sequential( + nn.LayerNorm(emb_size), + MultiHeadAttention(emb_size, num_heads, drop_p), + nn.Dropout(drop_p) + )), + ResidualAdd(nn.Sequential( + nn.LayerNorm(emb_size), + FeedForwardBlock( + emb_size, expansion=forward_expansion, drop_p=forward_drop_p), + nn.Dropout(drop_p) + ) + )) + + +class TransformerEncoder(nn.Sequential): + def __init__(self, depth, emb_size): + super().__init__(*[TransformerEncoderBlock(emb_size) for _ in range(depth)]) + + +class ClassificationHead(nn.Sequential): + def __init__(self, emb_size, n_classes): + super().__init__() + self.cov = nn.Sequential( + nn.Conv1d(190, 1, 1, 1), + nn.LeakyReLU(0.2), + nn.Dropout(0.5) + ) + self.clshead = nn.Sequential( + Reduce('b n e -> b e', reduction='mean'), + nn.LayerNorm(emb_size), + nn.Linear(emb_size, n_classes) + ) + self.clshead_fc = nn.Sequential( + Reduce('b n e -> b e', reduction='mean'), + nn.LayerNorm(emb_size), + nn.Linear(emb_size, 32), + nn.ELU(), + nn.Dropout(0.5), + nn.Linear(32, n_classes) + ) + self.fc = nn.Sequential( + nn.Linear(2440, 256), + nn.ELU(), + nn.Dropout(0.5), + nn.Linear(256, 32), + nn.ELU(), + nn.Dropout(0.3), + nn.Linear(32, 2) + ) + + def forward(self, x): + x = x.contiguous().view(x.size(0), -1) + out = self.fc(x) + + return x, out + + +# ! Rethink the use of Transformer for EEG signal +class ViT(nn.Sequential): + def __init__(self, emb_size=40, depth=10, n_classes=2, **kwargs): + super().__init__( + + PatchEmbedding(emb_size), + TransformerEncoder(depth, emb_size), + ClassificationHead(emb_size, n_classes) + ) + + +class ExGAN(): + def __init__(self, nsub): + super(ExGAN, self).__init__() + self.batch_size = 100 + self.n_epochs = 2000 + self.img_height = 22 + self.img_width = 600 + self.channels = 1 + self.c_dim = 4 + self.lr = 0.0002 + self.b1 = 0.5 + self.b2 = 0.999 + self.alpha = 0.0002 + self.dimension = (190, 50) + + self.nSub = nsub + + self.start_epoch = 0 + self.root = '/Data/strict_TE/2b/' + + self.pretrain = False + + self.log_write = open("/Code/CT/results/cf/2b/log_subject%d.txt" % self.nSub, "w") + + self.img_shape = (self.channels, self.img_height, self.img_width) + + self.Tensor = torch.cuda.FloatTensor + self.LongTensor = torch.cuda.LongTensor + + self.criterion_l1 = torch.nn.L1Loss().cuda() + self.criterion_l2 = torch.nn.MSELoss().cuda() + self.criterion_cls = torch.nn.CrossEntropyLoss().cuda() + + self.model = ViT().cuda() + self.model = nn.DataParallel(self.model, device_ids=[i for i in range(len(gpus))]) + self.model = self.model.cuda() + + self.centers = {} + + def interaug(self, timg, label): + aug_data = [] + aug_label = [] + for cls4aug in range(2): + cls_idx = np.where(label == cls4aug + 1) + tmp_data = timg[cls_idx] + tmp_label = label[cls_idx] + + tmp_aug_data = np.zeros((int(self.batch_size / 2), 1, 3, 1000)) + for ri in range(int(self.batch_size / 2)): + for rj in range(8): + rand_idx = np.random.randint(0, tmp_data.shape[0], 8) + tmp_aug_data[ri, :, :, rj * 125:(rj + 1) * 125] = tmp_data[rand_idx[rj], :, :, + rj * 125:(rj + 1) * 125] + + aug_data.append(tmp_aug_data) + aug_label.append(tmp_label[:int(self.batch_size / 2)]) + aug_data = np.concatenate(aug_data) + aug_label = np.concatenate(aug_label) + aug_shuffle = np.random.permutation(len(aug_data)) + aug_data = aug_data[aug_shuffle, :, :] + aug_label = aug_label[aug_shuffle] + + aug_data = torch.from_numpy(aug_data).cuda() + aug_data = aug_data.float() + aug_label = torch.from_numpy(aug_label-1).cuda() + aug_label = aug_label.long() + return aug_data, aug_label + + def get_source_data(self): + + # to get the data of target subject + train_data = [] + train_label = [] + for session_index in range(3): + target_tmp = scipy.io.loadmat(self.root + 'B0%d0%dT.mat' % (self.nSub, session_index+1)) + train_data_tmp = target_tmp['data'] + train_label_tmp = target_tmp['label'] + train_data_tmp = np.transpose(train_data_tmp, (2, 1, 0)) + train_data_tmp = np.expand_dims(train_data_tmp, axis=1) + train_label_tmp = np.transpose(train_label_tmp) + train_label_tmp = train_label_tmp[0] + train_data.append(train_data_tmp) + train_label.append(train_label_tmp) + + self.allData = np.concatenate(train_data) + self.allLabel = np.concatenate(train_label) + + shuffle_num = np.random.permutation(len(self.allData)) + self.allData = self.allData[shuffle_num, :, :, :] + self.allLabel = self.allLabel[shuffle_num] + + # test data + test_data = [] + test_label = [] + for session_index in range(2): + test_tmp = scipy.io.loadmat(self.root + 'B0%d0%dE.mat' % (self.nSub, session_index+4)) + test_data_tmp = test_tmp['data'] + test_label_tmp = test_tmp['label'] + test_data_tmp = np.transpose(test_data_tmp, (2, 1, 0)) + test_data_tmp = np.expand_dims(test_data_tmp, axis=1) + test_label_tmp = np.transpose(test_label_tmp) + test_label_tmp = test_label_tmp[0] + test_data.append(test_data_tmp) + test_label.append(test_label_tmp) + + self.testData = np.concatenate(test_data) + self.testLabel = np.concatenate(test_label) + + # standardize + target_mean = np.mean(self.allData) + target_std = np.std(self.allData) + self.allData = (self.allData - target_mean) / target_std + self.testData = (self.testData - target_mean) / target_std + + return self.allData, self.allLabel, self.testData, self.testLabel + + def update_lr(self, optimizer, lr): + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + def aug(self, img, label): + aug_data = [] + aug_label = [] + for cls4aug in range(4): + cls_idx = np.where(label == cls4aug + 1) + tmp_data = img[cls_idx] + tmp_label = label[cls_idx] + + tmp_aug_data = np.zeros(tmp_data.shape) + for ri in range(tmp_data.shape[0]): + for rj in range(8): + rand_idx = np.random.randint(0, tmp_data.shape[0], 8) + tmp_aug_data[ri, :, :, rj * 125:(rj + 1) * 125] = tmp_data[rand_idx[rj], :, :, rj * 125:(rj + 1) * 125] + + aug_data.append(tmp_aug_data) + aug_label.append(tmp_label) + aug_data = np.concatenate(aug_data) + aug_label = np.concatenate(aug_label) + aug_shuffle = np.random.permutation(len(aug_data)) + aug_data = aug_data[aug_shuffle, :, :] + aug_label = aug_label[aug_shuffle] + + return aug_data, aug_label + + def update_centers(self, feature, label): + deltac = {} + count = {} + count[0] = 0 + for i in range(len(label)): + l = label[i] + if l in deltac: + deltac[l] += self.centers[l]-feature[i] + else: + deltac[l] = self.centers[l]-feature[i] + if l in count: + count[l] += 1 + else: + count[l] = 1 + + for ke in deltac.keys(): + deltac[ke] = deltac[ke]/(count[ke]+1) + + return deltac + + def train(self): + + img, label, test_data, test_label = self.get_source_data() + + img = torch.from_numpy(img) + label = torch.from_numpy(label - 1) + + dataset = torch.utils.data.TensorDataset(img, label) + self.dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True) + + test_data = torch.from_numpy(test_data) + test_label = torch.from_numpy(test_label - 1) + test_dataset = torch.utils.data.TensorDataset(test_data, test_label) + self.test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=True) + + for i in range(self.c_dim): + self.centers[i] = torch.randn(self.dimension) + self.centers[i] = self.centers[i].cuda() + + # Optimizers + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(self.b1, self.b2)) + + test_data = Variable(test_data.type(self.Tensor)) + test_label = Variable(test_label.type(self.LongTensor)) + + bestAcc = 0 + averAcc = 0 + num = 0 + Y_true = 0 + Y_pred = 0 + + # Train the cnn model + total_step = len(self.dataloader) + curr_lr = self.lr + + for e in range(self.n_epochs): + in_epoch = time.time() + self.model.train() + for i, (img, label) in enumerate(self.dataloader): + + img = Variable(img.cuda().type(self.Tensor)) + # img = self.active_function(img) + label = Variable(label.cuda().type(self.LongTensor)) + + aug_data, aug_label = self.interaug(self.allData, self.allLabel) + img = torch.cat((img, aug_data)) + label = torch.cat((label, aug_label)) + + tok, outputs = self.model(img) + + loss = self.criterion_cls(outputs, label) + + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + out_epoch = time.time() + + if (e + 1) % 1 == 0: + self.model.eval() + Tok, Cls = self.model(test_data) + + + loss_test = self.criterion_cls(Cls, test_label) + y_pred = torch.max(Cls, 1)[1] + acc = float((y_pred == test_label).cpu().numpy().astype(int).sum()) / float(test_label.size(0)) + train_pred = torch.max(outputs, 1)[1] + train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0)) + print('Epoch:', e, + ' Train loss: %.6f' % loss.detach().cpu().numpy(), + ' Test loss: %.6f' % loss_test.detach().cpu().numpy(), + ' Train accuracy %.6f' % train_acc, + ' Test accuracy is %.6f' % acc) + self.log_write.write(str(e) + " " + str(acc) + "\n") + num = num + 1 + averAcc = averAcc + acc + if acc > bestAcc: + bestAcc = acc + Y_true = test_label + Y_pred = y_pred + + torch.save(self.model.module.state_dict(), 'model.pth') + averAcc = averAcc / num + print('The average accuracy is:', averAcc) + print('The best accuracy is:', bestAcc) + self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n") + self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n") + return bestAcc, averAcc, Y_true, Y_pred + + +def main(): + best = 0 + aver = 0 + result_write = open("/Code/CT/results/cf/2b/sub_result.txt", "w") + + for i in range(9): + starttime = datetime.datetime.now() + seed_n = np.random.randint(2021) + + print('seed is ' + str(seed_n)) + random.seed(seed_n) + np.random.seed(seed_n) + torch.manual_seed(seed_n) + torch.cuda.manual_seed(seed_n) + torch.cuda.manual_seed_all(seed_n) + print('Subject %d' % (i+1)) + exgan = ExGAN(i + 1) + + bestAcc, averAcc, Y_true, Y_pred = exgan.train() + print('THE BEST ACCURACY IS ' + str(bestAcc)) + result_write.write('Subject ' + str(i + 1) + ' : ' + 'Seed is: ' + str(seed_n) + "\n") + result_write.write('**Subject ' + str(i + 1) + ' : ' + 'The best accuracy is: ' + str(bestAcc) + "\n") + result_write.write('Subject ' + str(i + 1) + ' : ' + 'The average accuracy is: ' + str(averAcc) + "\n") + + endtime = datetime.datetime.now() + print('subject %d duration: '%(i+1) + str(endtime - starttime)) + + best = best + bestAcc + aver = aver + averAcc + if i == 0: + yt = Y_true + yp = Y_pred + else: + yt = torch.cat((yt, Y_true)) + yp = torch.cat((yp, Y_pred)) + + + best = best / 9 + aver = aver / 9 + + result_write.write('**The average Best accuracy is: ' + str(best) + "\n") + result_write.write('The average Aver accuracy is: ' + str(aver) + "\n") + result_write.close() + + +if __name__ == "__main__": + print(time.asctime(time.localtime(time.time()))) + main() + print(time.asctime(time.localtime(time.time())))