Diff of /visualize.py [000000] .. [3f9044]

Switch to unified view

a b/visualize.py
1
#! /usr/bin/python
2
# -*- coding: utf8 -*-
3
4
import argparse
5
import os
6
import re
7
import scipy.io as sio
8
import numpy as np
9
from sklearn.metrics import cohen_kappa_score
10
from sklearn.metrics import confusion_matrix, f1_score
11
import matplotlib.pyplot as plt
12
W=0
13
N1=1
14
N2=2
15
N3=3
16
REM=4
17
classes= ['W','N1', 'N2','N3','REM']
18
n_classes = len(classes)
19
20
def plot_attention(attention_map, input_tags = None, output_tags = None):
21
    attn_len = len(attention_map)
22
23
    # Plot the attention_map
24
    # plt.clf()
25
    f = plt.figure(figsize=(15, 10))
26
    ax = f.add_subplot(1, 1, 1)
27
    ax.tick_params(axis='both', which='major', labelsize=15)
28
    ax.tick_params(axis='both', which='minor', labelsize=15)
29
30
    # Add image
31
    i = ax.imshow(attention_map, interpolation='nearest', cmap='gray')
32
33
    # Add colorbar
34
    cbaxes = f.add_axes([0.2, -0.02, 0.6, 0.03])
35
    cbar = f.colorbar(i, cax=cbaxes, orientation='horizontal')
36
    cbar.ax.set_xlabel('Alpha value (Probability output of the "softmax")', labelpad=2)
37
    cbar.ax.tick_params(labelsize=15)
38
    cbar.ax.xaxis.label.set_size(16)
39
40
    # Add labels
41
    ax.set_yticks(range(attn_len))
42
    if output_tags != None:
43
      ax.set_yticklabels(output_tags[:attn_len])
44
45
    ax.set_xticks(range(attn_len))
46
    if input_tags != None:
47
      ax.set_xticklabels(input_tags[:attn_len], rotation=45)
48
49
50
    ax.set_xlabel('Input: EEG Epochs (a sequence)')
51
    ax.xaxis.label.set_size(20)
52
    ax.set_ylabel('Output: Satge Scoring')
53
    ax.yaxis.label.set_size(20)
54
55
    # add grid and legend
56
    ax.grid()
57
    HERE = os.path.realpath(os.path.join(os.path.realpath(__file__), '..'))
58
    dir_save = os.path.join(HERE, 'attention_maps')
59
    if (os.path.exists(dir_save) == False):
60
        os.mkdir(dir_save)
61
    f.savefig(os.path.join(dir_save, 'a_map_6_5.pdf'), bbox_inches='tight')
62
    # f.show()
63
    plt.show()
64
65
def visualize(data_dir):
66
67
    # Remove non-output files, and perform ascending sort
68
    allfiles = os.listdir(data_dir)
69
    outputfiles = []
70
    for idx, f in enumerate(allfiles):
71
        if re.match("^output_.+\d+\.npz", f):
72
            outputfiles.append(os.path.join(data_dir, f))
73
    outputfiles.sort()
74
75
76
    with np.load(outputfiles[0]) as f:
77
        y_true = f["y_true"]
78
        y_pred = f["y_pred"]
79
        alignments_alphas_all = f["alignments_alphas_all"] # (batch_num,B,max_time_step,max_time_step)
80
81
    batch_len = len(alignments_alphas_all)
82
    char2numY = dict(zip(classes, range(len(classes))))
83
    num2charY = dict(zip(char2numY.values(), char2numY.keys()))
84
    shape = alignments_alphas_all.shape
85
    max_time_step = shape[2]
86
    batch_num = 6
87
    alignments_alphas = alignments_alphas_all[batch_num] # get results for the batch of batch_num
88
    y_true_ = np.reshape(y_true,[-1,shape[1],shape[2]])
89
    y_pred_ = np.reshape(y_pred, [-1, shape[1], shape[2]])
90
91
    input_tags = [[num2charY[i] for i in seq] for seq in y_true_[batch_num,:]]
92
    output_tags = [[num2charY[i] for i in seq] for seq in y_pred_[batch_num,:]]
93
94
    plot_attention(alignments_alphas[5, :, :], input_tags[5], output_tags[5])
95
96
97
98
def main():
99
    parser = argparse.ArgumentParser()
100
    parser.add_argument("--data_dir", type=str, default="outputs_2013/outputs_eeg_fpz_cz",
101
                        help="Directory where to load prediction outputs")
102
    args = parser.parse_args()
103
104
    if args.data_dir is not None:
105
        visualize(data_dir=args.data_dir)
106
107
108
if __name__ == "__main__":
109
    main()