|
a |
|
b/lstm_kmean/train.py |
|
|
1 |
import tensorflow as tf |
|
|
2 |
import numpy as np |
|
|
3 |
from glob import glob |
|
|
4 |
from natsort import natsorted |
|
|
5 |
import os |
|
|
6 |
import pickle |
|
|
7 |
from model import TripleNet, train_step, test_step |
|
|
8 |
from utils import load_complete_data |
|
|
9 |
from tqdm import tqdm |
|
|
10 |
from sklearn.manifold import TSNE |
|
|
11 |
import matplotlib.pyplot as plt |
|
|
12 |
from matplotlib import style |
|
|
13 |
import seaborn as sns |
|
|
14 |
import pandas as pd |
|
|
15 |
from sklearn.cluster import KMeans |
|
|
16 |
|
|
|
17 |
style.use('seaborn') |
|
|
18 |
|
|
|
19 |
os.environ["CUDA_DEVICE_ORDER"]= "PCI_BUS_ID" |
|
|
20 |
os.environ["CUDA_VISIBLE_DEVICES"]= '3' |
|
|
21 |
|
|
|
22 |
np.random.seed(45) |
|
|
23 |
tf.random.set_seed(45) |
|
|
24 |
|
|
|
25 |
if __name__ == '__main__': |
|
|
26 |
n_channels = 14 |
|
|
27 |
n_feat = 128 |
|
|
28 |
batch_size = 256 |
|
|
29 |
test_batch_size = 1 |
|
|
30 |
n_classes = 10 |
|
|
31 |
|
|
|
32 |
# data_cls = natsorted(glob('data/thoughtviz_eeg_data/*')) |
|
|
33 |
# cls2idx = {key.split(os.path.sep)[-1]:idx for idx, key in enumerate(data_cls, start=0)} |
|
|
34 |
# idx2cls = {value:key for key, value in cls2idx.items()} |
|
|
35 |
|
|
|
36 |
with open('../../data/b2i_data/eeg/image/data.pkl', 'rb') as file: |
|
|
37 |
data = pickle.load(file, encoding='latin1') |
|
|
38 |
train_X = data['x_train'] |
|
|
39 |
train_Y = data['y_train'] |
|
|
40 |
test_X = data['x_test'] |
|
|
41 |
test_Y = data['y_test'] |
|
|
42 |
|
|
|
43 |
|
|
|
44 |
# train_batch = load_complete_data('data/thoughtviz_eeg_data/*/train/*', batch_size=batch_size) |
|
|
45 |
# val_batch = load_complete_data('data/thoughtviz_eeg_data/*/val/*', batch_size=batch_size) |
|
|
46 |
# test_batch = load_complete_data('data/thoughtviz_eeg_data/*/test/*', batch_size=test_batch_size) |
|
|
47 |
train_batch = load_complete_data(train_X, train_Y, batch_size=batch_size) |
|
|
48 |
val_batch = load_complete_data(test_X, test_Y, batch_size=batch_size) |
|
|
49 |
test_batch = load_complete_data(test_X, test_Y, batch_size=test_batch_size) |
|
|
50 |
X, Y = next(iter(train_batch)) |
|
|
51 |
|
|
|
52 |
# print(X.shape, Y.shape) |
|
|
53 |
triplenet = TripleNet(n_classes=n_classes) |
|
|
54 |
opt = tf.keras.optimizers.Adam(learning_rate=3e-4) |
|
|
55 |
triplenet_ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=triplenet, optimizer=opt) |
|
|
56 |
triplenet_ckptman = tf.train.CheckpointManager(triplenet_ckpt, directory='experiments/best_ckpt', max_to_keep=5000) |
|
|
57 |
triplenet_ckpt.restore(triplenet_ckptman.latest_checkpoint) |
|
|
58 |
START = int(triplenet_ckpt.step) // len(train_batch) |
|
|
59 |
if triplenet_ckptman.latest_checkpoint: |
|
|
60 |
print('Restored from the latest checkpoint, epoch: {}'.format(START)) |
|
|
61 |
EPOCHS = 3000 |
|
|
62 |
cfreq = 10 # Checkpoint frequency |
|
|
63 |
|
|
|
64 |
for epoch in range(START, EPOCHS): |
|
|
65 |
train_acc = tf.keras.metrics.SparseCategoricalAccuracy() |
|
|
66 |
train_loss = tf.keras.metrics.Mean() |
|
|
67 |
test_acc = tf.keras.metrics.SparseCategoricalAccuracy() |
|
|
68 |
test_loss = tf.keras.metrics.Mean() |
|
|
69 |
|
|
|
70 |
tq = tqdm(train_batch) |
|
|
71 |
for idx, (X, Y) in enumerate(tq, start=1): |
|
|
72 |
loss = train_step(triplenet, opt, X, Y) |
|
|
73 |
train_loss.update_state(loss) |
|
|
74 |
# Y_cap = triplenet(X, training=False) |
|
|
75 |
# train_acc.update_state(Y, Y_cap) |
|
|
76 |
# tq.set_description('Train Epoch: {}, Loss: {}, Acc: {}'.format(epoch, train_loss.result(), train_acc.result())) |
|
|
77 |
tq.set_description('Train Epoch: {}, Loss: {}'.format(epoch, train_loss.result())) |
|
|
78 |
# break |
|
|
79 |
|
|
|
80 |
tq = tqdm(val_batch) |
|
|
81 |
for idx, (X, Y) in enumerate(tq, start=1): |
|
|
82 |
loss = test_step(triplenet, X, Y) |
|
|
83 |
test_loss.update_state(loss) |
|
|
84 |
# Y_cap = triplenet(X, training=False) |
|
|
85 |
# test_acc.update_state(Y, Y_cap) |
|
|
86 |
# tq.set_description('Test Epoch: {}, Loss: {}'.format(epoch, test_loss.result(), test_acc.result())) |
|
|
87 |
tq.set_description('Test Epoch: {}, Loss: {}'.format(epoch, test_loss.result())) |
|
|
88 |
# break |
|
|
89 |
|
|
|
90 |
triplenet_ckpt.step.assign_add(1) |
|
|
91 |
if (epoch%cfreq) == 0: |
|
|
92 |
triplenet_ckptman.save() |
|
|
93 |
|
|
|
94 |
# kmeanacc = 0.0 |
|
|
95 |
# tq = tqdm(test_batch) |
|
|
96 |
# feat_X = [] |
|
|
97 |
# feat_Y = [] |
|
|
98 |
# for idx, (X, Y) in enumerate(tq, start=1): |
|
|
99 |
# _, feat = triplenet(X, training=False) |
|
|
100 |
# feat_X.extend(feat.numpy()) |
|
|
101 |
# feat_Y.extend(Y.numpy()) |
|
|
102 |
# feat_X = np.array(feat_X) |
|
|
103 |
# feat_Y = np.array(feat_Y) |
|
|
104 |
# print(feat_X.shape, feat_Y.shape) |
|
|
105 |
# # colors = list(plt.cm.get_cmap('viridis', 10)) |
|
|
106 |
# # print(colors) |
|
|
107 |
# # colors = [np.random.rand(3,) for _ in range(10)] |
|
|
108 |
# # print(colors) |
|
|
109 |
# # Y_color = [colors[label] for label in feat_Y] |
|
|
110 |
|
|
|
111 |
# tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=700) |
|
|
112 |
# tsne_results = tsne.fit_transform(feat_X) |
|
|
113 |
# df = pd.DataFrame() |
|
|
114 |
# df['label'] = feat_Y |
|
|
115 |
# df['x1'] = tsne_results[:, 0] |
|
|
116 |
# df['x2'] = tsne_results[:, 1] |
|
|
117 |
# # df['x3'] = tsne_results[:, 2] |
|
|
118 |
# df.to_csv('experiments/inference/triplet_embed2D.csv') |
|
|
119 |
|
|
|
120 |
# # df.to_csv('experiments/triplenet_embed3D.csv') |
|
|
121 |
# # df = pd.read_csv('experiments/triplenet_embed2D.csv') |
|
|
122 |
|
|
|
123 |
# df = pd.read_csv('experiments/inference/triplet_embed2D.csv') |
|
|
124 |
|
|
|
125 |
# plt.figure(figsize=(16,10)) |
|
|
126 |
|
|
|
127 |
# # ax = plt.axes(projection='3d') |
|
|
128 |
# sns.scatterplot( |
|
|
129 |
# x="x1", y="x2", |
|
|
130 |
# data=df, |
|
|
131 |
# hue='label', |
|
|
132 |
# palette=sns.color_palette("hls", n_classes), |
|
|
133 |
# legend="full", |
|
|
134 |
# alpha=0.4 |
|
|
135 |
# ) |
|
|
136 |
|
|
|
137 |
# plt.show() |
|
|
138 |
# # plt.savefig('experiments/inference/{}_embedding.png'.format(epoch)) |
|
|
139 |
|
|
|
140 |
# kmeans = KMeans(n_clusters=n_classes,random_state=45) |
|
|
141 |
# kmeans.fit(feat_X) |
|
|
142 |
# labels = kmeans.labels_ |
|
|
143 |
# # print(feat_Y, labels) |
|
|
144 |
# correct_labels = sum(feat_Y == labels) |
|
|
145 |
# print("Result: %d out of %d samples were correctly labeled." % (correct_labels, feat_Y.shape[0])) |
|
|
146 |
# kmeanacc = correct_labels/float(feat_Y.shape[0]) |
|
|
147 |
# print('Accuracy score: {0:0.2f}'. format(kmeanacc)) |
|
|
148 |
|
|
|
149 |
# with open('experiments/triplenet_log.txt', 'a') as file: |
|
|
150 |
# file.write('E: {}, Train Loss: {}, Test Loss: {}, KM Acc: {}\n'.\ |
|
|
151 |
# format(epoch, train_loss.result(), test_loss.result(), kmeanacc)) |
|
|
152 |
# break |