Diff of /conformer_BCIIV2b.py [000000] .. [8bbec7]

Switch to side-by-side view

--- a
+++ b/conformer_BCIIV2b.py
@@ -0,0 +1,520 @@
+"""
+EEG conformer 
+Test on the datasets 2b
+"""
+
+
+import argparse
+import os
+gpus = [1]
+os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
+os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, gpus))
+import numpy as np
+import math
+import glob
+import random
+import itertools
+import datetime
+import time
+import datetime
+import sys
+import scipy.io
+
+import torchvision.transforms as transforms
+from torchvision.utils import save_image, make_grid
+
+from torch.utils.data import DataLoader
+from torch.autograd import Variable
+from torchsummary import summary
+import torch.autograd as autograd
+from torchvision.models import vgg19
+
+import torch.nn as nn
+import torch.nn.functional as F
+import torch
+import torch.nn.init as init
+
+from torch.utils.data import Dataset
+from PIL import Image
+import torchvision.transforms as transforms
+from sklearn.decomposition import PCA
+
+import torch
+import torch.nn.functional as F
+import matplotlib.pyplot as plt
+
+from torch import nn
+from torch import Tensor
+from PIL import Image
+from torchvision.transforms import Compose, Resize, ToTensor
+from einops import rearrange, reduce, repeat
+from einops.layers.torch import Rearrange, Reduce
+# from common_spatial_pattern import csp
+
+import matplotlib.pyplot as plt
+# from torch.utils.tensorboard import SummaryWriter
+from torch.backends import cudnn
+cudnn.benchmark = False
+cudnn.deterministic = True
+
+
+class PatchEmbedding(nn.Module):
+    def __init__(self, emb_size=40):
+        # self.patch_size = patch_size
+        super().__init__()
+
+        self.shallownet = nn.Sequential(
+            nn.Conv2d(1, 40, (1, 25), (1, 1)),
+            nn.Conv2d(40, 40, (3, 1), (1, 1)),
+            nn.BatchNorm2d(40),
+            nn.ELU(),
+            nn.AvgPool2d((1, 75), (1, 15)),
+            nn.Dropout(0.5),
+        )
+
+        self.projection = nn.Sequential(
+            nn.Conv2d(40, emb_size, (1, 1), stride=(1, 1)),  # 5 is better than 1
+            Rearrange('b e (h) (w) -> b (h w) e'),
+        )
+
+
+    def forward(self, x: Tensor) -> Tensor:
+        b, _, _, _ = x.shape
+        x = self.shallownet(x)
+        x = self.projection(x)
+        return x
+
+
+class MultiHeadAttention(nn.Module):
+    def __init__(self, emb_size, num_heads, dropout):
+        super().__init__()
+        self.emb_size = emb_size
+        self.num_heads = num_heads
+        self.keys = nn.Linear(emb_size, emb_size)
+        self.queries = nn.Linear(emb_size, emb_size)
+        self.values = nn.Linear(emb_size, emb_size)
+        self.att_drop = nn.Dropout(dropout)
+        self.projection = nn.Linear(emb_size, emb_size)
+
+    def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
+        queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
+        keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
+        values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
+        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)  # batch, num_heads, query_len, key_len
+        if mask is not None:
+            fill_value = torch.finfo(torch.float32).min
+            energy.mask_fill(~mask, fill_value)
+
+        scaling = self.emb_size ** (1 / 2)
+        att = F.softmax(energy / scaling, dim=-1)
+        att = self.att_drop(att)
+        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
+        out = rearrange(out, "b h n d -> b n (h d)")
+        out = self.projection(out)
+        return out
+
+
+class ResidualAdd(nn.Module):
+    def __init__(self, fn):
+        super().__init__()
+        self.fn = fn
+
+    def forward(self, x, **kwargs):
+        res = x
+        x = self.fn(x, **kwargs)
+        x += res
+        return x
+
+
+class FeedForwardBlock(nn.Sequential):
+    def __init__(self, emb_size, expansion, drop_p):
+        super().__init__(
+            nn.Linear(emb_size, expansion * emb_size),
+            nn.GELU(),
+            nn.Dropout(drop_p),
+            nn.Linear(expansion * emb_size, emb_size),
+        )
+
+class GELU(nn.Module):
+    def forward(self, input: Tensor) -> Tensor:
+        return input*0.5*(1.0+torch.erf(input/math.sqrt(2.0)))
+
+class TransformerEncoderBlock(nn.Sequential):
+    def __init__(self,
+                 emb_size,
+                 num_heads=5,
+                 drop_p=0.5,
+                 forward_expansion=4,
+                 forward_drop_p=0.5):
+        super().__init__(
+            ResidualAdd(nn.Sequential(
+                nn.LayerNorm(emb_size),
+                MultiHeadAttention(emb_size, num_heads, drop_p),
+                nn.Dropout(drop_p)
+            )),
+            ResidualAdd(nn.Sequential(
+                nn.LayerNorm(emb_size),
+                FeedForwardBlock(
+                    emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
+                nn.Dropout(drop_p)
+            )
+            ))
+
+
+class TransformerEncoder(nn.Sequential):
+    def __init__(self, depth, emb_size):
+        super().__init__(*[TransformerEncoderBlock(emb_size) for _ in range(depth)])
+
+
+class ClassificationHead(nn.Sequential):
+    def __init__(self, emb_size, n_classes):
+        super().__init__()
+        self.cov = nn.Sequential(
+            nn.Conv1d(190, 1, 1, 1),
+            nn.LeakyReLU(0.2),
+            nn.Dropout(0.5)
+        )
+        self.clshead = nn.Sequential(
+            Reduce('b n e -> b e', reduction='mean'),
+            nn.LayerNorm(emb_size),
+            nn.Linear(emb_size, n_classes)
+        )
+        self.clshead_fc = nn.Sequential(
+            Reduce('b n e -> b e', reduction='mean'),
+            nn.LayerNorm(emb_size),
+            nn.Linear(emb_size, 32),
+            nn.ELU(),
+            nn.Dropout(0.5),
+            nn.Linear(32, n_classes)
+        )
+        self.fc = nn.Sequential(
+            nn.Linear(2440, 256),
+            nn.ELU(),
+            nn.Dropout(0.5),
+            nn.Linear(256, 32),
+            nn.ELU(),
+            nn.Dropout(0.3),
+            nn.Linear(32, 2)
+        )
+
+    def forward(self, x):
+        x = x.contiguous().view(x.size(0), -1)
+        out = self.fc(x)
+        
+        return x, out
+
+
+# ! Rethink the use of Transformer for EEG signal
+class ViT(nn.Sequential):
+    def __init__(self, emb_size=40, depth=10, n_classes=2, **kwargs):
+        super().__init__(
+
+            PatchEmbedding(emb_size),
+            TransformerEncoder(depth, emb_size),
+            ClassificationHead(emb_size, n_classes)
+        )
+
+
+class ExGAN():
+    def __init__(self, nsub):
+        super(ExGAN, self).__init__()
+        self.batch_size = 100
+        self.n_epochs = 2000
+        self.img_height = 22
+        self.img_width = 600
+        self.channels = 1
+        self.c_dim = 4
+        self.lr = 0.0002
+        self.b1 = 0.5
+        self.b2 = 0.999
+        self.alpha = 0.0002
+        self.dimension = (190, 50)
+
+        self.nSub = nsub
+
+        self.start_epoch = 0
+        self.root = '/Data/strict_TE/2b/'
+
+        self.pretrain = False
+
+        self.log_write = open("/Code/CT/results/cf/2b/log_subject%d.txt" % self.nSub, "w")
+
+        self.img_shape = (self.channels, self.img_height, self.img_width)
+
+        self.Tensor = torch.cuda.FloatTensor
+        self.LongTensor = torch.cuda.LongTensor
+
+        self.criterion_l1 = torch.nn.L1Loss().cuda()
+        self.criterion_l2 = torch.nn.MSELoss().cuda()
+        self.criterion_cls = torch.nn.CrossEntropyLoss().cuda()
+
+        self.model = ViT().cuda()
+        self.model = nn.DataParallel(self.model, device_ids=[i for i in range(len(gpus))])
+        self.model = self.model.cuda()
+
+        self.centers = {}
+
+    def interaug(self, timg, label):
+        aug_data = []
+        aug_label = []
+        for cls4aug in range(2):
+            cls_idx = np.where(label == cls4aug + 1)
+            tmp_data = timg[cls_idx]
+            tmp_label = label[cls_idx]
+
+            tmp_aug_data = np.zeros((int(self.batch_size / 2), 1, 3, 1000))
+            for ri in range(int(self.batch_size / 2)):
+                for rj in range(8):
+                    rand_idx = np.random.randint(0, tmp_data.shape[0], 8)
+                    tmp_aug_data[ri, :, :, rj * 125:(rj + 1) * 125] = tmp_data[rand_idx[rj], :, :,
+                                                                      rj * 125:(rj + 1) * 125]
+
+            aug_data.append(tmp_aug_data)
+            aug_label.append(tmp_label[:int(self.batch_size / 2)])
+        aug_data = np.concatenate(aug_data)
+        aug_label = np.concatenate(aug_label)
+        aug_shuffle = np.random.permutation(len(aug_data))
+        aug_data = aug_data[aug_shuffle, :, :]
+        aug_label = aug_label[aug_shuffle]
+
+        aug_data = torch.from_numpy(aug_data).cuda()
+        aug_data = aug_data.float()
+        aug_label = torch.from_numpy(aug_label-1).cuda()
+        aug_label = aug_label.long()
+        return aug_data, aug_label
+
+    def get_source_data(self):
+
+        # to get the data of target subject
+        train_data = []
+        train_label = []
+        for session_index in range(3):
+            target_tmp = scipy.io.loadmat(self.root + 'B0%d0%dT.mat' % (self.nSub, session_index+1))
+            train_data_tmp = target_tmp['data']
+            train_label_tmp = target_tmp['label']
+            train_data_tmp = np.transpose(train_data_tmp, (2, 1, 0))
+            train_data_tmp = np.expand_dims(train_data_tmp, axis=1)
+            train_label_tmp = np.transpose(train_label_tmp)
+            train_label_tmp = train_label_tmp[0]
+            train_data.append(train_data_tmp)
+            train_label.append(train_label_tmp)
+
+        self.allData = np.concatenate(train_data)
+        self.allLabel = np.concatenate(train_label)
+
+        shuffle_num = np.random.permutation(len(self.allData))
+        self.allData = self.allData[shuffle_num, :, :, :]
+        self.allLabel = self.allLabel[shuffle_num]
+
+        # test data
+        test_data = []
+        test_label = []
+        for session_index in range(2):
+            test_tmp = scipy.io.loadmat(self.root + 'B0%d0%dE.mat' % (self.nSub, session_index+4))
+            test_data_tmp = test_tmp['data']
+            test_label_tmp = test_tmp['label']
+            test_data_tmp = np.transpose(test_data_tmp, (2, 1, 0))
+            test_data_tmp = np.expand_dims(test_data_tmp, axis=1)
+            test_label_tmp = np.transpose(test_label_tmp)
+            test_label_tmp = test_label_tmp[0]
+            test_data.append(test_data_tmp)
+            test_label.append(test_label_tmp)
+
+        self.testData = np.concatenate(test_data)
+        self.testLabel = np.concatenate(test_label)
+
+        # standardize
+        target_mean = np.mean(self.allData)
+        target_std = np.std(self.allData)
+        self.allData = (self.allData - target_mean) / target_std
+        self.testData = (self.testData - target_mean) / target_std
+
+        return self.allData, self.allLabel, self.testData, self.testLabel
+
+    def update_lr(self, optimizer, lr):
+        for param_group in optimizer.param_groups:
+            param_group['lr'] = lr
+
+    def aug(self, img, label):
+        aug_data = []
+        aug_label = []
+        for cls4aug in range(4):
+            cls_idx = np.where(label == cls4aug + 1)
+            tmp_data = img[cls_idx]
+            tmp_label = label[cls_idx]
+
+            tmp_aug_data = np.zeros(tmp_data.shape)
+            for ri in range(tmp_data.shape[0]):
+                for rj in range(8):
+                    rand_idx = np.random.randint(0, tmp_data.shape[0], 8)
+                    tmp_aug_data[ri, :, :, rj * 125:(rj + 1) * 125] = tmp_data[rand_idx[rj], :, :, rj * 125:(rj + 1) * 125]
+
+            aug_data.append(tmp_aug_data)
+            aug_label.append(tmp_label)
+        aug_data = np.concatenate(aug_data)
+        aug_label = np.concatenate(aug_label)
+        aug_shuffle = np.random.permutation(len(aug_data))
+        aug_data = aug_data[aug_shuffle, :, :]
+        aug_label = aug_label[aug_shuffle]
+
+        return aug_data, aug_label
+
+    def update_centers(self, feature, label):
+            deltac = {}
+            count = {}
+            count[0] = 0
+            for i in range(len(label)):
+                l = label[i]
+                if l in deltac:
+                    deltac[l] += self.centers[l]-feature[i]
+                else:
+                    deltac[l] = self.centers[l]-feature[i]
+                if l in count:
+                    count[l] += 1
+                else:
+                    count[l] = 1
+
+            for ke in deltac.keys():
+                deltac[ke] = deltac[ke]/(count[ke]+1)
+
+            return deltac
+
+    def train(self):
+
+        img, label, test_data, test_label = self.get_source_data()
+
+        img = torch.from_numpy(img)
+        label = torch.from_numpy(label - 1)
+
+        dataset = torch.utils.data.TensorDataset(img, label)
+        self.dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)
+
+        test_data = torch.from_numpy(test_data)
+        test_label = torch.from_numpy(test_label - 1)
+        test_dataset = torch.utils.data.TensorDataset(test_data, test_label)
+        self.test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=True)
+
+        for i in range(self.c_dim):
+            self.centers[i] = torch.randn(self.dimension)
+            self.centers[i] = self.centers[i].cuda()
+
+        # Optimizers
+        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(self.b1, self.b2))
+
+        test_data = Variable(test_data.type(self.Tensor))
+        test_label = Variable(test_label.type(self.LongTensor))
+
+        bestAcc = 0
+        averAcc = 0
+        num = 0
+        Y_true = 0
+        Y_pred = 0
+
+        # Train the cnn model
+        total_step = len(self.dataloader)
+        curr_lr = self.lr
+
+        for e in range(self.n_epochs):
+            in_epoch = time.time()
+            self.model.train()
+            for i, (img, label) in enumerate(self.dataloader):
+
+                img = Variable(img.cuda().type(self.Tensor))
+                # img = self.active_function(img)
+                label = Variable(label.cuda().type(self.LongTensor))
+
+                aug_data, aug_label = self.interaug(self.allData, self.allLabel)
+                img = torch.cat((img, aug_data))
+                label = torch.cat((label, aug_label))
+
+                tok, outputs = self.model(img)
+
+                loss = self.criterion_cls(outputs, label)
+
+                self.optimizer.zero_grad()
+                loss.backward()
+                self.optimizer.step()
+
+            out_epoch = time.time()
+
+            if (e + 1) % 1 == 0:
+                self.model.eval()
+                Tok, Cls = self.model(test_data)
+
+
+                loss_test = self.criterion_cls(Cls, test_label)
+                y_pred = torch.max(Cls, 1)[1]
+                acc = float((y_pred == test_label).cpu().numpy().astype(int).sum()) / float(test_label.size(0))
+                train_pred = torch.max(outputs, 1)[1]
+                train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0))
+                print('Epoch:', e,
+                      '  Train loss: %.6f' % loss.detach().cpu().numpy(),
+                      '  Test loss: %.6f' % loss_test.detach().cpu().numpy(),
+                      '  Train accuracy %.6f' % train_acc,
+                      '  Test accuracy is %.6f' % acc)
+                self.log_write.write(str(e) + "    " + str(acc) + "\n")
+                num = num + 1
+                averAcc = averAcc + acc
+                if acc > bestAcc:
+                    bestAcc = acc
+                    Y_true = test_label
+                    Y_pred = y_pred
+
+        torch.save(self.model.module.state_dict(), 'model.pth')
+        averAcc = averAcc / num
+        print('The average accuracy is:', averAcc)
+        print('The best accuracy is:', bestAcc)
+        self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n")
+        self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n")
+        return bestAcc, averAcc, Y_true, Y_pred
+
+
+def main():
+    best = 0
+    aver = 0
+    result_write = open("/Code/CT/results/cf/2b/sub_result.txt", "w")
+
+    for i in range(9):
+        starttime = datetime.datetime.now()
+        seed_n = np.random.randint(2021)
+
+        print('seed is ' + str(seed_n))
+        random.seed(seed_n)
+        np.random.seed(seed_n)
+        torch.manual_seed(seed_n)
+        torch.cuda.manual_seed(seed_n)
+        torch.cuda.manual_seed_all(seed_n)
+        print('Subject %d' % (i+1))
+        exgan = ExGAN(i + 1)
+
+        bestAcc, averAcc, Y_true, Y_pred = exgan.train()
+        print('THE BEST ACCURACY IS ' + str(bestAcc))
+        result_write.write('Subject ' + str(i + 1) + ' : ' + 'Seed is: ' + str(seed_n) + "\n")
+        result_write.write('**Subject ' + str(i + 1) + ' : ' + 'The best accuracy is: ' + str(bestAcc) + "\n")
+        result_write.write('Subject ' + str(i + 1) + ' : ' + 'The average accuracy is: ' + str(averAcc) + "\n")
+
+        endtime = datetime.datetime.now()
+        print('subject %d duration: '%(i+1) + str(endtime - starttime))
+
+        best = best + bestAcc
+        aver = aver + averAcc
+        if i == 0:
+            yt = Y_true
+            yp = Y_pred
+        else:
+            yt = torch.cat((yt, Y_true))
+            yp = torch.cat((yp, Y_pred))
+
+
+    best = best / 9
+    aver = aver / 9
+
+    result_write.write('**The average Best accuracy is: ' + str(best) + "\n")
+    result_write.write('The average Aver accuracy is: ' + str(aver) + "\n")
+    result_write.close()
+
+
+if __name__ == "__main__":
+    print(time.asctime(time.localtime(time.time())))
+    main()
+    print(time.asctime(time.localtime(time.time())))