Diff of /lstm_kmean/train.py [000000] .. [277df6]

Switch to unified view

a b/lstm_kmean/train.py
1
import tensorflow as tf
2
import numpy as np
3
from glob import glob
4
from natsort import natsorted
5
import os
6
import pickle
7
from model import TripleNet, train_step, test_step
8
from utils import load_complete_data
9
from tqdm import tqdm
10
from sklearn.manifold import TSNE
11
import matplotlib.pyplot as plt
12
from matplotlib import style
13
import seaborn as sns
14
import pandas as pd
15
from sklearn.cluster import KMeans
16
17
style.use('seaborn')
18
19
os.environ["CUDA_DEVICE_ORDER"]= "PCI_BUS_ID"
20
os.environ["CUDA_VISIBLE_DEVICES"]= '3'
21
22
np.random.seed(45)
23
tf.random.set_seed(45)
24
25
if __name__ == '__main__':
26
    n_channels  = 14
27
    n_feat      = 128
28
    batch_size  = 256
29
    test_batch_size  = 1
30
    n_classes   = 10
31
32
    # data_cls = natsorted(glob('data/thoughtviz_eeg_data/*'))
33
    # cls2idx  = {key.split(os.path.sep)[-1]:idx for idx, key in enumerate(data_cls, start=0)}
34
    # idx2cls  = {value:key for key, value in cls2idx.items()}
35
36
    with open('../../data/b2i_data/eeg/image/data.pkl', 'rb') as file:
37
        data = pickle.load(file, encoding='latin1')
38
        train_X = data['x_train']
39
        train_Y = data['y_train']
40
        test_X = data['x_test']
41
        test_Y = data['y_test']
42
43
44
    # train_batch = load_complete_data('data/thoughtviz_eeg_data/*/train/*', batch_size=batch_size)
45
    # val_batch   = load_complete_data('data/thoughtviz_eeg_data/*/val/*', batch_size=batch_size)
46
    # test_batch  = load_complete_data('data/thoughtviz_eeg_data/*/test/*', batch_size=test_batch_size)
47
    train_batch = load_complete_data(train_X, train_Y, batch_size=batch_size)
48
    val_batch   = load_complete_data(test_X, test_Y, batch_size=batch_size)
49
    test_batch  = load_complete_data(test_X, test_Y, batch_size=test_batch_size)
50
    X, Y = next(iter(train_batch))
51
52
    # print(X.shape, Y.shape)
53
    triplenet = TripleNet(n_classes=n_classes)
54
    opt     = tf.keras.optimizers.Adam(learning_rate=3e-4)
55
    triplenet_ckpt    = tf.train.Checkpoint(step=tf.Variable(1), model=triplenet, optimizer=opt)
56
    triplenet_ckptman = tf.train.CheckpointManager(triplenet_ckpt, directory='experiments/best_ckpt', max_to_keep=5000)
57
    triplenet_ckpt.restore(triplenet_ckptman.latest_checkpoint)
58
    START = int(triplenet_ckpt.step) // len(train_batch)
59
    if triplenet_ckptman.latest_checkpoint:
60
        print('Restored from the latest checkpoint, epoch: {}'.format(START))
61
    EPOCHS = 3000
62
    cfreq  = 10 # Checkpoint frequency
63
64
    for epoch in range(START, EPOCHS):
65
        train_acc  = tf.keras.metrics.SparseCategoricalAccuracy()
66
        train_loss = tf.keras.metrics.Mean()
67
        test_acc   = tf.keras.metrics.SparseCategoricalAccuracy()
68
        test_loss  = tf.keras.metrics.Mean()
69
70
        tq = tqdm(train_batch)
71
        for idx, (X, Y) in enumerate(tq, start=1):
72
            loss = train_step(triplenet, opt, X, Y)
73
            train_loss.update_state(loss)
74
            # Y_cap = triplenet(X, training=False)
75
            # train_acc.update_state(Y, Y_cap)
