Diff of /inference.py [000000] .. [277df6]

Switch to unified view

a b/inference.py
1
import tensorflow as tf
2
# from utils import vis, load_batch#, load_data
3
from utils import load_complete_data, show_batch_images
4
from model import DCGAN, dist_train_step#, train_step
5
from tqdm import tqdm
6
import os
7
import shutil
8
import pickle
9
from glob import glob
10
from natsort import natsorted
11
import wandb
12
import numpy as np
13
import cv2
14
15
tf.random.set_seed(45)
16
np.random.seed(45)
17
# wandb.init(project='DCGAN_DiffAug_EDDisc_imagenet_128', entity="prajwal_15")
18
19
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
20
os.environ["CUDA_DEVICE_ORDER"]= "PCI_BUS_ID"
21
os.environ["CUDA_VISIBLE_DEVICES"]= '1'
22
clstoidx   = {}
23
idxtocls   = {}
24
25
26
# @tf.function
27
def get_code(path):
28
    path  = path.numpy().decode('utf-8')
29
    code  = np.zeros(shape=(max(clstoidx.values())+1,), dtype=np.float32)
30
    code[clstoidx[path.split(sep='/')[-2]]] = 1
31
    return tf.cast(code, dtype=tf.float32)
32
33
34
if __name__ == '__main__':
35
36
    # if len(glob('experiments/*'))==0:
37
    #   os.makedirs('experiments/experiment_1/code/')
38
    #   exp_num = 1
39
    # else:
40
    #   exp_num = len(glob('experiments/*'))+1
41
    #   os.makedirs('experiments/experiment_{}/code/'.format(exp_num))
42
43
    # exp_dir = 'experiments/experiment_{}'.format(exp_num)
44
    # for item in glob('*.py'):
45
    #   shutil.copy(item, exp_dir+'/code')
46
    
47
    gpus = tf.config.list_physical_devices('GPU')
48
    mirrored_strategy = tf.distribute.MirroredStrategy(devices=['/GPU:0'], 
49
        cross_device_ops=tf.distribute.HierarchicalCopyAllReduce())
50
    n_gpus = mirrored_strategy.num_replicas_in_sync
51
    # print(n_gpus)
52
53
    batch_size = 64
54
    latent_dim = 128
55
    input_res  = 64
56
    data_path  = '../data/b2i_data/thoughtviz_eeg_data/*/*'
57
58
    # train_batch = load_complete_data(data_path, input_res=input_res, batch_size=batch_size)
59
    train_batch = load_complete_data(data_path, batch_size=batch_size)
60
    X, latent_Y = next(iter(train_batch))
61
    # print(latent_Y)
62
    latent_Y = latent_Y[:16]
63
    lr = 3e-4
64
    with mirrored_strategy.scope():
65
        model        = DCGAN()
66
        model_gopt   = tf.keras.optimizers.Adam(learning_rate=lr, beta_1=0.2, beta_2=0.5)
67
        model_copt   = tf.keras.optimizers.Adam(learning_rate=lr, beta_1=0.2, beta_2=0.5)
68
        ckpt         = tf.train.Checkpoint(step=tf.Variable(1), model=model, gopt=model_gopt, copt=model_copt)
69
        ckpt_manager = tf.train.CheckpointManager(ckpt, directory='experiments/best_ckpt', max_to_keep=30)
70
        ckpt.restore(ckpt_manager.latest_checkpoint).expect_partial()
71
72
    # print(ckpt.step.numpy())
73
    START         = int(ckpt.step.numpy()) // len(train_batch) + 1
74
    EPOCHS        = 1000#670#66
75
    model_freq    = 14#200#40
76
    t_visfreq     = 14#200#1500#40
77
    
78
    if ckpt_manager.latest_checkpoint:
79
        print('Restored from last checkpoint epoch: {0}'.format(START))
80
81
    for clidx in tqdm(range(10)):
82
        code = np.zeros(shape=(10,), dtype=np.float32)
83
        code[clidx] = 1
84
        code = np.expand_dims(code, axis=0)
85
        code = tf.cast(code, dtype=tf.float32)
86
        
87
        if not os.path.isdir('experiments/inference_result/{}'.format(clidx)):
88
            os.makedirs('experiments/inference_result/{}'.format(clidx))
89
90
        for _ in tqdm(range(256)):
91
            latent = tf.random.uniform(shape=(1, latent_dim), minval=-1, maxval=1)
92
            latent = tf.concat([latent, code], axis=-1)
93
            fake_img = mirrored_strategy.run(model.gen, args=(latent,))
94
            fake_img = fake_img[0].numpy()
95
            fake_img = np.uint8(np.clip(255*(fake_img * 0.5 + 0.5), 0.0, 255.0))
96
            fake_img = cv2.cvtColor(fake_img, cv2.COLOR_RGB2BGR)
97
            cv2.imwrite('experiments/inference_result/{}/{}.png'.format(clidx, _), fake_img)