|
a |
|
b/visualization/tSNE.py |
|
|
1 |
# use t-SNE to show the feature distribution |
|
|
2 |
|
|
|
3 |
import numpy as np |
|
|
4 |
import matplotlib.pyplot as plt |
|
|
5 |
from sklearn import manifold, datasets |
|
|
6 |
from einops import reduce |
|
|
7 |
import scipy.io |
|
|
8 |
import torch |
|
|
9 |
|
|
|
10 |
def plt_tsne(data, label, per): |
|
|
11 |
data = data.cpu().detach().numpy() |
|
|
12 |
data = reduce(data, 'b n e -> b e', reduction='mean') |
|
|
13 |
label = label.cpu().detach().numpy() |
|
|
14 |
|
|
|
15 |
tsne = manifold.TSNE(n_components=2, perplexity=per, init='pca', random_state=166) |
|
|
16 |
X_tsne = tsne.fit_transform(data) |
|
|
17 |
|
|
|
18 |
x_min, x_max = X_tsne.min(0), X_tsne.max(0) |
|
|
19 |
X_norm = (X_tsne - x_min) / (x_max - x_min) |
|
|
20 |
plt.figure() |
|
|
21 |
for i in range(X_norm.shape[0]): |
|
|
22 |
plt.scatter(X_norm[i, 0], X_norm[i, 1], color=plt.Set1(label[i])) |
|
|
23 |
plt.xticks([]) |
|
|
24 |
plt.yticks([]) |
|
|
25 |
# plt.show() |
|
|
26 |
plt.savefig('./test.png', dpi=600) |