76
            # tq.set_description('Train Epoch: {}, Loss: {}, Acc: {}'.format(epoch, train_loss.result(), train_acc.result()))
77
            tq.set_description('Train Epoch: {}, Loss: {}'.format(epoch, train_loss.result()))
78
            # break
79
80
        tq = tqdm(val_batch)
81
        for idx, (X, Y) in enumerate(tq, start=1):
82
            loss = test_step(triplenet, X, Y)
83
            test_loss.update_state(loss)
84
            # Y_cap = triplenet(X, training=False)
85
            # test_acc.update_state(Y, Y_cap)
86
            # tq.set_description('Test Epoch: {}, Loss: {}'.format(epoch, test_loss.result(), test_acc.result()))
87
            tq.set_description('Test Epoch: {}, Loss: {}'.format(epoch, test_loss.result()))
88
            # break
89
   
90
        triplenet_ckpt.step.assign_add(1)
91
        if (epoch%cfreq) == 0:
92
            triplenet_ckptman.save()
93
94
    # kmeanacc = 0.0
95
    # tq = tqdm(test_batch)
96
    # feat_X = []
97
    # feat_Y = []
98
    # for idx, (X, Y) in enumerate(tq, start=1):
99
    #   _, feat = triplenet(X, training=False)
100
    #   feat_X.extend(feat.numpy())
101
    #   feat_Y.extend(Y.numpy())
102
    # feat_X = np.array(feat_X)
103
    # feat_Y = np.array(feat_Y)
104
    # print(feat_X.shape, feat_Y.shape)
105
    # # colors = list(plt.cm.get_cmap('viridis', 10))
106
    # # print(colors)
107
    # # colors  = [np.random.rand(3,) for _ in range(10)]
108
    # # print(colors)
109
    # # Y_color = [colors[label] for label in feat_Y]
110
111
    # tsne = TSNE(n_components=2, verbose=1, perplexity=40, n_iter=700)
112
    # tsne_results = tsne.fit_transform(feat_X)
113
    # df = pd.DataFrame()
114
    # df['label'] = feat_Y
115
    # df['x1'] = tsne_results[:, 0]
116
    # df['x2'] = tsne_results[:, 1]
117
    # # df['x3'] = tsne_results[:, 2]
118
    # df.to_csv('experiments/inference/triplet_embed2D.csv')
119
    
120
    # # df.to_csv('experiments/triplenet_embed3D.csv')
121
    # # df = pd.read_csv('experiments/triplenet_embed2D.csv')
122
    
123
    # df = pd.read_csv('experiments/inference/triplet_embed2D.csv')
124
125
    # plt.figure(figsize=(16,10))
126
    
127
    # # ax = plt.axes(projection='3d')
128
    # sns.scatterplot(
129
    #     x="x1", y="x2",
130
    #     data=df,
131
    #     hue='label',
132
    #     palette=sns.color_palette("hls", n_classes),
133
    #     legend="full",
134
    #     alpha=0.4
135
    #     )
136
137
    # plt.show()
138
    # # plt.savefig('experiments/inference/{}_embedding.png'.format(epoch))
139
140
    # kmeans = KMeans(n_clusters=n_classes,random_state=45)
141
    # kmeans.fit(feat_X)
142
    # labels = kmeans.labels_
143
    # # print(feat_Y, labels)
144
    # correct_labels = sum(feat_Y == labels)
145
    # print("Result: %d out of %d samples were correctly labeled." % (correct_labels, feat_Y.shape[0]))
146
    # kmeanacc = correct_labels/float(feat_Y.shape[0])
147
    # print('Accuracy score: {0:0.2f}'. format(kmeanacc))
148
149
    # with open('experiments/triplenet_log.txt', 'a') as file:
150
    #   file.write('E: {}, Train Loss: {}, Test Loss: {}, KM Acc: {}\n'.\
151
    #       format(epoch, train_loss.result(), test_loss.result(), kmeanacc))
152
    # break