|
a |
|
b/model.py |
|
|
1 |
import tensorflow as tf |
|
|
2 |
from tensorflow.keras import Model, layers, backend |
|
|
3 |
from tensorflow.keras.constraints import Constraint |
|
|
4 |
from losses import disc_hinge, disc_loss, gen_loss, gen_hinge |
|
|
5 |
from diff_augment import diff_augment |
|
|
6 |
from tensorflow_addons.layers import SpectralNormalization |
|
|
7 |
|
|
|
8 |
tf.random.set_seed(45) |
|
|
9 |
# np.random.seed(45) |
|
|
10 |
|
|
|
11 |
class Generator(Model): |
|
|
12 |
def __init__(self, n_class=10, res=128): |
|
|
13 |
super(Generator, self).__init__() |
|
|
14 |
# filters = [ 1024, 512, 256, 128, 64, 32]#, 32, 16] |
|
|
15 |
# strides = [ 4, 2, 2, 2, 2, 2]#, 2, 2] |
|
|
16 |
filters = [ 1024, 512, 256, 128, 64, 32]#, 16] |
|
|
17 |
strides = [ 4, 2, 2, 2, 2, 2]#, 2] |
|
|
18 |
self.cnn_depth = len(filters) |
|
|
19 |
|
|
|
20 |
# For discrete condition we are using Embedding |
|
|
21 |
self.cond_embedding = layers.Embedding(input_dim=n_class, output_dim=50) |
|
|
22 |
self.cond_flat = layers.Flatten() |
|
|
23 |
self.cond_dense = layers.Dense(units=(8 * 8 * 1)) |
|
|
24 |
self.cond_reshape = layers.Reshape(target_shape=(64,)) |
|
|
25 |
|
|
|
26 |
# Hyperparameter: |
|
|
27 |
# If only conv : mean=0.0, var=0.02 |
|
|
28 |
# If using bnorm: mean=1.0, var=0.02 |
|
|
29 |
self.conv = [SpectralNormalization(layers.Conv2DTranspose(\ |
|
|
30 |
filters=filters[idx], kernel_size=3,\ |
|
|
31 |
strides=strides[idx], padding='same',\ |
|
|
32 |
kernel_initializer=tf.keras.initializers.TruncatedNormal(mean=0.0, stddev=0.02),\ |
|
|
33 |
use_bias=False))\ |
|
|
34 |
for idx in range(self.cnn_depth)] |
|
|
35 |
|
|
|
36 |
self.act = [layers.LeakyReLU() for idx in range(self.cnn_depth)] |
|
|
37 |
|
|
|
38 |
self.bnorm = [layers.BatchNormalization() for idx in range(self.cnn_depth)] |
|
|
39 |
|
|
|
40 |
self.last_conv = SpectralNormalization(layers.Conv2D(filters=3, kernel_size=3,\ |
|
|
41 |
strides=1, padding='same',\ |
|
|
42 |
activation='tanh',\ |
|
|
43 |
kernel_initializer=tf.keras.initializers.TruncatedNormal(mean=0.0, stddev=0.02),\ |
|
|
44 |
use_bias=False)) |
|
|
45 |
|
|
|
46 |
@tf.function |
|
|
47 |
def call(self, X): |
|
|
48 |
# C = self.cond_reshape( self.cond_dense( self.cond_flat( self.cond_embedding( C ) ) ) ) |
|
|
49 |
# X = tf.concat([C, X], axis=-1) |
|
|
50 |
|
|
|
51 |
X = tf.expand_dims(tf.expand_dims(X, axis=1), axis=1) |
|
|
52 |
X = self.act[0]( self.conv[0]( X ) ) |
|
|
53 |
|
|
|
54 |
for idx in range(1, self.cnn_depth): |
|
|
55 |
X = self.act[idx]( self.bnorm[idx]( self.conv[idx]( X ) ) ) |
|
|
56 |
# X = self.bnorm[idx]( self.act[idx]( self.conv[idx]( X ) ) ) |
|
|
57 |
# X = self.act[idx]( self.conv[idx]( X ) ) |
|
|
58 |
X = self.last_conv(X) |
|
|
59 |
return X |
|
|
60 |
|
|
|
61 |
|
|
|
62 |
class Discriminator(Model): |
|
|
63 |
def __init__(self, n_class=10, res=128): |
|
|
64 |
super(Discriminator, self).__init__() |
|
|
65 |
# filters = [32, 64, 128, 256, 256, 512, 512, 1] |
|
|
66 |
# strides = [ 2, 2, 2, 2, 2, 2, 1, 1] |
|
|
67 |
# filters = [ 64, 128, 256, 512, 1024, 1] |
|
|
68 |
# strides = [ 2, 2, 2, 2, 1, 1] |
|
|
69 |
filters = [ 64, 128, 256, 512, 1024, 1] |
|
|
70 |
strides = [ 2, 2, 2, 2, 1, 1] |
|
|
71 |
self.cnn_depth = len(filters) |
|
|
72 |
|
|
|
73 |
# For discrete condition we are using Embedding |
|
|
74 |
self.cond_embedding = layers.Embedding(input_dim=n_class, output_dim=50) |
|
|
75 |
self.cond_flat = layers.Flatten() |
|
|
76 |
self.cond_dense = layers.Dense(units=(res * res * 1)) |
|
|
77 |
self.cond_reshape = layers.Reshape(target_shape=(res, res, 1)) |
|
|
78 |
|
|
|
79 |
self.cnn_conv = [layers.Conv2D(filters=filters[i], kernel_size=3,\ |
|
|
80 |
strides=strides[i], padding='same',\ |
|
|
81 |
kernel_initializer=tf.keras.initializers.TruncatedNormal(mean=0.0, stddev=0.02),\ |
|
|
82 |
use_bias=False)\ |
|
|
83 |
for i in range(self.cnn_depth)] |
|
|
84 |
|
|
|
85 |
self.cnn_bnorm = [layers.BatchNormalization() for _ in range(self.cnn_depth)] |
|
|
86 |
|
|
|
87 |
self.cnn_act = [layers.LeakyReLU(alpha=0.2) for _ in range(self.cnn_depth)] |
|
|
88 |
|
|
|
89 |
# self.final_act = layers.Activation('sigmoid') |
|
|
90 |
|
|
|
91 |
self.flat = layers.Flatten() |
|
|
92 |
|
|
|
93 |
self.disc_out = layers.Dense(units=1) |
|
|
94 |
|
|
|
95 |
# self.autoenc = Autoencoder() |
|
|
96 |
|
|
|
97 |
@tf.function |
|
|
98 |
def call(self, x, C): |
|
|
99 |
#x = self.cnn_merge( x ) |
|
|
100 |
#x = self.cnn_exp( x ) |
|
|
101 |
# mem_bank = [] |
|
|
102 |
# C = self.cond_reshape( self.cond_dense( self.cond_flat( self.cond_embedding( C ) ) ) ) |
|
|
103 |
C = tf.expand_dims( tf.expand_dims(C, axis=1), axis=1) |
|
|
104 |
C = tf.tile(C, [1, x.shape[1], x.shape[2], 1]) |
|
|
105 |
x = tf.concat([x, C], axis=-1) |
|
|
106 |
|
|
|
107 |
for layer_no in range(self.cnn_depth): |
|
|
108 |
# print(x.shape) |
|
|
109 |
x = self.cnn_act[layer_no]( self.cnn_bnorm[layer_no]( self.cnn_conv[layer_no]( x ) ) ) |
|
|
110 |
# x = self.cnn_bnorm[layer_no]( self.cnn_act[layer_no]( self.cnn_conv[layer_no]( x ) ) ) |
|
|
111 |
# x = self.cnn_act[layer_no]( self.cnn_conv[layer_no]( x ) ) |
|
|
112 |
# if layer_no == 0: |
|
|
113 |
# mem_bank.append( x ) |
|
|
114 |
# if layer_no == 1: |
|
|
115 |
# mem_bank.append( x ) |
|
|
116 |
# x = self.cnn_act[layer_no]( self.cnn_conv[layer_no]( x ) ) |
|
|
117 |
|
|
|
118 |
# reconst_x = self.autoenc( x ) |
|
|
119 |
|
|
|
120 |
# condition = tf.expand_dims(tf.expand_dims(condition, axis=1), axis=1) |
|
|
121 |
# condition = tf.tile(condition, [1, x.shape[1], x.shape[1], 1]) |
|
|
122 |
# x = tf.concat([x, condition], axis=-1) |
|
|
123 |
|
|
|
124 |
# x = self.cnn_act[layer_no+1]( self.cnn_bnorm[layer_no+1]( self.cnn_conv[layer_no+1]( x ) ) ) |
|
|
125 |
# x = self.cnn_bnorm[layer_no+1]( self.cnn_act[layer_no+1]( self.cnn_conv[layer_no+1]( x ) ) ) |
|
|
126 |
# x = self.cnn_act[layer_no+1]( self.cnn_conv[layer_no+1]( x ) ) |
|
|
127 |
|
|
|
128 |
# reconst_x = self.autoenc( x ) |
|
|
129 |
reconst_x = None |
|
|
130 |
|
|
|
131 |
# x = self.cnn_act[layer_no+2]( self.cnn_bnorm[layer_no+2]( self.cnn_conv[layer_no+2]( x ) ) ) |
|
|
132 |
# reconst_x = self.autoenc( x, mem_bank ) |
|
|
133 |
|
|
|
134 |
# x = self.final_act( x ) |
|
|
135 |
# x = self.out( self.flat( x ) ) |
|
|
136 |
x = self.disc_out( self.flat( x ) ) |
|
|
137 |
|
|
|
138 |
return x, reconst_x |
|
|
139 |
|
|
|
140 |
class DCGAN(Model): |
|
|
141 |
def __init__(self): |
|
|
142 |
super(DCGAN, self).__init__() |
|
|
143 |
self.gen = Generator() |
|
|
144 |
self.disc = Discriminator() |
|
|
145 |
|
|
|
146 |
@tf.function |
|
|
147 |
def dist_train_step(mirrored_strategy, model, model_gopt, model_copt, X, C, latent_dim=96, batch_size=64): |
|
|
148 |
|
|
|
149 |
diff_augment_policies = "color,translation" |
|
|
150 |
noise_vector = tf.random.uniform(shape=(batch_size, latent_dim), minval=-1, maxval=1) |
|
|
151 |
noise_vector_2 = tf.random.uniform(shape=(batch_size, latent_dim), minval=-1, maxval=1) |
|
|
152 |
noise_vector = tf.concat([noise_vector, C], axis=-1) |
|
|
153 |
noise_vector_2 = tf.concat([noise_vector_2, C], axis=-1) |
|
|
154 |
# @tf.function |
|
|
155 |
def train_step_disc(model, model_gopt, model_copt, X, C, latent_dim=96, batch_size=64): |
|
|
156 |
with tf.GradientTape() as ctape: |
|
|
157 |
# noise_vector = tf.random.uniform(shape=(batch_size, latent_dim), minval=-1, maxval=1) |
|
|
158 |
# noise_vector = tf.random.uniform(shape=(batch_size, latent_dim), minval=-1, maxval=1) |
|
|
159 |
# noise_vector = tf.random.normal(shape=(batch_size, latent_dim)) |
|
|
160 |
|
|
|
161 |
fake_img = model.gen(noise_vector, training=False) |
|
|
162 |
|
|
|
163 |
X_aug = diff_augment(X, policy=diff_augment_policies) |
|
|
164 |
fake_img = diff_augment(fake_img, policy=diff_augment_policies) |
|
|
165 |
|
|
|
166 |
D_real, X_recon = model.disc(X_aug, C, training=True) |
|
|
167 |
D_fake, _ = model.disc(fake_img, C, training=True) |
|
|
168 |
|
|
|
169 |
# c_loss = disc_loss(D_real, D_fake) +\ |
|
|
170 |
# tf.reduce_mean(tf.keras.losses.MeanAbsoluteError(reduction=tf.keras.losses.Reduction.NONE)(X_aug, X_recon)) |
|
|
171 |
# c_loss = disc_hinge(D_real, D_fake) +\ |
|
|
172 |
# tf.reduce_mean(tf.keras.losses.MeanAbsoluteError(reduction=tf.keras.losses.Reduction.NONE)(X_aug, X_recon)) |
|
|
173 |
c_loss = disc_hinge(D_real, D_fake) |
|
|
174 |
|
|
|
175 |
variables = model.disc.trainable_variables |
|
|
176 |
gradients = ctape.gradient(c_loss, variables) |
|
|
177 |
model_copt.apply_gradients(zip(gradients, variables)) |
|
|
178 |
return c_loss |
|
|
179 |
|
|
|
180 |
# @tf.function |
|
|
181 |
def train_step_gen(model, model_gopt, model_copt, X, C, latent_dim=96, batch_size=64): |
|
|
182 |
with tf.GradientTape() as gtape: |
|
|
183 |
# noise_vector = tf.random.uniform(shape=(batch_size, latent_dim), minval=-1, maxval=1) |
|
|
184 |
# noise_vector = tf.random.normal(shape=(batch_size, latent_dim)) |
|
|
185 |
|
|
|
186 |
fake_img_o = model.gen(noise_vector, training=True) |
|
|
187 |
fake_img_2_o = model.gen(noise_vector_2, training=True) |
|
|
188 |
#D_fake = model.disc(fake_img, H_hat, training=False) |
|
|
189 |
|
|
|
190 |
fake_img = diff_augment(fake_img_o, policy=diff_augment_policies) |
|
|
191 |
fake_img_2 = diff_augment(fake_img_2_o, policy=diff_augment_policies) |
|
|
192 |
|
|
|
193 |
D_fake, _ = model.disc(fake_img, C, training=False) |
|
|
194 |
D_fake_2, _ = model.disc(fake_img_2, C, training=False) |
|
|
195 |
# g_loss = gen_loss(D_fake) |
|
|
196 |
g_loss = gen_hinge(D_fake) + gen_hinge(D_fake_2) |
|
|
197 |
mode_loss = tf.divide(tf.reduce_mean(tf.abs(tf.subtract(fake_img_2_o, fake_img_o))),\ |
|
|
198 |
tf.reduce_mean(tf.abs(tf.subtract(noise_vector_2, noise_vector))) |
|
|
199 |
) |
|
|
200 |
mode_loss = tf.divide(1.0, mode_loss + 1e-5) |
|
|
201 |
g_loss = g_loss + 1.0 * mode_loss |
|
|
202 |
|
|
|
203 |
variables = model.gen.trainable_variables #+ model.gcn.trainable_variables |
|
|
204 |
gradients = gtape.gradient(g_loss, variables) |
|
|
205 |
model_gopt.apply_gradients(zip(gradients, variables)) |
|
|
206 |
return g_loss |
|
|
207 |
|
|
|
208 |
per_replica_loss_disc = mirrored_strategy.run(train_step_disc, args=(model, model_gopt, model_copt, X, C, latent_dim, batch_size,)) |
|
|
209 |
per_replica_loss_gen = mirrored_strategy.run(train_step_gen, args=(model, model_gopt, model_copt, X, C, latent_dim, batch_size,)) |
|
|
210 |
|
|
|
211 |
# print(per_replica_loss_disc) |
|
|
212 |
|
|
|
213 |
# print(mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_loss_disc, axis=0).numpy()) |
|
|
214 |
|
|
|
215 |
discriminator_loss = mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_loss_disc, axis=None) |
|
|
216 |
generator_loss = mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_loss_gen, axis=None) |
|
|
217 |
return generator_loss, discriminator_loss |