import tensorflow as tf
# from utils import vis, load_batch#, load_data
from utils import load_complete_data, show_batch_images
from model import DCGAN, dist_train_step#, train_step
from tqdm import tqdm
import os
import shutil
import pickle
from glob import glob
from natsort import natsorted
import wandb
import numpy as np
import cv2
from lstm_kmean.model import TripleNet
import math
# from eval_utils import get_inception_score
tf.random.set_seed(45)
np.random.seed(45)
clstoidx = {}
idxtocls = {}
for idx, item in enumerate(natsorted(glob('data/images/test/*')), start=0):
clsname = os.path.basename(item)
clstoidx[clsname] = idx
idxtocls[idx] = clsname
image_paths = natsorted(glob('data/images/test/*/*'))
imgdict = {}
for path in image_paths:
key = path.split(os.path.sep)[-2]
if key in imgdict:
imgdict[key].append(path)
else:
imgdict[key] = [path]
# wandb.init(project='DCGAN_DiffAug_EDDisc_imagenet_128', entity="prajwal_15")
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
os.environ["CUDA_DEVICE_ORDER"]= "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]= '1'
if __name__ == '__main__':
n_channels = 14
n_feat = 128
batch_size = 128
test_batch_size = 1
n_classes = 10
# data_cls = natsorted(glob('data/thoughtviz_eeg_data/*'))
# cls2idx = {key.split(os.path.sep)[-1]:idx for idx, key in enumerate(data_cls, start=0)}
# idx2cls = {value:key for key, value in cls2idx.items()}
with open('data/eeg/image/data.pkl', 'rb') as file:
data = pickle.load(file, encoding='latin1')
train_X = data['x_train']
train_Y = data['y_train']
test_X = data['x_test']
test_Y = data['y_test']
print(test_X.shape, test_Y.shape)
test_path = []
for X, Y in zip(test_X, test_Y):
test_path.append(np.random.choice(imgdict[idxtocls[np.argmax(Y)]], size=(1,) ,replace=True)[0])
test_batch = load_complete_data(test_X, test_Y, test_path, batch_size=test_batch_size)
X, Y, I = next(iter(test_batch))
# latent_label = Y[:16]
print(X.shape, Y.shape, I.shape)
gpus = tf.config.list_physical_devices('GPU')
mirrored_strategy = tf.distribute.MirroredStrategy(devices=['/GPU:1'],
cross_device_ops=tf.distribute.HierarchicalCopyAllReduce())
n_gpus = mirrored_strategy.num_replicas_in_sync
# print(n_gpus)
# batch_size = 64
latent_dim = 128
input_res = 128
# print(latent_Y)
# latent_Y = latent_Y[:16]
# print
triplenet = TripleNet(n_classes=n_classes)
opt = tf.keras.optimizers.Adam(learning_rate=3e-4)
triplenet_ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=triplenet, optimizer=opt)
triplenet_ckptman = tf.train.CheckpointManager(triplenet_ckpt, directory='lstm_kmean/experiments/best_ckpt', max_to_keep=5000)
triplenet_ckpt.restore(triplenet_ckptman.latest_checkpoint)
print('TripletNet restored from the latest checkpoint: {}'.format(triplenet_ckpt.step.numpy()))
_, latent_Y = triplenet(X, training=False)
print('Extracting test eeg features:')
# test_eeg_features = np.array([np.squeeze(triplenet(E, training=False)[1].numpy()) for E, Y, X in tqdm(test_batch)])
# test_eeg_y = np.array([Y.numpy()[0] for E, Y, X in tqdm(test_batch)])
test_image_count = 50000 #// n_classes
# test_labels = np.tile(np.expand_dims(np.arange(0, 10), axis=-1), [1, test_image_count//n_classes])
# test_labels = np.sort(test_labels.ravel())
test_eeg_cls = {}
for E, Y, X in tqdm(test_batch):
Y = Y.numpy()[0]
if Y not in test_eeg_cls:
# print(E.shape)
test_eeg_cls[Y] = [np.squeeze(triplenet(E, training=False)[1].numpy())]
else:
test_eeg_cls[Y].append(np.squeeze(triplenet(E, training=False)[1].numpy()))
for _ in range(n_classes):
test_eeg_cls[_] = np.array(test_eeg_cls[_])
print(test_eeg_cls[_].shape)
for cl in range(n_classes):
N = test_eeg_cls[cl].shape[0]
per_cls_image = int(math.ceil((test_image_count//n_classes) / N))
test_eeg_cls[cl] = np.expand_dims(test_eeg_cls[cl], axis=1)
test_eeg_cls[cl] = np.tile(test_eeg_cls[cl], [1, per_cls_image, 1])
test_eeg_cls[cl] = np.reshape(test_eeg_cls[cl], [-1, latent_dim])
# print(test_eeg_cls[cl].shape)
# test_image_count = test_image_count // n_classes
# print(test_eeg_features.shape, test_eeg_y.shape)
lr = 3e-4
with mirrored_strategy.scope():
model = DCGAN()
model_gopt = tf.keras.optimizers.Adam(learning_rate=lr, beta_1=0.2, beta_2=0.5)
model_copt = tf.keras.optimizers.Adam(learning_rate=lr, beta_1=0.2, beta_2=0.5)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=model, gopt=model_gopt, copt=model_copt)
ckpt_manager = tf.train.CheckpointManager(ckpt, directory='experiments/best_ckpt', max_to_keep=300)
ckpt.restore(ckpt_manager.latest_checkpoint).expect_partial()
# print(ckpt.step.numpy())
START = int(ckpt.step.numpy())# // len(train_batch) + 1
# EPOCHS = 300#670#66
# model_freq = 355#178#355#178#200#40
# t_visfreq = 355#178#355#178#200#1500#40
# latent = tf.random.uniform(shape=(16, latent_dim), minval=-0.2, maxval=0.2)
# latent = tf.concat([latent, latent_Y[:16]], axis=-1)
# print(latent_Y.shape, latent.shape)
if ckpt_manager.latest_checkpoint:
print('Restored from last checkpoint epoch: {0}'.format(START))
for cl in range(n_classes):
test_noise = np.random.uniform(size=(test_eeg_cls[cl].shape[0],128), low=-1, high=1)
noise_lst = np.concatenate([test_noise, test_eeg_cls[cl]], axis=-1)
save_path = 'experiments/finalversion/{}/{}'.format(210, cl)
if not os.path.isdir(save_path):
os.makedirs(save_path)
for idx, noise in enumerate(tqdm(noise_lst)):
X = mirrored_strategy.run(model.gen, args=(tf.expand_dims(noise, axis=0),))
X = cv2.cvtColor(tf.squeeze(X).numpy(), cv2.COLOR_RGB2BGR)
X = np.uint8(np.clip((X*0.5 + 0.5)*255.0, 0, 255))
cv2.imwrite(save_path+'/{}_{}.jpg'.format(cl, idx), X)
# # eeg_feature_vectors_test = np.array([test_eeg_features[np.random.choice(np.where(test_eeg_y == test_label)[0], size=(1,))[0]] for test_label in test_labels])
# # latent_var = np.concatenate([test_noise, eeg_feature_vectors_test], axis=-1)
# # print(test_noise.shape, eeg_feature_vectors_test.shape, latent_var.shape)
# # for idx, noise in enumerate(tqdm(latent_var)):
# # X = mirrored_strategy.run(model.gen, args=(tf.expand_dims(noise, axis=0),))
# # X = cv2.cvtColor(tf.squeeze(X).numpy(), cv2.COLOR_RGB2BGR)
# # X = np.uint8(np.clip((X*0.5 + 0.5)*255.0, 0, 255))
# # cv2.imwrite(save_path+'/{}_{}.jpg'.format(test_labels[idx], idx), X)
# # print(X.shape)
# # break