|
a |
|
b/lstm_kmean/inference.py |
|
|
1 |
import tensorflow as tf |
|
|
2 |
import numpy as np |
|
|
3 |
from glob import glob |
|
|
4 |
from natsort import natsorted |
|
|
5 |
import os |
|
|
6 |
from model import TripleNet, train_step, test_step |
|
|
7 |
from utils import load_complete_data |
|
|
8 |
from tqdm import tqdm |
|
|
9 |
from sklearn.manifold import TSNE |
|
|
10 |
import matplotlib.pyplot as plt |
|
|
11 |
from matplotlib import style |
|
|
12 |
import seaborn as sns |
|
|
13 |
import pandas as pd |
|
|
14 |
import pickle |
|
|
15 |
from sklearn.cluster import KMeans |
|
|
16 |
from scipy.optimize import linear_sum_assignment as linear_assignment |
|
|
17 |
|
|
|
18 |
style.use('seaborn') |
|
|
19 |
|
|
|
20 |
os.environ["CUDA_DEVICE_ORDER"]= "PCI_BUS_ID" |
|
|
21 |
os.environ["CUDA_VISIBLE_DEVICES"]= '0' |
|
|
22 |
|
|
|
23 |
|
|
|
24 |
# Thanks to: https://github.com/k-han/DTC/blob/master/utils/util.py |
|
|
25 |
def cluster_acc(y_true, y_pred): |
|
|
26 |
""" |
|
|
27 |
Calculate clustering accuracy. Require scikit-learn installed |
|
|
28 |
# Arguments |
|
|
29 |
y: true labels, numpy.array with shape `(n_samples,)` |
|
|
30 |
y_pred: predicted labels, numpy.array with shape `(n_samples,)` |
|
|
31 |
# Return |
|
|
32 |
accuracy, in [0,1] |
|
|
33 |
""" |
|
|
34 |
y_true = y_true.astype(np.int64) |
|
|
35 |
assert y_pred.size == y_true.size |
|
|
36 |
D = max(y_pred.max(), y_true.max()) + 1 |
|
|
37 |
w = np.zeros((D, D), dtype=np.int64) |
|
|
38 |
for i in range(y_pred.size): |
|
|
39 |
w[y_pred[i], y_true[i]] += 1 |
|
|
40 |
ind = linear_assignment(w.max() - w) |
|
|
41 |
return sum([w[i, j] for i, j in zip(*ind)]) * 1.0 / y_pred.size |
|
|
42 |
|
|
|
43 |
if __name__ == '__main__': |
|
|
44 |
|
|
|
45 |
n_channels = 14 |
|
|
46 |
n_feat = 128 |
|
|
47 |
batch_size = 256 |
|
|
48 |
test_batch_size = 256 |
|
|
49 |
n_classes = 10 |
|
|
50 |
|
|
|
51 |
# data_cls = natsorted(glob('data/thoughtviz_eeg_data/*')) |
|
|
52 |
# cls2idx = {key.split(os.path.sep)[-1]:idx for idx, key in enumerate(data_cls, start=0)} |
|
|
53 |
# idx2cls = {value:key for key, value in cls2idx.items()} |
|
|
54 |
|
|
|
55 |
with open('../../data/b2i_data/eeg/image/data.pkl', 'rb') as file: |
|
|
56 |
data = pickle.load(file, encoding='latin1') |
|
|
57 |
train_X = data['x_train'] |
|
|
58 |
train_Y = data['y_train'] |
|
|
59 |
test_X = data['x_test'] |
|
|
60 |
test_Y = data['y_test'] |
|
|
61 |
|
|
|
62 |
|
|
|
63 |
# train_batch = load_complete_data('data/thoughtviz_eeg_data/*/train/*', batch_size=batch_size) |
|
|
64 |
# val_batch = load_complete_data('data/thoughtviz_eeg_data/*/val/*', batch_size=batch_size) |
|
|
65 |
# test_batch = load_complete_data('data/thoughtviz_eeg_data/*/test/*', batch_size=test_batch_size) |
|
|
66 |
train_batch = load_complete_data(train_X, train_Y, batch_size=batch_size) |
|
|
67 |
val_batch = load_complete_data(test_X, test_Y, batch_size=batch_size) |
|
|
68 |
test_batch = load_complete_data(test_X, test_Y, batch_size=test_batch_size) |
|
|
69 |
# X, Y = next(iter(train_batch)) |
|
|
70 |
# print(X.shape, Y.shape) |
|
|
71 |
triplenet = TripleNet(n_classes=n_classes) |
|
|
72 |
opt = tf.keras.optimizers.Adam(learning_rate=3e-4) |
|
|
73 |
triplenet_ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=triplenet, optimizer=opt) |
|
|
74 |
triplenet_ckpt.restore('experiments/best_ckpt/ckpt-89') |
|
|
75 |
|
|
|
76 |
tq = tqdm(test_batch) |
|
|
77 |
feat_X = np.array([]) |
|
|
78 |
feat_Y = np.array([]) |
|
|
79 |
for idx, (X, Y) in enumerate(tq, start=1): |
|
|
80 |
_, feat = triplenet(X, training=False) |
|
|
81 |
feat_X = np.concatenate((feat_X, feat.numpy()), axis=0) if feat_X.size else feat.numpy() |
|
|
82 |
feat_Y = np.concatenate((feat_Y, Y.numpy()), axis=0) if feat_Y.size else Y.numpy() |
|
|
83 |
|
|
|
84 |
print(feat_X.shape, feat_Y.shape) |
|
|
85 |
# colors = list(plt.cm.get_cmap('viridis', 10)) |
|
|
86 |
# print(colors) |
|
|
87 |
# colors = [np.random.rand(3,) for _ in range(10)] |
|
|
88 |
# print(colors) |
|
|
89 |
# Y_color = [colors[label] for label in feat_Y] |
|
|
90 |
kmeans = KMeans(n_clusters=n_classes,random_state=45) |
|
|
91 |
kmeans.fit(feat_X) |
|
|
92 |
labels = kmeans.labels_ |
|
|
93 |
kmeanacc = cluster_acc(feat_Y, labels) |
|
|
94 |
# correct_labels = sum(feat_Y == labels) |
|
|
95 |
# print("Result: %d out of %d samples were correctly labeled." % (correct_labels, feat_Y.shape[0])) |
|
|
96 |
# kmeanacc = correct_labels/float(feat_Y.shape[0]) |
|
|
97 |
print('Accuracy score: {0:0.2f}'. format(kmeanacc)) |
|
|
98 |
|
|
|
99 |
tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=700) |
|
|
100 |
tsne_results = tsne.fit_transform(feat_X) |
|
|
101 |
df = pd.DataFrame() |
|
|
102 |
df['label'] = feat_Y |
|
|
103 |
df['x1'] = tsne_results[:, 0] |
|
|
104 |
df['x2'] = tsne_results[:, 1] |
|
|
105 |
# df['x3'] = tsne_results[:, 2] |
|
|
106 |
df.to_csv('experiments/infer_triplet_embed2D.csv') |
|
|
107 |
# df.to_csv('experiments/triplenet_embed3D.csv') |
|
|
108 |
# df = pd.read_csv('experiments/triplenet_embed2D.csv') |
|
|
109 |
|
|
|
110 |
df = pd.read_csv('experiments/infer_triplet_embed2D.csv') |
|
|
111 |
|
|
|
112 |
plt.figure(figsize=(16,10)) |
|
|
113 |
|
|
|
114 |
# ax = plt.axes(projection='3d') |
|
|
115 |
sns.scatterplot( |
|
|
116 |
x="x1", y="x2", |
|
|
117 |
data=df, |
|
|
118 |
hue='label', |
|
|
119 |
palette=sns.color_palette("hls", n_classes), |
|
|
120 |
legend="full", |
|
|
121 |
alpha=0.4 |
|
|
122 |
) |
|
|
123 |
# ax.scatter3D(df['x1'], df['x2'], df['x3'], c=df['x3'], alpha=0.4) |
|
|
124 |
# plt.scatter(df['x1'], df['x2'], c=df['x2'], alpha=0.4) |
|
|
125 |
# min_x, max_x = np.min(feat_X), np.max(feat_X) |
|
|
126 |
# min_x, max_x = -10, 10 |
|
|
127 |
|
|
|
128 |
# for c in range(len(np.unique(feat_Y))): |
|
|
129 |
# # ax.scatter(feat_X[feat_Y==c, 0], feat_X[feat_Y==c, 1], feat_X[feat_Y==c, 2], '.', alpha=0.5, c=colors[c], s=0.3) |
|
|
130 |
# plt.scatter(feat_X[feat_Y==c, 0], feat_X[feat_Y==c, 1], marker='.', alpha=0.5, c=colors[c], s=1.0) |
|
|
131 |
# plt.title('Triple Loss') |
|
|
132 |
|
|
|
133 |
# W = triplenet.cls_layer.get_weights()[0].T |
|
|
134 |
|
|
|
135 |
# x = np.linspace(min_x, max_x, 50) |
|
|
136 |
# y = W[0][1]*x + W[0][0] |
|
|
137 |
# plt.plot(x, y, c=colors[0]) |
|
|
138 |
|
|
|
139 |
# x = np.linspace(min_x, max_x, 50) |
|
|
140 |
# y = W[1][1]*x + W[1][0] |
|
|
141 |
# plt.plot(x, y, c=colors[1]) |
|
|
142 |
|
|
|
143 |
# x = np.linspace(min_x, max_x, 50) |
|
|
144 |
# y = W[2][1]*x + W[2][0] |
|
|
145 |
# plt.plot(x, y, c=colors[2]) |
|
|
146 |
|
|
|
147 |
# x = np.linspace(min_x, max_x, 50) |
|
|
148 |
# y = W[3][1]*x + W[3][0] |
|
|
149 |
# plt.plot(x, y, c=colors[3]) |
|
|
150 |
|
|
|
151 |
# x = np.linspace(min_x, max_x, 50) |
|
|
152 |
# y = W[4][1]*x + W[4][0] |
|
|
153 |
# plt.plot(x, y, c=colors[4]) |
|
|
154 |
# plt.clf() |
|
|
155 |
# plt.close() |
|
|
156 |
# featX = df[['x1', 'x2']].to_numpy() |
|
|
157 |
# print(featX.shape) |
|
|
158 |
|
|
|
159 |
plt.title('k-means accuracy: {}%'.format(kmeanacc*100)) |
|
|
160 |
plt.savefig('experiments/embedding.png') |
|
|
161 |
# plt.show() |