|
a |
|
b/visualization/CAT.py |
|
|
1 |
""" |
|
|
2 |
Class activation topography (CAT) for EEG model visualization, combining class activity map and topography |
|
|
3 |
Code: Class activation map (CAM) and then CAT |
|
|
4 |
|
|
|
5 |
refer to high-star repo on github: |
|
|
6 |
https://github.com/WZMIAOMIAO/deep-learning-for-image-processing/tree/master/pytorch_classification/grad_cam |
|
|
7 |
|
|
|
8 |
Salute every open-source researcher and developer! |
|
|
9 |
""" |
|
|
10 |
|
|
|
11 |
|
|
|
12 |
import argparse |
|
|
13 |
import os |
|
|
14 |
gpus = [1] |
|
|
15 |
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' |
|
|
16 |
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, gpus)) |
|
|
17 |
os.environ['CUDA_LAUNCH_BLOCKING'] = '1' |
|
|
18 |
import numpy as np |
|
|
19 |
import math |
|
|
20 |
import glob |
|
|
21 |
import random |
|
|
22 |
import itertools |
|
|
23 |
import datetime |
|
|
24 |
import time |
|
|
25 |
import datetime |
|
|
26 |
import sys |
|
|
27 |
import scipy.io |
|
|
28 |
|
|
|
29 |
import torchvision.transforms as transforms |
|
|
30 |
from torchvision.utils import save_image, make_grid |
|
|
31 |
|
|
|
32 |
from torch.utils.data import DataLoader |
|
|
33 |
from torch.autograd import Variable |
|
|
34 |
from torchsummary import summary |
|
|
35 |
import torch.autograd as autograd |
|
|
36 |
from torchvision.models import vgg19 |
|
|
37 |
|
|
|
38 |
import torch.nn as nn |
|
|
39 |
import torch.nn.functional as F |
|
|
40 |
import torch |
|
|
41 |
import torch.nn.init as init |
|
|
42 |
|
|
|
43 |
from torch.utils.data import Dataset |
|
|
44 |
from PIL import Image |
|
|
45 |
import torchvision.transforms as transforms |
|
|
46 |
from sklearn.decomposition import PCA |
|
|
47 |
|
|
|
48 |
import torch |
|
|
49 |
import torch.nn.functional as F |
|
|
50 |
import matplotlib.pyplot as plt |
|
|
51 |
|
|
|
52 |
from torch import nn |
|
|
53 |
from torch import Tensor |
|
|
54 |
from PIL import Image |
|
|
55 |
from torchvision.transforms import Compose, Resize, ToTensor |
|
|
56 |
from einops import rearrange, reduce, repeat |
|
|
57 |
from einops.layers.torch import Rearrange, Reduce |
|
|
58 |
# from common_spatial_pattern import csp |
|
|
59 |
|
|
|
60 |
import matplotlib.pyplot as plt |
|
|
61 |
from torch.backends import cudnn |
|
|
62 |
# from tSNE import plt_tsne |
|
|
63 |
# from grad_cam.utils import GradCAM, show_cam_on_image |
|
|
64 |
from utils import GradCAM, show_cam_on_image |
|
|
65 |
|
|
|
66 |
cudnn.benchmark = False |
|
|
67 |
cudnn.deterministic = True |
|
|
68 |
|
|
|
69 |
|
|
|
70 |
# keep the overall model class, omitted here |
|
|
71 |
class ViT(nn.Sequential): |
|
|
72 |
def __init__(self, emb_size=40, depth=6, n_classes=4, **kwargs): |
|
|
73 |
super().__init__( |
|
|
74 |
# ... the model |
|
|
75 |
) |
|
|
76 |
|
|
|
77 |
|
|
|
78 |
data = np.load('./grad_cam/train_data.npy') |
|
|
79 |
print(np.shape(data)) |
|
|
80 |
|
|
|
81 |
|
|
|
82 |
nSub = 1 |
|
|
83 |
target_category = 2 # set the class (class activation mapping) |
|
|
84 |
|
|
|
85 |
# ! A crucial step for adaptation on Transformer |
|
|
86 |
# reshape_transform b 61 40 -> b 40 1 61 |
|
|
87 |
def reshape_transform(tensor): |
|
|
88 |
result = rearrange(tensor, 'b (h w) e -> b e (h) (w)', h=1) |
|
|
89 |
return result |
|
|
90 |
|
|
|
91 |
|
|
|
92 |
device = torch.device("cpu") |
|
|
93 |
model = ViT() |
|
|
94 |
|
|
|
95 |
# # used for cnn model without transformer |
|
|
96 |
# model.load_state_dict(torch.load('./model/model_cnn.pth', map_location=device)) |
|
|
97 |
# target_layers = [model[0].projection] # set the layer you want to visualize, you can use torchsummary here to find the layer index |
|
|
98 |
# cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False) |
|
|
99 |
|
|
|
100 |
model.load_state_dict(torch.load('./model/sub%d.pth'%nSub, map_location=device)) |
|
|
101 |
target_layers = [model[1]] # set the target layer |
|
|
102 |
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False, reshape_transform=reshape_transform) |
|
|
103 |
|
|
|
104 |
|
|
|
105 |
|
|
|
106 |
# TODO: Class Activation Topography (proposed in the paper) |
|
|
107 |
import mne |
|
|
108 |
from matplotlib import mlab as mlab |
|
|
109 |
|
|
|
110 |
biosemi_montage = mne.channels.make_standard_montage('biosemi64') |
|
|
111 |
index = [37, 9, 10, 46, 45, 44, 13, 12, 11, 47, 48, 49, 50, 17, 18, 31, 55, 54, 19, 30, 56, 29] # for bci competition iv 2a |
|
|
112 |
biosemi_montage.ch_names = [biosemi_montage.ch_names[i] for i in index] |
|
|
113 |
biosemi_montage.dig = [biosemi_montage.dig[i+3] for i in index] |
|
|
114 |
info = mne.create_info(ch_names=biosemi_montage.ch_names, sfreq=250., ch_types='eeg') |
|
|
115 |
|
|
|
116 |
|
|
|
117 |
all_cam = [] |
|
|
118 |
# this loop is used to obtain the cam of each trial/sample |
|
|
119 |
for i in range(288): |
|
|
120 |
test = torch.as_tensor(data[i:i+1, :, :, :], dtype=torch.float32) |
|
|
121 |
test = torch.autograd.Variable(test, requires_grad=True) |
|
|
122 |
|
|
|
123 |
grayscale_cam = cam(input_tensor=test) |
|
|
124 |
grayscale_cam = grayscale_cam[0, :] |
|
|
125 |
all_cam.append(grayscale_cam) |
|
|
126 |
|
|
|
127 |
|
|
|
128 |
# the mean of all data |
|
|
129 |
test_all_data = np.squeeze(np.mean(data, axis=0)) |
|
|
130 |
# test_all_data = (test_all_data - np.mean(test_all_data)) / np.std(test_all_data) |
|
|
131 |
mean_all_test = np.mean(test_all_data, axis=1) |
|
|
132 |
|
|
|
133 |
# the mean of all cam |
|
|
134 |
test_all_cam = np.mean(all_cam, axis=0) |
|
|
135 |
# test_all_cam = (test_all_cam - np.mean(test_all_cam)) / np.std(test_all_cam) |
|
|
136 |
mean_all_cam = np.mean(test_all_cam, axis=1) |
|
|
137 |
|
|
|
138 |
# apply cam on the input data |
|
|
139 |
hyb_all = test_all_data * test_all_cam |
|
|
140 |
# hyb_all = (hyb_all - np.mean(hyb_all)) / np.std(hyb_all) |
|
|
141 |
mean_hyb_all = np.mean(hyb_all, axis=1) |
|
|
142 |
|
|
|
143 |
evoked = mne.EvokedArray(test_all_data, info) |
|
|
144 |
evoked.set_montage(biosemi_montage) |
|
|
145 |
|
|
|
146 |
fig, [ax1, ax2] = plt.subplots(nrows=2) |
|
|
147 |
|
|
|
148 |
# print(mean_all_test) |
|
|
149 |
plt.subplot(211) |
|
|
150 |
im1, cn1 = mne.viz.plot_topomap(mean_all_test, evoked.info, show=False, axes=ax1, res=1200) |
|
|
151 |
|
|
|
152 |
|
|
|
153 |
plt.subplot(212) |
|
|
154 |
im2, cn2 = mne.viz.plot_topomap(mean_hyb_all, evoked.info, show=False, axes=ax2, res=1200) |
|
|
155 |
|
|
|
156 |
|
|
|
157 |
|