[c2ab9a]: / visualization / CAT.py

Download this file

158 lines (115 with data), 4.8 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
"""
Class activation topography (CAT) for EEG model visualization, combining class activity map and topography
Code: Class activation map (CAM) and then CAT
refer to high-star repo on github:
https://github.com/WZMIAOMIAO/deep-learning-for-image-processing/tree/master/pytorch_classification/grad_cam
Salute every open-source researcher and developer!
"""
import argparse
import os
gpus = [1]
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, gpus))
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
import numpy as np
import math
import glob
import random
import itertools
import datetime
import time
import datetime
import sys
import scipy.io
import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torchsummary import summary
import torch.autograd as autograd
from torchvision.models import vgg19
import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.nn.init as init
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms
from sklearn.decomposition import PCA
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
# from common_spatial_pattern import csp
import matplotlib.pyplot as plt
from torch.backends import cudnn
# from tSNE import plt_tsne
# from grad_cam.utils import GradCAM, show_cam_on_image
from utils import GradCAM, show_cam_on_image
cudnn.benchmark = False
cudnn.deterministic = True
# keep the overall model class, omitted here
class ViT(nn.Sequential):
def __init__(self, emb_size=40, depth=6, n_classes=4, **kwargs):
super().__init__(
# ... the model
)
data = np.load('./grad_cam/train_data.npy')
print(np.shape(data))
nSub = 1
target_category = 2 # set the class (class activation mapping)
# ! A crucial step for adaptation on Transformer
# reshape_transform b 61 40 -> b 40 1 61
def reshape_transform(tensor):
result = rearrange(tensor, 'b (h w) e -> b e (h) (w)', h=1)
return result
device = torch.device("cpu")
model = ViT()
# # used for cnn model without transformer
# model.load_state_dict(torch.load('./model/model_cnn.pth', map_location=device))
# target_layers = [model[0].projection] # set the layer you want to visualize, you can use torchsummary here to find the layer index
# cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False)
model.load_state_dict(torch.load('./model/sub%d.pth'%nSub, map_location=device))
target_layers = [model[1]] # set the target layer
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=False, reshape_transform=reshape_transform)
# TODO: Class Activation Topography (proposed in the paper)
import mne
from matplotlib import mlab as mlab
biosemi_montage = mne.channels.make_standard_montage('biosemi64')
index = [37, 9, 10, 46, 45, 44, 13, 12, 11, 47, 48, 49, 50, 17, 18, 31, 55, 54, 19, 30, 56, 29] # for bci competition iv 2a
biosemi_montage.ch_names = [biosemi_montage.ch_names[i] for i in index]
biosemi_montage.dig = [biosemi_montage.dig[i+3] for i in index]
info = mne.create_info(ch_names=biosemi_montage.ch_names, sfreq=250., ch_types='eeg')
all_cam = []
# this loop is used to obtain the cam of each trial/sample
for i in range(288):
test = torch.as_tensor(data[i:i+1, :, :, :], dtype=torch.float32)
test = torch.autograd.Variable(test, requires_grad=True)
grayscale_cam = cam(input_tensor=test)
grayscale_cam = grayscale_cam[0, :]
all_cam.append(grayscale_cam)
# the mean of all data
test_all_data = np.squeeze(np.mean(data, axis=0))
# test_all_data = (test_all_data - np.mean(test_all_data)) / np.std(test_all_data)
mean_all_test = np.mean(test_all_data, axis=1)
# the mean of all cam
test_all_cam = np.mean(all_cam, axis=0)
# test_all_cam = (test_all_cam - np.mean(test_all_cam)) / np.std(test_all_cam)
mean_all_cam = np.mean(test_all_cam, axis=1)
# apply cam on the input data
hyb_all = test_all_data * test_all_cam
# hyb_all = (hyb_all - np.mean(hyb_all)) / np.std(hyb_all)
mean_hyb_all = np.mean(hyb_all, axis=1)
evoked = mne.EvokedArray(test_all_data, info)
evoked.set_montage(biosemi_montage)
fig, [ax1, ax2] = plt.subplots(nrows=2)
# print(mean_all_test)
plt.subplot(211)
im1, cn1 = mne.viz.plot_topomap(mean_all_test, evoked.info, show=False, axes=ax1, res=1200)
plt.subplot(212)
im2, cn2 = mne.viz.plot_topomap(mean_hyb_all, evoked.info, show=False, axes=ax2, res=1200)