Diff of /visualization/CAT.py [000000] .. [8bbec7]

Switch to unified view

a b/visualization/CAT.py
1
"""
2
Class activation topography (CAT) for EEG model visualization, combining class activity map and topography
3
Code: Class activation map (CAM) and then CAT
4
5
refer to high-star repo on github: 
6
https://github.com/WZMIAOMIAO/deep-learning-for-image-processing/tree/master/pytorch_classification/grad_cam
7
8
Salute every open-source researcher and developer!
9
"""
10
11
12
import argparse
13
import os
14
gpus = [1]
15
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
16
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, gpus))
17
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
18
import numpy as np
19
import math
20
import glob
21
import random
22
import itertools
23
import datetime
24
import time
25
import datetime
26
import sys
27
import scipy.io
28
29
import torchvision.transforms as transforms
30
from torchvision.utils import save_image, make_grid
31
32
from torch.utils.data import DataLoader
33
from torch.autograd import Variable
34
from torchsummary import summary
35
import torch.autograd as autograd
36
from torchvision.models import vgg19
37
38
import torch.nn as nn
39
import torch.nn.functional as F
40
import torch
41
import torch.nn.init as init
42
43
from torch.utils.data import Dataset
44
from PIL import Image
45
import torchvision.transforms as transforms
46
from sklearn.decomposition import PCA
47
48
import torch
49
import torch.nn.functional as F
50
import matplotlib.pyplot as plt
51
52
from torch import nn
53
from torch import Tensor
54
from PIL import Image
55
from torchvision.transforms import Compose, Resize, ToTensor
56
from einops import rearrange, reduce, repeat
57
from einops.layers.torch import Rearrange, Reduce
58
# from common_spatial_pattern import csp
59
60
import matplotlib.pyplot as plt
61
from torch.backends import cudnn
62
# from tSNE import plt_tsne
63
# from grad_cam.utils import GradCAM, show_cam_on_image
64
from utils import GradCAM, show_cam_on_image
65
66
cudnn.benchmark = False
67
cudnn.deterministic = True
68
69
70
# keep the overall model class, omitted here
71
class ViT(nn.Sequential):
72
    def __init__(self, emb_size=40, depth=6, n_classes=4, **kwargs):
73
        super().__init__(
74
            # ... the model
75
        )
76
77
78
data = np.load('./grad_cam/train_data.npy')  
79
print(np.shape(data))
80
81
82
nSub = 1
83
target_category = 2  # set the class (class activation mapping)
84
85
# ! A crucial step for adaptation on Transformer
86
# reshape_transform  b 61 40 -> b 40 1 61
87
def reshape_transform(tensor):
88
    result = rearrange(tensor, 'b (h w) e -> b e (h) (w)', h=1)
89
    return result
90
91
92
device = torch.device("cpu")
93
model = ViT()
94
95
# # used for cnn model without transformer
96
# model.load_state_dict(torch.load('./model/model_cnn.pth', map_location=device))
97
# target_layers = [model[0].projection]  # set the layer you want to visualize, you can use torchsummary here to find the layer index
98
# cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
99
100
model.load_state_dict(torch.load('./model/sub%d.pth'%nSub, map_location=device))
101
target_layers = [model[1]]  # set the target layer 
102
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False, reshape_transform=reshape_transform)
103
104
105
106
# TODO: Class Activation Topography (proposed in the paper)
107
import mne
108
from matplotlib import mlab as mlab
109
110
biosemi_montage = mne.channels.make_standard_montage('biosemi64')
111
index = [37, 9, 10, 46, 45, 44, 13, 12, 11, 47, 48, 49, 50, 17, 18, 31, 55, 54, 19, 30, 56, 29]  # for bci competition iv 2a
112
biosemi_montage.ch_names = [biosemi_montage.ch_names[i] for i in index]
113
biosemi_montage.dig = [biosemi_montage.dig[i+3] for i in index]
114
info = mne.create_info(ch_names=biosemi_montage.ch_names, sfreq=250., ch_types='eeg')
115
116
117
all_cam = []
118
# this loop is used to obtain the cam of each trial/sample
119
for i in range(288):
120
    test = torch.as_tensor(data[i:i+1, :, :, :], dtype=torch.float32)
121
    test = torch.autograd.Variable(test, requires_grad=True)
122
123
    grayscale_cam = cam(input_tensor=test)
124
    grayscale_cam = grayscale_cam[0, :]
125
    all_cam.append(grayscale_cam)
126
127
128
# the mean of all data
129
test_all_data = np.squeeze(np.mean(data, axis=0))
130
# test_all_data = (test_all_data - np.mean(test_all_data)) / np.std(test_all_data)
131
mean_all_test = np.mean(test_all_data, axis=1)
132
133
# the mean of all cam
134
test_all_cam = np.mean(all_cam, axis=0)
135
# test_all_cam = (test_all_cam - np.mean(test_all_cam)) / np.std(test_all_cam)
136
mean_all_cam = np.mean(test_all_cam, axis=1)
137
138
# apply cam on the input data
139
hyb_all = test_all_data * test_all_cam
140
# hyb_all = (hyb_all - np.mean(hyb_all)) / np.std(hyb_all)
141
mean_hyb_all = np.mean(hyb_all, axis=1)
142
143
evoked = mne.EvokedArray(test_all_data, info)
144
evoked.set_montage(biosemi_montage)
145
146
fig, [ax1, ax2] = plt.subplots(nrows=2)
147
148
# print(mean_all_test)
149
plt.subplot(211)
150
im1, cn1 = mne.viz.plot_topomap(mean_all_test, evoked.info, show=False, axes=ax1, res=1200)
151
152
153
plt.subplot(212)
154
im2, cn2 = mne.viz.plot_topomap(mean_hyb_all, evoked.info, show=False, axes=ax2, res=1200)
155
156
157