--- a +++ b/visualization/CAT.py @@ -0,0 +1,157 @@ +""" +Class activation topography (CAT) for EEG model visualization, combining class activity map and topography +Code: Class activation map (CAM) and then CAT + +refer to high-star repo on github: +https://github.com/WZMIAOMIAO/deep-learning-for-image-processing/tree/master/pytorch_classification/grad_cam + +Salute every open-source researcher and developer! +""" + + +import argparse +import os +gpus = [1] +os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' +os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, gpus)) +os.environ['CUDA_LAUNCH_BLOCKING'] = '1' +import numpy as np +import math +import glob +import random +import itertools +import datetime +import time +import datetime +import sys +import scipy.io + +import torchvision.transforms as transforms +from torchvision.utils import save_image, make_grid + +from torch.utils.data import DataLoader +from torch.autograd import Variable +from torchsummary import summary +import torch.autograd as autograd +from torchvision.models import vgg19 + +import torch.nn as nn +import torch.nn.functional as F +import torch +import torch.nn.init as init + +from torch.utils.data import Dataset +from PIL import Image +import torchvision.transforms as transforms +from sklearn.decomposition import PCA + +import torch +import torch.nn.functional as F +import matplotlib.pyplot as plt + +from torch import nn +from torch import Tensor +from PIL import Image +from torchvision.transforms import Compose, Resize, ToTensor +from einops import rearrange, reduce, repeat +from einops.layers.torch import Rearrange, Reduce +# from common_spatial_pattern import csp + +import matplotlib.pyplot as plt +from torch.backends import cudnn +# from tSNE import plt_tsne +# from grad_cam.utils import GradCAM, show_cam_on_image +from utils import GradCAM, show_cam_on_image + +cudnn.benchmark = False +cudnn.deterministic = True + + +# keep the overall model class, omitted here +class ViT(nn.Sequential): + def __init__(self, emb_size=40, depth=6, n_classes=4, **kwargs): + super().__init__( + # ... the model + ) + + +data = np.load('./grad_cam/train_data.npy') +print(np.shape(data)) + + +nSub = 1 +target_category = 2 # set the class (class activation mapping) + +# ! A crucial step for adaptation on Transformer +# reshape_transform b 61 40 -> b 40 1 61 +def reshape_transform(tensor): + result = rearrange(tensor, 'b (h w) e -> b e (h) (w)', h=1) + return result + + +device = torch.device("cpu") +model = ViT() + +# # used for cnn model without transformer +# model.load_state_dict(torch.load('./model/model_cnn.pth', map_location=device)) +# target_layers = [model[0].projection] # set the layer you want to visualize, you can use torchsummary here to find the layer index +# cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False) + +model.load_state_dict(torch.load('./model/sub%d.pth'%nSub, map_location=device)) +target_layers = [model[1]] # set the target layer +cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False, reshape_transform=reshape_transform) + + + +# TODO: Class Activation Topography (proposed in the paper) +import mne +from matplotlib import mlab as mlab + +biosemi_montage = mne.channels.make_standard_montage('biosemi64') +index = [37, 9, 10, 46, 45, 44, 13, 12, 11, 47, 48, 49, 50, 17, 18, 31, 55, 54, 19, 30, 56, 29] # for bci competition iv 2a +biosemi_montage.ch_names = [biosemi_montage.ch_names[i] for i in index] +biosemi_montage.dig = [biosemi_montage.dig[i+3] for i in index] +info = mne.create_info(ch_names=biosemi_montage.ch_names, sfreq=250., ch_types='eeg') + + +all_cam = [] +# this loop is used to obtain the cam of each trial/sample +for i in range(288): + test = torch.as_tensor(data[i:i+1, :, :, :], dtype=torch.float32) + test = torch.autograd.Variable(test, requires_grad=True) + + grayscale_cam = cam(input_tensor=test) + grayscale_cam = grayscale_cam[0, :] + all_cam.append(grayscale_cam) + + +# the mean of all data +test_all_data = np.squeeze(np.mean(data, axis=0)) +# test_all_data = (test_all_data - np.mean(test_all_data)) / np.std(test_all_data) +mean_all_test = np.mean(test_all_data, axis=1) + +# the mean of all cam +test_all_cam = np.mean(all_cam, axis=0) +# test_all_cam = (test_all_cam - np.mean(test_all_cam)) / np.std(test_all_cam) +mean_all_cam = np.mean(test_all_cam, axis=1) + +# apply cam on the input data +hyb_all = test_all_data * test_all_cam +# hyb_all = (hyb_all - np.mean(hyb_all)) / np.std(hyb_all) +mean_hyb_all = np.mean(hyb_all, axis=1) + +evoked = mne.EvokedArray(test_all_data, info) +evoked.set_montage(biosemi_montage) + +fig, [ax1, ax2] = plt.subplots(nrows=2) + +# print(mean_all_test) +plt.subplot(211) +im1, cn1 = mne.viz.plot_topomap(mean_all_test, evoked.info, show=False, axes=ax1, res=1200) + + +plt.subplot(212) +im2, cn2 = mne.viz.plot_topomap(mean_hyb_all, evoked.info, show=False, axes=ax2, res=1200) + + +