--- a +++ b/model.py @@ -0,0 +1,217 @@ +import tensorflow as tf +from tensorflow.keras import Model, layers, backend +from tensorflow.keras.constraints import Constraint +from losses import disc_hinge, disc_loss, gen_loss, gen_hinge +from diff_augment import diff_augment +from tensorflow_addons.layers import SpectralNormalization + +tf.random.set_seed(45) +# np.random.seed(45) + +class Generator(Model): + def __init__(self, n_class=10, res=128): + super(Generator, self).__init__() + # filters = [ 1024, 512, 256, 128, 64, 32]#, 32, 16] + # strides = [ 4, 2, 2, 2, 2, 2]#, 2, 2] + filters = [ 1024, 512, 256, 128, 64, 32]#, 16] + strides = [ 4, 2, 2, 2, 2, 2]#, 2] + self.cnn_depth = len(filters) + + # For discrete condition we are using Embedding + self.cond_embedding = layers.Embedding(input_dim=n_class, output_dim=50) + self.cond_flat = layers.Flatten() + self.cond_dense = layers.Dense(units=(8 * 8 * 1)) + self.cond_reshape = layers.Reshape(target_shape=(64,)) + + # Hyperparameter: + # If only conv : mean=0.0, var=0.02 + # If using bnorm: mean=1.0, var=0.02 + self.conv = [SpectralNormalization(layers.Conv2DTranspose(\ + filters=filters[idx], kernel_size=3,\ + strides=strides[idx], padding='same',\ + kernel_initializer=tf.keras.initializers.TruncatedNormal(mean=0.0, stddev=0.02),\ + use_bias=False))\ + for idx in range(self.cnn_depth)] + + self.act = [layers.LeakyReLU() for idx in range(self.cnn_depth)] + + self.bnorm = [layers.BatchNormalization() for idx in range(self.cnn_depth)] + + self.last_conv = SpectralNormalization(layers.Conv2D(filters=3, kernel_size=3,\ + strides=1, padding='same',\ + activation='tanh',\ + kernel_initializer=tf.keras.initializers.TruncatedNormal(mean=0.0, stddev=0.02),\ + use_bias=False)) + + @tf.function + def call(self, X): + # C = self.cond_reshape( self.cond_dense( self.cond_flat( self.cond_embedding( C ) ) ) ) + # X = tf.concat([C, X], axis=-1) + + X = tf.expand_dims(tf.expand_dims(X, axis=1), axis=1) + X = self.act[0]( self.conv[0]( X ) ) + + for idx in range(1, self.cnn_depth): + X = self.act[idx]( self.bnorm[idx]( self.conv[idx]( X ) ) ) + # X = self.bnorm[idx]( self.act[idx]( self.conv[idx]( X ) ) ) + # X = self.act[idx]( self.conv[idx]( X ) ) + X = self.last_conv(X) + return X + + +class Discriminator(Model): + def __init__(self, n_class=10, res=128): + super(Discriminator, self).__init__() + # filters = [32, 64, 128, 256, 256, 512, 512, 1] + # strides = [ 2, 2, 2, 2, 2, 2, 1, 1] + # filters = [ 64, 128, 256, 512, 1024, 1] + # strides = [ 2, 2, 2, 2, 1, 1] + filters = [ 64, 128, 256, 512, 1024, 1] + strides = [ 2, 2, 2, 2, 1, 1] + self.cnn_depth = len(filters) + + # For discrete condition we are using Embedding + self.cond_embedding = layers.Embedding(input_dim=n_class, output_dim=50) + self.cond_flat = layers.Flatten() + self.cond_dense = layers.Dense(units=(res * res * 1)) + self.cond_reshape = layers.Reshape(target_shape=(res, res, 1)) + + self.cnn_conv = [layers.Conv2D(filters=filters[i], kernel_size=3,\ + strides=strides[i], padding='same',\ + kernel_initializer=tf.keras.initializers.TruncatedNormal(mean=0.0, stddev=0.02),\ + use_bias=False)\ + for i in range(self.cnn_depth)] + + self.cnn_bnorm = [layers.BatchNormalization() for _ in range(self.cnn_depth)] + + self.cnn_act = [layers.LeakyReLU(alpha=0.2) for _ in range(self.cnn_depth)] + + # self.final_act = layers.Activation('sigmoid') + + self.flat = layers.Flatten() + + self.disc_out = layers.Dense(units=1) + + # self.autoenc = Autoencoder() + + @tf.function + def call(self, x, C): + #x = self.cnn_merge( x ) + #x = self.cnn_exp( x ) + # mem_bank = [] + # C = self.cond_reshape( self.cond_dense( self.cond_flat( self.cond_embedding( C ) ) ) ) + C = tf.expand_dims( tf.expand_dims(C, axis=1), axis=1) + C = tf.tile(C, [1, x.shape[1], x.shape[2], 1]) + x = tf.concat([x, C], axis=-1) + + for layer_no in range(self.cnn_depth): + # print(x.shape) + x = self.cnn_act[layer_no]( self.cnn_bnorm[layer_no]( self.cnn_conv[layer_no]( x ) ) ) + # x = self.cnn_bnorm[layer_no]( self.cnn_act[layer_no]( self.cnn_conv[layer_no]( x ) ) ) + # x = self.cnn_act[layer_no]( self.cnn_conv[layer_no]( x ) ) + # if layer_no == 0: + # mem_bank.append( x ) + # if layer_no == 1: + # mem_bank.append( x ) + # x = self.cnn_act[layer_no]( self.cnn_conv[layer_no]( x ) ) + + # reconst_x = self.autoenc( x ) + + # condition = tf.expand_dims(tf.expand_dims(condition, axis=1), axis=1) + # condition = tf.tile(condition, [1, x.shape[1], x.shape[1], 1]) + # x = tf.concat([x, condition], axis=-1) + + # x = self.cnn_act[layer_no+1]( self.cnn_bnorm[layer_no+1]( self.cnn_conv[layer_no+1]( x ) ) ) + # x = self.cnn_bnorm[layer_no+1]( self.cnn_act[layer_no+1]( self.cnn_conv[layer_no+1]( x ) ) ) + # x = self.cnn_act[layer_no+1]( self.cnn_conv[layer_no+1]( x ) ) + + # reconst_x = self.autoenc( x ) + reconst_x = None + + # x = self.cnn_act[layer_no+2]( self.cnn_bnorm[layer_no+2]( self.cnn_conv[layer_no+2]( x ) ) ) + # reconst_x = self.autoenc( x, mem_bank ) + + # x = self.final_act( x ) + # x = self.out( self.flat( x ) ) + x = self.disc_out( self.flat( x ) ) + + return x, reconst_x + +class DCGAN(Model): + def __init__(self): + super(DCGAN, self).__init__() + self.gen = Generator() + self.disc = Discriminator() + +@tf.function +def dist_train_step(mirrored_strategy, model, model_gopt, model_copt, X, C, latent_dim=96, batch_size=64): + + diff_augment_policies = "color,translation" + noise_vector = tf.random.uniform(shape=(batch_size, latent_dim), minval=-1, maxval=1) + noise_vector_2 = tf.random.uniform(shape=(batch_size, latent_dim), minval=-1, maxval=1) + noise_vector = tf.concat([noise_vector, C], axis=-1) + noise_vector_2 = tf.concat([noise_vector_2, C], axis=-1) + # @tf.function + def train_step_disc(model, model_gopt, model_copt, X, C, latent_dim=96, batch_size=64): + with tf.GradientTape() as ctape: + # noise_vector = tf.random.uniform(shape=(batch_size, latent_dim), minval=-1, maxval=1) + # noise_vector = tf.random.uniform(shape=(batch_size, latent_dim), minval=-1, maxval=1) + # noise_vector = tf.random.normal(shape=(batch_size, latent_dim)) + + fake_img = model.gen(noise_vector, training=False) + + X_aug = diff_augment(X, policy=diff_augment_policies) + fake_img = diff_augment(fake_img, policy=diff_augment_policies) + + D_real, X_recon = model.disc(X_aug, C, training=True) + D_fake, _ = model.disc(fake_img, C, training=True) + + # c_loss = disc_loss(D_real, D_fake) +\ + # tf.reduce_mean(tf.keras.losses.MeanAbsoluteError(reduction=tf.keras.losses.Reduction.NONE)(X_aug, X_recon)) + # c_loss = disc_hinge(D_real, D_fake) +\ + # tf.reduce_mean(tf.keras.losses.MeanAbsoluteError(reduction=tf.keras.losses.Reduction.NONE)(X_aug, X_recon)) + c_loss = disc_hinge(D_real, D_fake) + + variables = model.disc.trainable_variables + gradients = ctape.gradient(c_loss, variables) + model_copt.apply_gradients(zip(gradients, variables)) + return c_loss + + # @tf.function + def train_step_gen(model, model_gopt, model_copt, X, C, latent_dim=96, batch_size=64): + with tf.GradientTape() as gtape: + # noise_vector = tf.random.uniform(shape=(batch_size, latent_dim), minval=-1, maxval=1) + # noise_vector = tf.random.normal(shape=(batch_size, latent_dim)) + + fake_img_o = model.gen(noise_vector, training=True) + fake_img_2_o = model.gen(noise_vector_2, training=True) + #D_fake = model.disc(fake_img, H_hat, training=False) + + fake_img = diff_augment(fake_img_o, policy=diff_augment_policies) + fake_img_2 = diff_augment(fake_img_2_o, policy=diff_augment_policies) + + D_fake, _ = model.disc(fake_img, C, training=False) + D_fake_2, _ = model.disc(fake_img_2, C, training=False) + # g_loss = gen_loss(D_fake) + g_loss = gen_hinge(D_fake) + gen_hinge(D_fake_2) + mode_loss = tf.divide(tf.reduce_mean(tf.abs(tf.subtract(fake_img_2_o, fake_img_o))),\ + tf.reduce_mean(tf.abs(tf.subtract(noise_vector_2, noise_vector))) + ) + mode_loss = tf.divide(1.0, mode_loss + 1e-5) + g_loss = g_loss + 1.0 * mode_loss + + variables = model.gen.trainable_variables #+ model.gcn.trainable_variables + gradients = gtape.gradient(g_loss, variables) + model_gopt.apply_gradients(zip(gradients, variables)) + return g_loss + + per_replica_loss_disc = mirrored_strategy.run(train_step_disc, args=(model, model_gopt, model_copt, X, C, latent_dim, batch_size,)) + per_replica_loss_gen = mirrored_strategy.run(train_step_gen, args=(model, model_gopt, model_copt, X, C, latent_dim, batch_size,)) + + # print(per_replica_loss_disc) + + # print(mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_loss_disc, axis=0).numpy()) + + discriminator_loss = mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_loss_disc, axis=None) + generator_loss = mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_loss_gen, axis=None) + return generator_loss, discriminator_loss \ No newline at end of file