Diff of /Trans.py [000000] .. [9c6ad1]

Switch to side-by-side view

--- a
+++ b/Trans.py
@@ -0,0 +1,453 @@
+"""
+Transformer for EEG classification
+
+The core idea is slicing, which means to split the signal along the time dimension. Slice is just like the patch in Vision Transformer.
+"""
+
+
+import os
+import numpy as np
+import math
+import random
+import time
+import scipy.io
+
+from torch.utils.data import DataLoader
+from torch.autograd import Variable
+from torchsummary import summary
+
+import torch
+import torch.nn.functional as F
+
+from torch import nn
+from torch import Tensor
+
+from einops import rearrange, reduce, repeat
+from einops.layers.torch import Rearrange, Reduce
+from common_spatial_pattern import csp
+# from confusion_matrix import plot_confusion_matrix
+# from cm_no_normal import plot_confusion_matrix_nn
+# from torchsummary import summary
+
+import matplotlib.pyplot as plt
+# from torch.utils.tensorboard import SummaryWriter
+from torch.backends import cudnn
+cudnn.benchmark = False
+cudnn.deterministic = True
+
+# writer = SummaryWriter('./TensorBoardX/')
+
+# torch.cuda.set_device(6)
+gpus = [0]
+os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
+os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, gpus))
+
+
+class PatchEmbedding(nn.Module):
+    def __init__(self, emb_size):
+        # self.patch_size = patch_size
+        super().__init__()
+        self.projection = nn.Sequential(
+            nn.Conv2d(1, 2, (1, 51), (1, 1)),
+            nn.BatchNorm2d(2),
+            nn.LeakyReLU(0.2),
+            nn.Conv2d(2, emb_size, (16, 5), stride=(1, 5)),
+            Rearrange('b e (h) (w) -> b (h w) e'),
+        )
+        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
+        # self.positions = nn.Parameter(torch.randn((100 + 1, emb_size)))
+        # self.positions = nn.Parameter(torch.randn((2200 + 1, emb_size)))
+
+    def forward(self, x: Tensor) -> Tensor:
+        b, _, _, _ = x.shape
+        x = self.projection(x)
+        cls_tokens = repeat(self.cls_token, '() n e -> b n e', b=b)
+
+        # position
+        # x += self.positions
+        return x
+
+
+class MultiHeadAttention(nn.Module):
+    def __init__(self, emb_size, num_heads, dropout):
+        super().__init__()
+        self.emb_size = emb_size
+        self.num_heads = num_heads
+        self.keys = nn.Linear(emb_size, emb_size)
+        self.queries = nn.Linear(emb_size, emb_size)
+        self.values = nn.Linear(emb_size, emb_size)
+        self.att_drop = nn.Dropout(dropout)
+        self.projection = nn.Linear(emb_size, emb_size)
+
+    def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
+        queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
+        keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
+        values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
+        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)  # batch, num_heads, query_len, key_len
+        if mask is not None:
+            fill_value = torch.finfo(torch.float32).min
+            energy.mask_fill(~mask, fill_value)
+
+        scaling = self.emb_size ** (1 / 2)
+        att = F.softmax(energy / scaling, dim=-1)
+        att = self.att_drop(att)
+        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
+        out = rearrange(out, "b h n d -> b n (h d)")
+        out = self.projection(out)
+        return out
+
+
+class ResidualAdd(nn.Module):
+    def __init__(self, fn):
+        super().__init__()
+        self.fn = fn
+
+    def forward(self, x, **kwargs):
+        res = x
+        x = self.fn(x, **kwargs)
+        x += res
+        return x
+
+
+class FeedForwardBlock(nn.Sequential):
+    def __init__(self, emb_size, expansion, drop_p):
+        super().__init__(
+            nn.Linear(emb_size, expansion * emb_size),
+            nn.GELU(),
+            nn.Dropout(drop_p),
+            nn.Linear(expansion * emb_size, emb_size),
+        )
+
+class GELU(nn.Module):
+    def forward(self, input: Tensor) -> Tensor:
+        return input*0.5*(1.0+torch.erf(input/math.sqrt(2.0)))
+
+
+class TransformerEncoderBlock(nn.Sequential):
+    def __init__(self,
+                 emb_size,
+                 num_heads=5,
+                 drop_p=0.5,
+                 forward_expansion=4,
+                 forward_drop_p=0.5):
+        super().__init__(
+            ResidualAdd(nn.Sequential(
+                nn.LayerNorm(emb_size),
+                MultiHeadAttention(emb_size, num_heads, drop_p),
+                nn.Dropout(drop_p)
+            )),
+            ResidualAdd(nn.Sequential(
+                nn.LayerNorm(emb_size),
+                FeedForwardBlock(
+                    emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
+                nn.Dropout(drop_p)
+            )
+            ))
+
+
+class TransformerEncoder(nn.Sequential):
+    def __init__(self, depth, emb_size):
+        super().__init__(*[TransformerEncoderBlock(emb_size) for _ in range(depth)])
+
+
+class ClassificationHead(nn.Sequential):
+    def __init__(self, emb_size, n_classes):
+        super().__init__()
+        self.clshead = nn.Sequential(
+            Reduce('b n e -> b e', reduction='mean'),
+            nn.LayerNorm(emb_size),
+            nn.Linear(emb_size, n_classes)
+        )
+
+    def forward(self, x):
+        out = self.clshead(x)
+        return x, out
+
+
+class ViT(nn.Sequential):
+    def __init__(self, emb_size=10, depth=3, n_classes=4, **kwargs):
+        super().__init__(
+            # channel_attention(),
+            ResidualAdd(
+                nn.Sequential(
+                    nn.LayerNorm(1000),
+                    channel_attention(),
+                    nn.Dropout(0.5),
+                )
+            ),
+
+            PatchEmbedding(emb_size),
+            TransformerEncoder(depth, emb_size),
+            ClassificationHead(emb_size, n_classes)
+        )
+
+
+class channel_attention(nn.Module):
+    def __init__(self, sequence_num=1000, inter=30):
+        super(channel_attention, self).__init__()
+        self.sequence_num = sequence_num
+        self.inter = inter
+        self.extract_sequence = int(self.sequence_num / self.inter)  # You could choose to do that for less computation
+
+        self.query = nn.Sequential(
+            nn.Linear(16, 16),
+            nn.LayerNorm(16),  # also may introduce improvement to a certain extent
+            nn.Dropout(0.3)
+        )
+        self.key = nn.Sequential(
+            nn.Linear(16, 16),
+            # nn.LeakyReLU(),
+            nn.LayerNorm(16),
+            nn.Dropout(0.3)
+        )
+
+        # self.value = self.key
+        self.projection = nn.Sequential(
+            nn.Linear(16, 16),
+            # nn.LeakyReLU(),
+            nn.LayerNorm(16),
+            nn.Dropout(0.3),
+        )
+
+        self.drop_out = nn.Dropout(0)
+        self.pooling = nn.AvgPool2d(kernel_size=(1, self.inter), stride=(1, self.inter))
+
+        for m in self.modules():
+            if isinstance(m, nn.Linear):
+                nn.init.xavier_normal_(m.weight)
+                if m.bias is not None:
+                    nn.init.constant_(m.bias, 0.0)
+
+    def forward(self, x):
+        temp = rearrange(x, 'b o c s->b o s c')
+        temp_query = rearrange(self.query(temp), 'b o s c -> b o c s')
+        temp_key = rearrange(self.key(temp), 'b o s c -> b o c s')
+
+        channel_query = self.pooling(temp_query)
+        channel_key = self.pooling(temp_key)
+
+        scaling = self.extract_sequence ** (1 / 2)
+
+        channel_atten = torch.einsum('b o c s, b o m s -> b o c m', channel_query, channel_key) / scaling
+
+        channel_atten_score = F.softmax(channel_atten, dim=-1)
+        channel_atten_score = self.drop_out(channel_atten_score)
+
+        out = torch.einsum('b o c s, b o c m -> b o c s', x, channel_atten_score)
+        '''
+        projections after or before multiplying with attention score are almost the same.
+        '''
+        out = rearrange(out, 'b o c s -> b o s c')
+        out = self.projection(out)
+        out = rearrange(out, 'b o s c -> b o c s')
+        return out
+
+
+class Trans():
+    def __init__(self, nsub):
+        super(Trans, self).__init__()
+        self.batch_size = 50
+        self.n_epochs = 1000
+        self.img_height = 22
+        self.img_width = 600
+        self.channels = 1
+        self.c_dim = 4
+        self.lr = 0.0002
+        self.b1 = 0.5
+        self.b2 = 0.9
+        self.nSub = nsub
+        self.start_epoch = 0
+        self.root = '...'  # the path of data
+
+        self.pretrain = False
+
+        self.log_write = open("./results/log_subject%d.txt" % self.nSub, "w")
+
+        self.img_shape = (self.channels, self.img_height, self.img_width)  # something no use
+
+        self.Tensor = torch.cuda.FloatTensor
+        self.LongTensor = torch.cuda.LongTensor
+
+        self.criterion_l1 = torch.nn.L1Loss().cuda()
+        self.criterion_l2 = torch.nn.MSELoss().cuda()
+        self.criterion_cls = torch.nn.CrossEntropyLoss().cuda()
+
+        self.model = ViT().cuda()
+        self.model = nn.DataParallel(self.model, device_ids=[i for i in range(len(gpus))])
+        self.model = self.model.cuda()
+        summary(self.model, (1, 16, 1000))
+
+        self.centers = {}
+
+    def get_source_data(self):
+
+        # to get the data of target subject
+        self.total_data = scipy.io.loadmat(self.root + 'A0%dT.mat' % self.nSub)
+        self.train_data = self.total_data['data']
+        self.train_label = self.total_data['label']
+
+        self.train_data = np.transpose(self.train_data, (2, 1, 0))
+        self.train_data = np.expand_dims(self.train_data, axis=1)
+        self.train_label = np.transpose(self.train_label)
+
+        self.allData = self.train_data
+        self.allLabel = self.train_label[0]
+
+        # test data
+        # to get the data of target subject
+        self.test_tmp = scipy.io.loadmat(self.root + 'A0%dE.mat' % self.nSub)
+        self.test_data = self.test_tmp['data']
+        self.test_label = self.test_tmp['label']
+
+        # self.train_data = self.train_data[250:1000, :, :]
+        self.test_data = np.transpose(self.test_data, (2, 1, 0))
+        self.test_data = np.expand_dims(self.test_data, axis=1)
+        self.test_label = np.transpose(self.test_label)
+
+        self.testData = self.test_data
+        self.testLabel = self.test_label[0]
+
+        # standardize
+        target_mean = np.mean(self.allData)
+        target_std = np.std(self.allData)
+        self.allData = (self.allData - target_mean) / target_std
+        self.testData = (self.testData - target_mean) / target_std
+
+        tmp_alldata = np.transpose(np.squeeze(self.allData), (0, 2, 1))
+        Wb = csp(tmp_alldata, self.allLabel-1)  # common spatial pattern
+        self.allData = np.einsum('abcd, ce -> abed', self.allData, Wb)
+        self.testData = np.einsum('abcd, ce -> abed', self.testData, Wb)
+        return self.allData, self.allLabel, self.testData, self.testLabel
+
+    def update_lr(self, optimizer, lr):
+        for param_group in optimizer.param_groups:
+            param_group['lr'] = lr
+
+    # Do some data augmentation is a potential way to improve the generalization ability
+    def aug(self, img, label):
+        aug_data = []
+        aug_label = []
+        return aug_data, aug_label
+
+    def train(self):
+
+
+        img, label, test_data, test_label = self.get_source_data()
+        img = torch.from_numpy(img)
+        label = torch.from_numpy(label - 1)
+
+
+        dataset = torch.utils.data.TensorDataset(img, label)
+        self.dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)
+
+        test_data = torch.from_numpy(test_data)
+        test_label = torch.from_numpy(test_label - 1)
+        test_dataset = torch.utils.data.TensorDataset(test_data, test_label)
+        self.test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=True)
+
+        # Optimizers
+        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(self.b1, self.b2))
+
+        test_data = Variable(test_data.type(self.Tensor))
+        test_label = Variable(test_label.type(self.LongTensor))
+
+        bestAcc = 0
+        averAcc = 0
+        num = 0
+        Y_true = 0
+        Y_pred = 0
+
+        # Train the cnn model
+        total_step = len(self.dataloader)
+        curr_lr = self.lr
+        # some better optimization strategy is worthy to explore. Sometimes terrible over-fitting.
+
+
+        for e in range(self.n_epochs):
+            in_epoch = time.time()
+            self.model.train()
+            for i, (img, label) in enumerate(self.dataloader):
+
+                img = Variable(img.cuda().type(self.Tensor))
+                label = Variable(label.cuda().type(self.LongTensor))
+                tok, outputs = self.model(img)
+                loss = self.criterion_cls(outputs, label)
+                self.optimizer.zero_grad()
+                loss.backward()
+                self.optimizer.step()
+
+            out_epoch = time.time()
+
+            if (e + 1) % 1 == 0:
+                self.model.eval()
+                Tok, Cls = self.model(test_data)
+
+                loss_test = self.criterion_cls(Cls, test_label)
+                y_pred = torch.max(Cls, 1)[1]
+                acc = float((y_pred == test_label).cpu().numpy().astype(int).sum()) / float(test_label.size(0))
+                train_pred = torch.max(outputs, 1)[1]
+                train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0))
+                print('Epoch:', e,
+                      '  Train loss:', loss.detach().cpu().numpy(),
+                      '  Test loss:', loss_test.detach().cpu().numpy(),
+                      '  Train accuracy:', train_acc,
+                      '  Test accuracy is:', acc)
+                self.log_write.write(str(e) + "    " + str(acc) + "\n")
+                num = num + 1
+                averAcc = averAcc + acc
+                if acc > bestAcc:
+                    bestAcc = acc
+                    Y_true = test_label
+                    Y_pred = y_pred
+
+        torch.save(self.model.module.state_dict(), 'model.pth')
+        averAcc = averAcc / num
+        print('The average accuracy is:', averAcc)
+        print('The best accuracy is:', bestAcc)
+        self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n")
+        self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n")
+
+        return bestAcc, averAcc, Y_true, Y_pred
+
+
+def main():
+    best = 0
+    aver = 0
+    result_write = open("./results/sub_result.txt", "w")
+
+    for i in range(9):
+        seed_n = np.random.randint(500)
+        print('seed is ' + str(seed_n))
+        random.seed(seed_n)
+        np.random.seed(seed_n)
+        torch.manual_seed(seed_n)
+        torch.cuda.manual_seed(seed_n)
+        torch.cuda.manual_seed_all(seed_n)
+        print('Subject %d' % (i+1))
+        trans = Trans(i + 1)
+        bestAcc, averAcc, Y_true, Y_pred = trans.train()
+        print('THE BEST ACCURACY IS ' + str(bestAcc))
+        result_write.write('Subject ' + str(i + 1) + ' : ' + 'Seed is: ' + str(seed_n) + "\n")
+        result_write.write('**Subject ' + str(i + 1) + ' : ' + 'The best accuracy is: ' + str(bestAcc) + "\n")
+        result_write.write('Subject ' + str(i + 1) + ' : ' + 'The average accuracy is: ' + str(averAcc) + "\n")
+        # plot_confusion_matrix(Y_true, Y_pred, i+1)
+        best = best + bestAcc
+        aver = aver + averAcc
+        if i == 0:
+            yt = Y_true
+            yp = Y_pred
+        else:
+            yt = torch.cat((yt, Y_true))
+            yp = torch.cat((yp, Y_pred))
+
+
+    best = best / 9
+    aver = aver / 9
+    # plot_confusion_matrix(yt, yp, 666)
+    result_write.write('**The average Best accuracy is: ' + str(best) + "\n")
+    result_write.write('The average Aver accuracy is: ' + str(aver) + "\n")
+    result_write.close()
+
+
+if __name__ == "__main__":
+    main()