Diff of /inference.py [000000] .. [277df6]

Switch to side-by-side view

--- a
+++ b/inference.py
@@ -0,0 +1,97 @@
+import tensorflow as tf
+# from utils import vis, load_batch#, load_data
+from utils import load_complete_data, show_batch_images
+from model import DCGAN, dist_train_step#, train_step
+from tqdm import tqdm
+import os
+import shutil
+import pickle
+from glob import glob
+from natsort import natsorted
+import wandb
+import numpy as np
+import cv2
+
+tf.random.set_seed(45)
+np.random.seed(45)
+# wandb.init(project='DCGAN_DiffAug_EDDisc_imagenet_128', entity="prajwal_15")
+
+os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
+os.environ["CUDA_DEVICE_ORDER"]= "PCI_BUS_ID"
+os.environ["CUDA_VISIBLE_DEVICES"]= '1'
+clstoidx   = {}
+idxtocls   = {}
+
+
+# @tf.function
+def get_code(path):
+	path  = path.numpy().decode('utf-8')
+	code  = np.zeros(shape=(max(clstoidx.values())+1,), dtype=np.float32)
+	code[clstoidx[path.split(sep='/')[-2]]] = 1
+	return tf.cast(code, dtype=tf.float32)
+
+
+if __name__ == '__main__':
+
+	# if len(glob('experiments/*'))==0:
+	# 	os.makedirs('experiments/experiment_1/code/')
+	# 	exp_num = 1
+	# else:
+	# 	exp_num = len(glob('experiments/*'))+1
+	# 	os.makedirs('experiments/experiment_{}/code/'.format(exp_num))
+
+	# exp_dir = 'experiments/experiment_{}'.format(exp_num)
+	# for item in glob('*.py'):
+	# 	shutil.copy(item, exp_dir+'/code')
+	
+	gpus = tf.config.list_physical_devices('GPU')
+	mirrored_strategy = tf.distribute.MirroredStrategy(devices=['/GPU:0'], 
+		cross_device_ops=tf.distribute.HierarchicalCopyAllReduce())
+	n_gpus = mirrored_strategy.num_replicas_in_sync
+	# print(n_gpus)
+
+	batch_size = 64
+	latent_dim = 128
+	input_res  = 64
+	data_path  = '../data/b2i_data/thoughtviz_eeg_data/*/*'
+
+	# train_batch = load_complete_data(data_path, input_res=input_res, batch_size=batch_size)
+	train_batch = load_complete_data(data_path, batch_size=batch_size)
+	X, latent_Y = next(iter(train_batch))
+	# print(latent_Y)
+	latent_Y = latent_Y[:16]
+	lr = 3e-4
+	with mirrored_strategy.scope():
+		model        = DCGAN()
+		model_gopt   = tf.keras.optimizers.Adam(learning_rate=lr, beta_1=0.2, beta_2=0.5)
+		model_copt   = tf.keras.optimizers.Adam(learning_rate=lr, beta_1=0.2, beta_2=0.5)
+		ckpt         = tf.train.Checkpoint(step=tf.Variable(1), model=model, gopt=model_gopt, copt=model_copt)
+		ckpt_manager = tf.train.CheckpointManager(ckpt, directory='experiments/best_ckpt', max_to_keep=30)
+		ckpt.restore(ckpt_manager.latest_checkpoint).expect_partial()
+
+	# print(ckpt.step.numpy())
+	START         = int(ckpt.step.numpy()) // len(train_batch) + 1
+	EPOCHS        = 1000#670#66
+	model_freq    = 14#200#40
+	t_visfreq     = 14#200#1500#40
+	
+	if ckpt_manager.latest_checkpoint:
+		print('Restored from last checkpoint epoch: {0}'.format(START))
+
+	for clidx in tqdm(range(10)):
+		code = np.zeros(shape=(10,), dtype=np.float32)
+		code[clidx] = 1
+		code = np.expand_dims(code, axis=0)
+		code = tf.cast(code, dtype=tf.float32)
+		
+		if not os.path.isdir('experiments/inference_result/{}'.format(clidx)):
+			os.makedirs('experiments/inference_result/{}'.format(clidx))
+
+		for _ in tqdm(range(256)):
+			latent = tf.random.uniform(shape=(1, latent_dim), minval=-1, maxval=1)
+			latent = tf.concat([latent, code], axis=-1)
+			fake_img = mirrored_strategy.run(model.gen, args=(latent,))
+			fake_img = fake_img[0].numpy()
+			fake_img = np.uint8(np.clip(255*(fake_img * 0.5 + 0.5), 0.0, 255.0))
+			fake_img = cv2.cvtColor(fake_img, cv2.COLOR_RGB2BGR)
+			cv2.imwrite('experiments/inference_result/{}/{}.png'.format(clidx, _), fake_img)