--- a +++ b/lstm_kmean/train.py @@ -0,0 +1,152 @@ +import tensorflow as tf +import numpy as np +from glob import glob +from natsort import natsorted +import os +import pickle +from model import TripleNet, train_step, test_step +from utils import load_complete_data +from tqdm import tqdm +from sklearn.manifold import TSNE +import matplotlib.pyplot as plt +from matplotlib import style +import seaborn as sns +import pandas as pd +from sklearn.cluster import KMeans + +style.use('seaborn') + +os.environ["CUDA_DEVICE_ORDER"]= "PCI_BUS_ID" +os.environ["CUDA_VISIBLE_DEVICES"]= '3' + +np.random.seed(45) +tf.random.set_seed(45) + +if __name__ == '__main__': + n_channels = 14 + n_feat = 128 + batch_size = 256 + test_batch_size = 1 + n_classes = 10 + + # data_cls = natsorted(glob('data/thoughtviz_eeg_data/*')) + # cls2idx = {key.split(os.path.sep)[-1]:idx for idx, key in enumerate(data_cls, start=0)} + # idx2cls = {value:key for key, value in cls2idx.items()} + + with open('../../data/b2i_data/eeg/image/data.pkl', 'rb') as file: + data = pickle.load(file, encoding='latin1') + train_X = data['x_train'] + train_Y = data['y_train'] + test_X = data['x_test'] + test_Y = data['y_test'] + + + # train_batch = load_complete_data('data/thoughtviz_eeg_data/*/train/*', batch_size=batch_size) + # val_batch = load_complete_data('data/thoughtviz_eeg_data/*/val/*', batch_size=batch_size) + # test_batch = load_complete_data('data/thoughtviz_eeg_data/*/test/*', batch_size=test_batch_size) + train_batch = load_complete_data(train_X, train_Y, batch_size=batch_size) + val_batch = load_complete_data(test_X, test_Y, batch_size=batch_size) + test_batch = load_complete_data(test_X, test_Y, batch_size=test_batch_size) + X, Y = next(iter(train_batch)) + + # print(X.shape, Y.shape) + triplenet = TripleNet(n_classes=n_classes) + opt = tf.keras.optimizers.Adam(learning_rate=3e-4) + triplenet_ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=triplenet, optimizer=opt) + triplenet_ckptman = tf.train.CheckpointManager(triplenet_ckpt, directory='experiments/best_ckpt', max_to_keep=5000) + triplenet_ckpt.restore(triplenet_ckptman.latest_checkpoint) + START = int(triplenet_ckpt.step) // len(train_batch) + if triplenet_ckptman.latest_checkpoint: + print('Restored from the latest checkpoint, epoch: {}'.format(START)) + EPOCHS = 3000 + cfreq = 10 # Checkpoint frequency + + for epoch in range(START, EPOCHS): + train_acc = tf.keras.metrics.SparseCategoricalAccuracy() + train_loss = tf.keras.metrics.Mean() + test_acc = tf.keras.metrics.SparseCategoricalAccuracy() + test_loss = tf.keras.metrics.Mean() + + tq = tqdm(train_batch) + for idx, (X, Y) in enumerate(tq, start=1): + loss = train_step(triplenet, opt, X, Y) + train_loss.update_state(loss) + # Y_cap = triplenet(X, training=False) + # train_acc.update_state(Y, Y_cap) + # tq.set_description('Train Epoch: {}, Loss: {}, Acc: {}'.format(epoch, train_loss.result(), train_acc.result())) + tq.set_description('Train Epoch: {}, Loss: {}'.format(epoch, train_loss.result())) + # break + + tq = tqdm(val_batch) + for idx, (X, Y) in enumerate(tq, start=1): + loss = test_step(triplenet, X, Y) + test_loss.update_state(loss) + # Y_cap = triplenet(X, training=False) + # test_acc.update_state(Y, Y_cap) + # tq.set_description('Test Epoch: {}, Loss: {}'.format(epoch, test_loss.result(), test_acc.result())) + tq.set_description('Test Epoch: {}, Loss: {}'.format(epoch, test_loss.result())) + # break + + triplenet_ckpt.step.assign_add(1) + if (epoch%cfreq) == 0: + triplenet_ckptman.save() + + # kmeanacc = 0.0 + # tq = tqdm(test_batch) + # feat_X = [] + # feat_Y = [] + # for idx, (X, Y) in enumerate(tq, start=1): + # _, feat = triplenet(X, training=False) + # feat_X.extend(feat.numpy()) + # feat_Y.extend(Y.numpy()) + # feat_X = np.array(feat_X) + # feat_Y = np.array(feat_Y) + # print(feat_X.shape, feat_Y.shape) + # # colors = list(plt.cm.get_cmap('viridis', 10)) + # # print(colors) + # # colors = [np.random.rand(3,) for _ in range(10)] + # # print(colors) + # # Y_color = [colors[label] for label in feat_Y] + + # tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=700) + # tsne_results = tsne.fit_transform(feat_X) + # df = pd.DataFrame() + # df['label'] = feat_Y + # df['x1'] = tsne_results[:, 0] + # df['x2'] = tsne_results[:, 1] + # # df['x3'] = tsne_results[:, 2] + # df.to_csv('experiments/inference/triplet_embed2D.csv') + + # # df.to_csv('experiments/triplenet_embed3D.csv') + # # df = pd.read_csv('experiments/triplenet_embed2D.csv') + + # df = pd.read_csv('experiments/inference/triplet_embed2D.csv') + + # plt.figure(figsize=(16,10)) + + # # ax = plt.axes(projection='3d') + # sns.scatterplot( + # x="x1", y="x2", + # data=df, + # hue='label', + # palette=sns.color_palette("hls", n_classes), + # legend="full", + # alpha=0.4 + # ) + + # plt.show() + # # plt.savefig('experiments/inference/{}_embedding.png'.format(epoch)) + + # kmeans = KMeans(n_clusters=n_classes,random_state=45) + # kmeans.fit(feat_X) + # labels = kmeans.labels_ + # # print(feat_Y, labels) + # correct_labels = sum(feat_Y == labels) + # print("Result: %d out of %d samples were correctly labeled." % (correct_labels, feat_Y.shape[0])) + # kmeanacc = correct_labels/float(feat_Y.shape[0]) + # print('Accuracy score: {0:0.2f}'. format(kmeanacc)) + + # with open('experiments/triplenet_log.txt', 'a') as file: + # file.write('E: {}, Train Loss: {}, Test Loss: {}, KM Acc: {}\n'.\ + # format(epoch, train_loss.result(), test_loss.result(), kmeanacc)) + # break \ No newline at end of file