|
a |
|
b/visualize.py |
|
|
1 |
#! /usr/bin/python |
|
|
2 |
# -*- coding: utf8 -*- |
|
|
3 |
|
|
|
4 |
import argparse |
|
|
5 |
import os |
|
|
6 |
import re |
|
|
7 |
import scipy.io as sio |
|
|
8 |
import numpy as np |
|
|
9 |
from sklearn.metrics import cohen_kappa_score |
|
|
10 |
from sklearn.metrics import confusion_matrix, f1_score |
|
|
11 |
import matplotlib.pyplot as plt |
|
|
12 |
W=0 |
|
|
13 |
N1=1 |
|
|
14 |
N2=2 |
|
|
15 |
N3=3 |
|
|
16 |
REM=4 |
|
|
17 |
classes= ['W','N1', 'N2','N3','REM'] |
|
|
18 |
n_classes = len(classes) |
|
|
19 |
|
|
|
20 |
def plot_attention(attention_map, input_tags = None, output_tags = None): |
|
|
21 |
attn_len = len(attention_map) |
|
|
22 |
|
|
|
23 |
# Plot the attention_map |
|
|
24 |
# plt.clf() |
|
|
25 |
f = plt.figure(figsize=(15, 10)) |
|
|
26 |
ax = f.add_subplot(1, 1, 1) |
|
|
27 |
ax.tick_params(axis='both', which='major', labelsize=15) |
|
|
28 |
ax.tick_params(axis='both', which='minor', labelsize=15) |
|
|
29 |
|
|
|
30 |
# Add image |
|
|
31 |
i = ax.imshow(attention_map, interpolation='nearest', cmap='gray') |
|
|
32 |
|
|
|
33 |
# Add colorbar |
|
|
34 |
cbaxes = f.add_axes([0.2, -0.02, 0.6, 0.03]) |
|
|
35 |
cbar = f.colorbar(i, cax=cbaxes, orientation='horizontal') |
|
|
36 |
cbar.ax.set_xlabel('Alpha value (Probability output of the "softmax")', labelpad=2) |
|
|
37 |
cbar.ax.tick_params(labelsize=15) |
|
|
38 |
cbar.ax.xaxis.label.set_size(16) |
|
|
39 |
|
|
|
40 |
# Add labels |
|
|
41 |
ax.set_yticks(range(attn_len)) |
|
|
42 |
if output_tags != None: |
|
|
43 |
ax.set_yticklabels(output_tags[:attn_len]) |
|
|
44 |
|
|
|
45 |
ax.set_xticks(range(attn_len)) |
|
|
46 |
if input_tags != None: |
|
|
47 |
ax.set_xticklabels(input_tags[:attn_len], rotation=45) |
|
|
48 |
|
|
|
49 |
|
|
|
50 |
ax.set_xlabel('Input: EEG Epochs (a sequence)') |
|
|
51 |
ax.xaxis.label.set_size(20) |
|
|
52 |
ax.set_ylabel('Output: Satge Scoring') |
|
|
53 |
ax.yaxis.label.set_size(20) |
|
|
54 |
|
|
|
55 |
# add grid and legend |
|
|
56 |
ax.grid() |
|
|
57 |
HERE = os.path.realpath(os.path.join(os.path.realpath(__file__), '..')) |
|
|
58 |
dir_save = os.path.join(HERE, 'attention_maps') |
|
|
59 |
if (os.path.exists(dir_save) == False): |
|
|
60 |
os.mkdir(dir_save) |
|
|
61 |
f.savefig(os.path.join(dir_save, 'a_map_6_5.pdf'), bbox_inches='tight') |
|
|
62 |
# f.show() |
|
|
63 |
plt.show() |
|
|
64 |
|
|
|
65 |
def visualize(data_dir): |
|
|
66 |
|
|
|
67 |
# Remove non-output files, and perform ascending sort |
|
|
68 |
allfiles = os.listdir(data_dir) |
|
|
69 |
outputfiles = [] |
|
|
70 |
for idx, f in enumerate(allfiles): |
|
|
71 |
if re.match("^output_.+\d+\.npz", f): |
|
|
72 |
outputfiles.append(os.path.join(data_dir, f)) |
|
|
73 |
outputfiles.sort() |
|
|
74 |
|
|
|
75 |
|
|
|
76 |
with np.load(outputfiles[0]) as f: |
|
|
77 |
y_true = f["y_true"] |
|
|
78 |
y_pred = f["y_pred"] |
|
|
79 |
alignments_alphas_all = f["alignments_alphas_all"] # (batch_num,B,max_time_step,max_time_step) |
|
|
80 |
|
|
|
81 |
batch_len = len(alignments_alphas_all) |
|
|
82 |
char2numY = dict(zip(classes, range(len(classes)))) |
|
|
83 |
num2charY = dict(zip(char2numY.values(), char2numY.keys())) |
|
|
84 |
shape = alignments_alphas_all.shape |
|
|
85 |
max_time_step = shape[2] |
|
|
86 |
batch_num = 6 |
|
|
87 |
alignments_alphas = alignments_alphas_all[batch_num] # get results for the batch of batch_num |
|
|
88 |
y_true_ = np.reshape(y_true,[-1,shape[1],shape[2]]) |
|
|
89 |
y_pred_ = np.reshape(y_pred, [-1, shape[1], shape[2]]) |
|
|
90 |
|
|
|
91 |
input_tags = [[num2charY[i] for i in seq] for seq in y_true_[batch_num,:]] |
|
|
92 |
output_tags = [[num2charY[i] for i in seq] for seq in y_pred_[batch_num,:]] |
|
|
93 |
|
|
|
94 |
plot_attention(alignments_alphas[5, :, :], input_tags[5], output_tags[5]) |
|
|
95 |
|
|
|
96 |
|
|
|
97 |
|
|
|
98 |
def main(): |
|
|
99 |
parser = argparse.ArgumentParser() |
|
|
100 |
parser.add_argument("--data_dir", type=str, default="outputs_2013/outputs_eeg_fpz_cz", |
|
|
101 |
help="Directory where to load prediction outputs") |
|
|
102 |
args = parser.parse_args() |
|
|
103 |
|
|
|
104 |
if args.data_dir is not None: |
|
|
105 |
visualize(data_dir=args.data_dir) |
|
|
106 |
|
|
|
107 |
|
|
|
108 |
if __name__ == "__main__": |
|
|
109 |
main() |