--- a +++ b/not_use_dummy_inference_inception.py @@ -0,0 +1,209 @@ +import tensorflow as tf +# from utils import vis, load_batch#, load_data +from utils import load_complete_data, show_batch_images +from model import DCGAN, dist_train_step#, train_step +from tqdm import tqdm +import os +import shutil +import pickle +from glob import glob +from natsort import natsorted +import wandb +import numpy as np +import cv2 +from lstm_kmean.model import TripleNet +import math +# from eval_utils import get_inception_score +tf.random.set_seed(45) +np.random.seed(45) + +clstoidx = {} +idxtocls = {} + +for idx, item in enumerate(natsorted(glob('data/images/train/*')), start=0): + clsname = os.path.basename(item) + clstoidx[clsname] = idx + idxtocls[idx] = clsname + +image_paths = natsorted(glob('data/images/train/*/*')) +imgdict = {} +for path in image_paths: + key = path.split(os.path.sep)[-2] + if key in imgdict: + imgdict[key].append(path) + else: + imgdict[key] = [path] + +# wandb.init(project='DCGAN_DiffAug_EDDisc_imagenet_128', entity="prajwal_15") + +os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices' +os.environ["CUDA_DEVICE_ORDER"]= "PCI_BUS_ID" +os.environ["CUDA_VISIBLE_DEVICES"]= '0' + +if __name__ == '__main__': + + n_channels = 14 + n_feat = 128 + batch_size = 128 + test_batch_size = 1 + n_classes = 10 + + # data_cls = natsorted(glob('data/thoughtviz_eeg_data/*')) + # cls2idx = {key.split(os.path.sep)[-1]:idx for idx, key in enumerate(data_cls, start=0)} + # idx2cls = {value:key for key, value in cls2idx.items()} + + with open('data/eeg/image/data.pkl', 'rb') as file: + data = pickle.load(file, encoding='latin1') + train_X = data['x_train'] + train_Y = data['y_train'] + test_X = data['x_test'] + test_Y = data['y_test'] + + train_path = [] + for X, Y in zip(train_X, train_Y): + train_path.append(np.random.choice(imgdict[idxtocls[np.argmax(Y)]], size=(1,) ,replace=True)[0]) + + test_path = [] + for X, Y in zip(test_X, test_Y): + test_path.append(np.random.choice(imgdict[idxtocls[np.argmax(Y)]], size=(1,) ,replace=True)[0]) + + train_batch = load_complete_data(train_X, train_Y, train_path, batch_size=batch_size) + test_batch = load_complete_data(test_X, test_Y, test_path, batch_size=test_batch_size) + X, Y, I = next(iter(train_batch)) + latent_label = Y[:16] + print(X.shape, Y.shape, I.shape) + + gpus = tf.config.list_physical_devices('GPU') + mirrored_strategy = tf.distribute.MirroredStrategy(devices=['/GPU:0'], + cross_device_ops=tf.distribute.HierarchicalCopyAllReduce()) + n_gpus = mirrored_strategy.num_replicas_in_sync + # print(n_gpus) + + # batch_size = 64 + latent_dim = 128 + input_res = 128 + + # print(latent_Y) + # latent_Y = latent_Y[:16] + # print + + triplenet = TripleNet(n_classes=n_classes) + opt = tf.keras.optimizers.Adam(learning_rate=3e-4) + triplenet_ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=triplenet, optimizer=opt) + triplenet_ckptman = tf.train.CheckpointManager(triplenet_ckpt, directory='lstm_kmean/experiments/best_ckpt', max_to_keep=5000) + triplenet_ckpt.restore(triplenet_ckptman.latest_checkpoint) + print('TripletNet restored from the latest checkpoint: {}'.format(triplenet_ckpt.step.numpy())) + _, latent_Y = triplenet(X, training=False) + + print('Extracting test eeg features:') + # test_eeg_features = np.array([np.squeeze(triplenet(E, training=False)[1].numpy()) for E, Y, X in tqdm(test_batch)]) + # test_eeg_y = np.array([Y.numpy()[0] for E, Y, X in tqdm(test_batch)]) + test_image_count = 500000 #// n_classes + # test_labels = np.tile(np.expand_dims(np.arange(0, 10), axis=-1), [1, test_image_count//n_classes]) + # test_labels = np.sort(test_labels.ravel()) + + test_eeg_cls = {} + for E, Y, X in tqdm(test_batch): + Y = Y.numpy()[0] + if Y not in test_eeg_cls: + test_eeg_cls[Y] = [np.squeeze(triplenet(E, training=False)[1].numpy())] + else: + test_eeg_cls[Y].append(np.squeeze(triplenet(E, training=False)[1].numpy())) + + for _ in range(n_classes): + test_eeg_cls[_] = np.array(test_eeg_cls[_]) + print(test_eeg_cls[_].shape) + + for cl in range(n_classes): + N = test_eeg_cls[cl].shape[0] + per_cls_image = int(math.ceil((test_image_count//n_classes) / N)) + test_eeg_cls[cl] = np.expand_dims(test_eeg_cls[cl], axis=1) + test_eeg_cls[cl] = np.tile(test_eeg_cls[cl], [1, per_cls_image, 1]) + test_eeg_cls[cl] = np.reshape(test_eeg_cls[cl], [-1, latent_dim]) + print(test_eeg_cls[cl].shape) + + # test_image_count = test_image_count // n_classes + # print(test_eeg_features.shape, test_eeg_y.shape) + + lr = 3e-4 + with mirrored_strategy.scope(): + model = DCGAN() + model_gopt = tf.keras.optimizers.Adam(learning_rate=lr, beta_1=0.2, beta_2=0.5) + model_copt = tf.keras.optimizers.Adam(learning_rate=lr, beta_1=0.2, beta_2=0.5) + ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=model, gopt=model_gopt, copt=model_copt) + ckpt_manager = tf.train.CheckpointManager(ckpt, directory='experiments/best_ckpt', max_to_keep=300) + ckpt.restore(ckpt_manager.latest_checkpoint).expect_partial() + + # print(ckpt.step.numpy()) + START = int(ckpt.step.numpy()) // len(train_batch) + 1 + EPOCHS = 300#670#66 + model_freq = 355#178#355#178#200#40 + t_visfreq = 355#178#355#178#200#1500#40 + latent = tf.random.uniform(shape=(16, latent_dim), minval=-0.2, maxval=0.2) + latent = tf.concat([latent, latent_Y[:16]], axis=-1) + print(latent_Y.shape, latent.shape) + + if ckpt_manager.latest_checkpoint: + print('Restored from last checkpoint epoch: {0}'.format(START)) + + if not os.path.isdir('experiments/results'): + os.makedirs('experiments/results') + + for epoch in range(START, EPOCHS): + t_gloss = tf.keras.metrics.Mean() + t_closs = tf.keras.metrics.Mean() + + # tq = tqdm(train_batch) + # for idx, (E, Y, X) in enumerate(tq, start=1): + # batch_size = X.shape[0] + # _, C = triplenet(E, training=False) + # gloss, closs = dist_train_step(mirrored_strategy, model, model_gopt, model_copt, X, C, latent_dim, batch_size) + # gloss = tf.reduce_mean(gloss) + # closs = tf.reduce_mean(closs) + # t_gloss.update_state(gloss) + # t_closs.update_state(closs) + # ckpt.step.assign_add(1) + # if (idx%model_freq)==0: + # ckpt_manager.save() + # if (idx%t_visfreq)==0: + # # latent_c = tf.concat([latent, C[:16]], axis=-1) + # X = mirrored_strategy.run(model.gen, args=(latent,)) + # # X = X.values[0] + # print(X.shape, latent_label.shape) + # show_batch_images(X, save_path='experiments/results/{}.png'.format(int(ckpt.step.numpy())), Y=latent_label) + + # tq.set_description('E: {}, gl: {:0.3f}, cl: {:0.3f}'.format(epoch, t_gloss.result(), t_closs.result())) + # # break + + # with open('experiments/log.txt', 'a') as file: + # file.write('Epoch: {0}\tT_gloss: {1}\tT_closs: {2}\n'.format(epoch, t_gloss.result(), t_closs.result())) + # print('Epoch: {0}\tT_gloss: {1}\tT_closs: {2}'.format(epoch, t_gloss.result(), t_closs.result())) + + + if (epoch%1)==0: + save_path = 'experiments/inference_inception/{}'.format(epoch) + + if not os.path.isdir(save_path): + os.makedirs(save_path) + + for cl in range(n_classes): + test_noise = np.random.uniform(size=(test_eeg_cls[cl].shape[0],128), low=-1, high=1) + noise_lst = np.concatenate([test_noise, test_eeg_cls[cl]], axis=-1) + + for idx, noise in enumerate(tqdm(noise_lst)): + X = mirrored_strategy.run(model.gen, args=(tf.expand_dims(noise, axis=0),)) + X = cv2.cvtColor(tf.squeeze(X).numpy(), cv2.COLOR_RGB2BGR) + X = np.uint8(np.clip((X*0.5 + 0.5)*255.0, 0, 255)) + cv2.imwrite(save_path+'/{}_{}.jpg'.format(cl, idx), X) + break + + # eeg_feature_vectors_test = np.array([test_eeg_features[np.random.choice(np.where(test_eeg_y == test_label)[0], size=(1,))[0]] for test_label in test_labels]) + # latent_var = np.concatenate([test_noise, eeg_feature_vectors_test], axis=-1) + # print(test_noise.shape, eeg_feature_vectors_test.shape, latent_var.shape) + # for idx, noise in enumerate(tqdm(latent_var)): + # X = mirrored_strategy.run(model.gen, args=(tf.expand_dims(noise, axis=0),)) + # X = cv2.cvtColor(tf.squeeze(X).numpy(), cv2.COLOR_RGB2BGR) + # X = np.uint8(np.clip((X*0.5 + 0.5)*255.0, 0, 255)) + # cv2.imwrite(save_path+'/{}_{}.jpg'.format(test_labels[idx], idx), X) + # print(X.shape) + # break