--- a
+++ b/conformer.py
@@ -0,0 +1,463 @@
+"""
+EEG Conformer 
+
+Convolutional Transformer for EEG decoding
+
+Couple CNN and Transformer in a concise manner with amazing results
+"""
+# remember to change paths
+
+import argparse
+import os
+gpus = [0]
+os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
+os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, gpus))
+import numpy as np
+import math
+import glob
+import random
+import itertools
+import datetime
+import time
+import datetime
+import sys
+import scipy.io
+
+import torchvision.transforms as transforms
+from torchvision.utils import save_image, make_grid
+
+from torch.utils.data import DataLoader
+from torch.autograd import Variable
+from torchsummary import summary
+import torch.autograd as autograd
+from torchvision.models import vgg19
+
+import torch.nn as nn
+import torch.nn.functional as F
+import torch
+import torch.nn.init as init
+
+from torch.utils.data import Dataset
+from PIL import Image
+import torchvision.transforms as transforms
+from sklearn.decomposition import PCA
+
+import torch
+import torch.nn.functional as F
+import matplotlib.pyplot as plt
+
+from torch import nn
+from torch import Tensor
+from PIL import Image
+from torchvision.transforms import Compose, Resize, ToTensor
+from einops import rearrange, reduce, repeat
+from einops.layers.torch import Rearrange, Reduce
+# from common_spatial_pattern import csp
+
+import matplotlib.pyplot as plt
+# from torch.utils.tensorboard import SummaryWriter
+from torch.backends import cudnn
+cudnn.benchmark = False
+cudnn.deterministic = True
+
+# writer = SummaryWriter('./TensorBoardX/')
+
+
+# Convolution module
+# use conv to capture local features, instead of postion embedding.
+class PatchEmbedding(nn.Module):
+    def __init__(self, emb_size=40):
+        # self.patch_size = patch_size
+        super().__init__()
+
+        self.shallownet = nn.Sequential(
+            nn.Conv2d(1, 40, (1, 25), (1, 1)),
+            nn.Conv2d(40, 40, (22, 1), (1, 1)),
+            nn.BatchNorm2d(40),
+            nn.ELU(),
+            nn.AvgPool2d((1, 75), (1, 15)),  # pooling acts as slicing to obtain 'patch' along the time dimension as in ViT
+            nn.Dropout(0.5),
+        )
+
+        self.projection = nn.Sequential(
+            nn.Conv2d(40, emb_size, (1, 1), stride=(1, 1)),  # transpose, conv could enhance fiting ability slightly
+            Rearrange('b e (h) (w) -> b (h w) e'),
+        )
+
+
+    def forward(self, x: Tensor) -> Tensor:
+        b, _, _, _ = x.shape
+        x = self.shallownet(x)
+        x = self.projection(x)
+        return x
+
+
+class MultiHeadAttention(nn.Module):
+    def __init__(self, emb_size, num_heads, dropout):
+        super().__init__()
+        self.emb_size = emb_size
+        self.num_heads = num_heads
+        self.keys = nn.Linear(emb_size, emb_size)
+        self.queries = nn.Linear(emb_size, emb_size)
+        self.values = nn.Linear(emb_size, emb_size)
+        self.att_drop = nn.Dropout(dropout)
+        self.projection = nn.Linear(emb_size, emb_size)
+
+    def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
+        queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
+        keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
+        values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
+        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)  
+        if mask is not None:
+            fill_value = torch.finfo(torch.float32).min
+            energy.mask_fill(~mask, fill_value)
+
+        scaling = self.emb_size ** (1 / 2)
+        att = F.softmax(energy / scaling, dim=-1)
+        att = self.att_drop(att)
+        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
+        out = rearrange(out, "b h n d -> b n (h d)")
+        out = self.projection(out)
+        return out
+
+
+class ResidualAdd(nn.Module):
+    def __init__(self, fn):
+        super().__init__()
+        self.fn = fn
+
+    def forward(self, x, **kwargs):
+        res = x
+        x = self.fn(x, **kwargs)
+        x += res
+        return x
+
+
+class FeedForwardBlock(nn.Sequential):
+    def __init__(self, emb_size, expansion, drop_p):
+        super().__init__(
+            nn.Linear(emb_size, expansion * emb_size),
+            nn.GELU(),
+            nn.Dropout(drop_p),
+            nn.Linear(expansion * emb_size, emb_size),
+        )
+
+
+class GELU(nn.Module):
+    def forward(self, input: Tensor) -> Tensor:
+        return input*0.5*(1.0+torch.erf(input/math.sqrt(2.0)))
+
+
+class TransformerEncoderBlock(nn.Sequential):
+    def __init__(self,
+                 emb_size,
+                 num_heads=10,
+                 drop_p=0.5,
+                 forward_expansion=4,
+                 forward_drop_p=0.5):
+        super().__init__(
+            ResidualAdd(nn.Sequential(
+                nn.LayerNorm(emb_size),
+                MultiHeadAttention(emb_size, num_heads, drop_p),
+                nn.Dropout(drop_p)
+            )),
+            ResidualAdd(nn.Sequential(
+                nn.LayerNorm(emb_size),
+                FeedForwardBlock(
+                    emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
+                nn.Dropout(drop_p)
+            )
+            ))
+
+
+class TransformerEncoder(nn.Sequential):
+    def __init__(self, depth, emb_size):
+        super().__init__(*[TransformerEncoderBlock(emb_size) for _ in range(depth)])
+
+
+class ClassificationHead(nn.Sequential):
+    def __init__(self, emb_size, n_classes):
+        super().__init__()
+        
+        # global average pooling
+        self.clshead = nn.Sequential(
+            Reduce('b n e -> b e', reduction='mean'),
+            nn.LayerNorm(emb_size),
+            nn.Linear(emb_size, n_classes)
+        )
+        self.fc = nn.Sequential(
+            nn.Linear(2440, 256),
+            nn.ELU(),
+            nn.Dropout(0.5),
+            nn.Linear(256, 32),
+            nn.ELU(),
+            nn.Dropout(0.3),
+            nn.Linear(32, 4)
+        )
+
+    def forward(self, x):
+        x = x.contiguous().view(x.size(0), -1)
+        out = self.fc(x)
+        return x, out
+
+
+class Conformer(nn.Sequential):
+    def __init__(self, emb_size=40, depth=6, n_classes=4, **kwargs):
+        super().__init__(
+
+            PatchEmbedding(emb_size),
+            TransformerEncoder(depth, emb_size),
+            ClassificationHead(emb_size, n_classes)
+        )
+
+
+class ExP():
+    def __init__(self, nsub):
+        super(ExP, self).__init__()
+        self.batch_size = 72
+        self.n_epochs = 2000
+        self.c_dim = 4
+        self.lr = 0.0002
+        self.b1 = 0.5
+        self.b2 = 0.999
+        self.dimension = (190, 50)
+        self.nSub = nsub
+
+        self.start_epoch = 0
+        self.root = '/Data/strict_TE/'
+
+        self.log_write = open("./results/log_subject%d.txt" % self.nSub, "w")
+
+
+        self.Tensor = torch.cuda.FloatTensor
+        self.LongTensor = torch.cuda.LongTensor
+
+        self.criterion_l1 = torch.nn.L1Loss().cuda()
+        self.criterion_l2 = torch.nn.MSELoss().cuda()
+        self.criterion_cls = torch.nn.CrossEntropyLoss().cuda()
+
+        self.model = Conformer().cuda()
+        self.model = nn.DataParallel(self.model, device_ids=[i for i in range(len(gpus))])
+        self.model = self.model.cuda()
+        # summary(self.model, (1, 22, 1000))
+
+
+    # Segmentation and Reconstruction (S&R) data augmentation
+    def interaug(self, timg, label):  
+        aug_data = []
+        aug_label = []
+        for cls4aug in range(4):
+            cls_idx = np.where(label == cls4aug + 1)
+            tmp_data = timg[cls_idx]
+            tmp_label = label[cls_idx]
+
+            tmp_aug_data = np.zeros((int(self.batch_size / 4), 1, 22, 1000))
+            for ri in range(int(self.batch_size / 4)):
+                for rj in range(8):
+                    rand_idx = np.random.randint(0, tmp_data.shape[0], 8)
+                    tmp_aug_data[ri, :, :, rj * 125:(rj + 1) * 125] = tmp_data[rand_idx[rj], :, :,
+                                                                      rj * 125:(rj + 1) * 125]
+
+            aug_data.append(tmp_aug_data)
+            aug_label.append(tmp_label[:int(self.batch_size / 4)])
+        aug_data = np.concatenate(aug_data)
+        aug_label = np.concatenate(aug_label)
+        aug_shuffle = np.random.permutation(len(aug_data))
+        aug_data = aug_data[aug_shuffle, :, :]
+        aug_label = aug_label[aug_shuffle]
+
+        aug_data = torch.from_numpy(aug_data).cuda()
+        aug_data = aug_data.float()
+        aug_label = torch.from_numpy(aug_label-1).cuda()
+        aug_label = aug_label.long()
+        return aug_data, aug_label
+
+    def get_source_data(self):
+        # ! please please recheck if you need validation set 
+        # ! and the data segement compared methods used
+
+        # train data
+        self.total_data = scipy.io.loadmat(self.root + 'A0%dT.mat' % self.nSub)
+        self.train_data = self.total_data['data']
+        self.train_label = self.total_data['label']
+
+        self.train_data = np.transpose(self.train_data, (2, 1, 0))
+        self.train_data = np.expand_dims(self.train_data, axis=1)
+        self.train_label = np.transpose(self.train_label)
+
+        self.allData = self.train_data
+        self.allLabel = self.train_label[0]
+
+        shuffle_num = np.random.permutation(len(self.allData))
+        self.allData = self.allData[shuffle_num, :, :, :]
+        self.allLabel = self.allLabel[shuffle_num]
+
+        # test data
+        self.test_tmp = scipy.io.loadmat(self.root + 'A0%dE.mat' % self.nSub)
+        self.test_data = self.test_tmp['data']
+        self.test_label = self.test_tmp['label']
+
+        self.test_data = np.transpose(self.test_data, (2, 1, 0))
+        self.test_data = np.expand_dims(self.test_data, axis=1)
+        self.test_label = np.transpose(self.test_label)
+
+        self.testData = self.test_data
+        self.testLabel = self.test_label[0]
+
+
+        # standardize
+        target_mean = np.mean(self.allData)
+        target_std = np.std(self.allData)
+        self.allData = (self.allData - target_mean) / target_std
+        self.testData = (self.testData - target_mean) / target_std
+
+        # data shape: (trial, conv channel, electrode channel, time samples)
+        return self.allData, self.allLabel, self.testData, self.testLabel
+
+
+    def train(self):
+
+        img, label, test_data, test_label = self.get_source_data()
+
+        img = torch.from_numpy(img)
+        label = torch.from_numpy(label - 1)
+
+        dataset = torch.utils.data.TensorDataset(img, label)
+        self.dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)
+
+        test_data = torch.from_numpy(test_data)
+        test_label = torch.from_numpy(test_label - 1)
+        test_dataset = torch.utils.data.TensorDataset(test_data, test_label)
+        self.test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=True)
+
+        # Optimizers
+        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(self.b1, self.b2))
+
+        test_data = Variable(test_data.type(self.Tensor))
+        test_label = Variable(test_label.type(self.LongTensor))
+
+        bestAcc = 0
+        averAcc = 0
+        num = 0
+        Y_true = 0
+        Y_pred = 0
+
+        # Train the cnn model
+        total_step = len(self.dataloader)
+        curr_lr = self.lr
+
+        for e in range(self.n_epochs):
+            # in_epoch = time.time()
+            self.model.train()
+            for i, (img, label) in enumerate(self.dataloader):
+
+                img = Variable(img.cuda().type(self.Tensor))
+                label = Variable(label.cuda().type(self.LongTensor))
+
+                # data augmentation
+                aug_data, aug_label = self.interaug(self.allData, self.allLabel)
+                img = torch.cat((img, aug_data))
+                label = torch.cat((label, aug_label))
+
+
+                tok, outputs = self.model(img)
+
+                loss = self.criterion_cls(outputs, label) 
+
+                self.optimizer.zero_grad()
+                loss.backward()
+                self.optimizer.step()
+
+
+            # out_epoch = time.time()
+
+
+            # test process
+            if (e + 1) % 1 == 0:
+                self.model.eval()
+                Tok, Cls = self.model(test_data)
+
+
+                loss_test = self.criterion_cls(Cls, test_label)
+                y_pred = torch.max(Cls, 1)[1]
+                acc = float((y_pred == test_label).cpu().numpy().astype(int).sum()) / float(test_label.size(0))
+                train_pred = torch.max(outputs, 1)[1]
+                train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0))
+
+                print('Epoch:', e,
+                      '  Train loss: %.6f' % loss.detach().cpu().numpy(),
+                      '  Test loss: %.6f' % loss_test.detach().cpu().numpy(),
+                      '  Train accuracy %.6f' % train_acc,
+                      '  Test accuracy is %.6f' % acc)
+
+                self.log_write.write(str(e) + "    " + str(acc) + "\n")
+                num = num + 1
+                averAcc = averAcc + acc
+                if acc > bestAcc:
+                    bestAcc = acc
+                    Y_true = test_label
+                    Y_pred = y_pred
+
+
+        torch.save(self.model.module.state_dict(), 'model.pth')
+        averAcc = averAcc / num
+        print('The average accuracy is:', averAcc)
+        print('The best accuracy is:', bestAcc)
+        self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n")
+        self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n")
+
+        return bestAcc, averAcc, Y_true, Y_pred
+        # writer.close()
+
+
+def main():
+    best = 0
+    aver = 0
+    result_write = open("./results/sub_result.txt", "w")
+
+    for i in range(9):
+        starttime = datetime.datetime.now()
+
+
+        seed_n = np.random.randint(2021)
+        print('seed is ' + str(seed_n))
+        random.seed(seed_n)
+        np.random.seed(seed_n)
+        torch.manual_seed(seed_n)
+        torch.cuda.manual_seed(seed_n)
+        torch.cuda.manual_seed_all(seed_n)
+
+
+        print('Subject %d' % (i+1))
+        exp = ExP(i + 1)
+
+        bestAcc, averAcc, Y_true, Y_pred = exp.train()
+        print('THE BEST ACCURACY IS ' + str(bestAcc))
+        result_write.write('Subject ' + str(i + 1) + ' : ' + 'Seed is: ' + str(seed_n) + "\n")
+        result_write.write('Subject ' + str(i + 1) + ' : ' + 'The best accuracy is: ' + str(bestAcc) + "\n")
+        result_write.write('Subject ' + str(i + 1) + ' : ' + 'The average accuracy is: ' + str(averAcc) + "\n")
+
+        endtime = datetime.datetime.now()
+        print('subject %d duration: '%(i+1) + str(endtime - starttime))
+        best = best + bestAcc
+        aver = aver + averAcc
+        if i == 0:
+            yt = Y_true
+            yp = Y_pred
+        else:
+            yt = torch.cat((yt, Y_true))
+            yp = torch.cat((yp, Y_pred))
+
+
+    best = best / 9
+    aver = aver / 9
+
+    result_write.write('**The average Best accuracy is: ' + str(best) + "\n")
+    result_write.write('The average Aver accuracy is: ' + str(aver) + "\n")
+    result_write.close()
+
+
+if __name__ == "__main__":
+    print(time.asctime(time.localtime(time.time())))
+    main()
+    print(time.asctime(time.localtime(time.time())))