|
a |
|
b/inference.py |
|
|
1 |
import tensorflow as tf |
|
|
2 |
# from utils import vis, load_batch#, load_data |
|
|
3 |
from utils import load_complete_data, show_batch_images |
|
|
4 |
from model import DCGAN, dist_train_step#, train_step |
|
|
5 |
from tqdm import tqdm |
|
|
6 |
import os |
|
|
7 |
import shutil |
|
|
8 |
import pickle |
|
|
9 |
from glob import glob |
|
|
10 |
from natsort import natsorted |
|
|
11 |
import wandb |
|
|
12 |
import numpy as np |
|
|
13 |
import cv2 |
|
|
14 |
|
|
|
15 |
tf.random.set_seed(45) |
|
|
16 |
np.random.seed(45) |
|
|
17 |
# wandb.init(project='DCGAN_DiffAug_EDDisc_imagenet_128', entity="prajwal_15") |
|
|
18 |
|
|
|
19 |
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices' |
|
|
20 |
os.environ["CUDA_DEVICE_ORDER"]= "PCI_BUS_ID" |
|
|
21 |
os.environ["CUDA_VISIBLE_DEVICES"]= '1' |
|
|
22 |
clstoidx = {} |
|
|
23 |
idxtocls = {} |
|
|
24 |
|
|
|
25 |
|
|
|
26 |
# @tf.function |
|
|
27 |
def get_code(path): |
|
|
28 |
path = path.numpy().decode('utf-8') |
|
|
29 |
code = np.zeros(shape=(max(clstoidx.values())+1,), dtype=np.float32) |
|
|
30 |
code[clstoidx[path.split(sep='/')[-2]]] = 1 |
|
|
31 |
return tf.cast(code, dtype=tf.float32) |
|
|
32 |
|
|
|
33 |
|
|
|
34 |
if __name__ == '__main__': |
|
|
35 |
|
|
|
36 |
# if len(glob('experiments/*'))==0: |
|
|
37 |
# os.makedirs('experiments/experiment_1/code/') |
|
|
38 |
# exp_num = 1 |
|
|
39 |
# else: |
|
|
40 |
# exp_num = len(glob('experiments/*'))+1 |
|
|
41 |
# os.makedirs('experiments/experiment_{}/code/'.format(exp_num)) |
|
|
42 |
|
|
|
43 |
# exp_dir = 'experiments/experiment_{}'.format(exp_num) |
|
|
44 |
# for item in glob('*.py'): |
|
|
45 |
# shutil.copy(item, exp_dir+'/code') |
|
|
46 |
|
|
|
47 |
gpus = tf.config.list_physical_devices('GPU') |
|
|
48 |
mirrored_strategy = tf.distribute.MirroredStrategy(devices=['/GPU:0'], |
|
|
49 |
cross_device_ops=tf.distribute.HierarchicalCopyAllReduce()) |
|
|
50 |
n_gpus = mirrored_strategy.num_replicas_in_sync |
|
|
51 |
# print(n_gpus) |
|
|
52 |
|
|
|
53 |
batch_size = 64 |
|
|
54 |
latent_dim = 128 |
|
|
55 |
input_res = 64 |
|
|
56 |
data_path = '../data/b2i_data/thoughtviz_eeg_data/*/*' |
|
|
57 |
|
|
|
58 |
# train_batch = load_complete_data(data_path, input_res=input_res, batch_size=batch_size) |
|
|
59 |
train_batch = load_complete_data(data_path, batch_size=batch_size) |
|
|
60 |
X, latent_Y = next(iter(train_batch)) |
|
|
61 |
# print(latent_Y) |
|
|
62 |
latent_Y = latent_Y[:16] |
|
|
63 |
lr = 3e-4 |
|
|
64 |
with mirrored_strategy.scope(): |
|
|
65 |
model = DCGAN() |
|
|
66 |
model_gopt = tf.keras.optimizers.Adam(learning_rate=lr, beta_1=0.2, beta_2=0.5) |
|
|
67 |
model_copt = tf.keras.optimizers.Adam(learning_rate=lr, beta_1=0.2, beta_2=0.5) |
|
|
68 |
ckpt = tf.train.Checkpoint(step=tf.Variable(1), model=model, gopt=model_gopt, copt=model_copt) |
|
|
69 |
ckpt_manager = tf.train.CheckpointManager(ckpt, directory='experiments/best_ckpt', max_to_keep=30) |
|
|
70 |
ckpt.restore(ckpt_manager.latest_checkpoint).expect_partial() |
|
|
71 |
|
|
|
72 |
# print(ckpt.step.numpy()) |
|
|
73 |
START = int(ckpt.step.numpy()) // len(train_batch) + 1 |
|
|
74 |
EPOCHS = 1000#670#66 |
|
|
75 |
model_freq = 14#200#40 |
|
|
76 |
t_visfreq = 14#200#1500#40 |
|
|
77 |
|
|
|
78 |
if ckpt_manager.latest_checkpoint: |
|
|
79 |
print('Restored from last checkpoint epoch: {0}'.format(START)) |
|
|
80 |
|
|
|
81 |
for clidx in tqdm(range(10)): |
|
|
82 |
code = np.zeros(shape=(10,), dtype=np.float32) |
|
|
83 |
code[clidx] = 1 |
|
|
84 |
code = np.expand_dims(code, axis=0) |
|
|
85 |
code = tf.cast(code, dtype=tf.float32) |
|
|
86 |
|
|
|
87 |
if not os.path.isdir('experiments/inference_result/{}'.format(clidx)): |
|
|
88 |
os.makedirs('experiments/inference_result/{}'.format(clidx)) |
|
|
89 |
|
|
|
90 |
for _ in tqdm(range(256)): |
|
|
91 |
latent = tf.random.uniform(shape=(1, latent_dim), minval=-1, maxval=1) |
|
|
92 |
latent = tf.concat([latent, code], axis=-1) |
|
|
93 |
fake_img = mirrored_strategy.run(model.gen, args=(latent,)) |
|
|
94 |
fake_img = fake_img[0].numpy() |
|
|
95 |
fake_img = np.uint8(np.clip(255*(fake_img * 0.5 + 0.5), 0.0, 255.0)) |
|
|
96 |
fake_img = cv2.cvtColor(fake_img, cv2.COLOR_RGB2BGR) |
|
|
97 |
cv2.imwrite('experiments/inference_result/{}/{}.png'.format(clidx, _), fake_img) |