Diff of /visualization/tSNE.py [000000] .. [8bbec7]

Switch to unified view

a b/visualization/tSNE.py
1
# use t-SNE to show the feature distribution
2
3
import numpy as np
4
import matplotlib.pyplot as plt
5
from sklearn import manifold, datasets
6
from einops import reduce
7
import scipy.io
8
import torch
9
10
def plt_tsne(data, label, per):
11
    data = data.cpu().detach().numpy()
12
    data = reduce(data, 'b n e -> b e', reduction='mean')
13
    label = label.cpu().detach().numpy()
14
15
    tsne = manifold.TSNE(n_components=2, perplexity=per, init='pca', random_state=166)
16
    X_tsne = tsne.fit_transform(data)
17
18
    x_min, x_max = X_tsne.min(0), X_tsne.max(0)
19
    X_norm = (X_tsne - x_min) / (x_max - x_min)
20
    plt.figure()
21
    for i in range(X_norm.shape[0]):
22
        plt.scatter(X_norm[i, 0], X_norm[i, 1], color=plt.Set1(label[i]))
23
        plt.xticks([])
24
        plt.yticks([])
25
    # plt.show()
26
    plt.savefig('./test.png', dpi=600)