Diff of /model.py [000000] .. [277df6]

Switch to side-by-side view

--- a
+++ b/model.py
@@ -0,0 +1,217 @@
+import tensorflow as tf
+from tensorflow.keras import Model, layers, backend
+from tensorflow.keras.constraints import Constraint
+from losses import disc_hinge, disc_loss, gen_loss, gen_hinge
+from diff_augment import diff_augment
+from tensorflow_addons.layers import SpectralNormalization
+
+tf.random.set_seed(45)
+# np.random.seed(45)
+
+class Generator(Model):
+	def __init__(self, n_class=10, res=128):
+		super(Generator, self).__init__()
+		# filters   = [  1024, 512, 256, 128,  64, 32]#, 32, 16]
+		# strides   = [     4,   2,   2,   2,   2,  2]#,  2, 2]
+		filters   = [  1024, 512, 256, 128,  64, 32]#, 16]
+		strides   = [     4,   2,   2,   2,   2,  2]#, 2]
+		self.cnn_depth  = len(filters)
+
+		# For discrete condition we are using Embedding
+		self.cond_embedding = layers.Embedding(input_dim=n_class, output_dim=50)
+		self.cond_flat      = layers.Flatten()
+		self.cond_dense     = layers.Dense(units=(8 * 8 * 1))
+		self.cond_reshape   = layers.Reshape(target_shape=(64,))
+
+		# Hyperparameter:
+		# If only conv  : mean=0.0, var=0.02
+		# If using bnorm: mean=1.0, var=0.02
+		self.conv  = [SpectralNormalization(layers.Conv2DTranspose(\
+					  filters=filters[idx], kernel_size=3,\
+		              strides=strides[idx], padding='same',\
+		              kernel_initializer=tf.keras.initializers.TruncatedNormal(mean=0.0, stddev=0.02),\
+					  use_bias=False))\
+					  for idx in range(self.cnn_depth)]
+
+		self.act   = [layers.LeakyReLU() for idx in range(self.cnn_depth)]
+
+		self.bnorm = [layers.BatchNormalization() for idx in range(self.cnn_depth)]
+
+		self.last_conv = SpectralNormalization(layers.Conv2D(filters=3, kernel_size=3,\
+									   strides=1, padding='same',\
+									   activation='tanh',\
+									   kernel_initializer=tf.keras.initializers.TruncatedNormal(mean=0.0, stddev=0.02),\
+									   use_bias=False))
+
+	@tf.function
+	def call(self, X):
+		# C = self.cond_reshape( self.cond_dense( self.cond_flat( self.cond_embedding( C ) ) ) )
+		# X = tf.concat([C, X], axis=-1)
+		
+		X = tf.expand_dims(tf.expand_dims(X, axis=1), axis=1)
+		X = self.act[0]( self.conv[0]( X ) )
+
+		for idx in range(1, self.cnn_depth):
+			X = self.act[idx]( self.bnorm[idx]( self.conv[idx]( X ) ) )
+			# X = self.bnorm[idx]( self.act[idx]( self.conv[idx]( X ) ) )
+			# X = self.act[idx]( self.conv[idx]( X ) )
+		X = self.last_conv(X)
+		return X
+
+
+class Discriminator(Model):
+	def __init__(self, n_class=10, res=128):
+		super(Discriminator, self).__init__()
+		# filters    = [32, 64, 128, 256, 256, 512, 512, 1]
+		# strides    = [ 2,  2,   2,   2,   2,   2,   1, 1]
+		# filters    = [ 64, 128, 256, 512, 1024, 1]
+		# strides    = [  2,   2,   2,   2,    1, 1]
+		filters    = [ 64, 128, 256, 512, 1024, 1]
+		strides    = [  2,   2,   2,   2,    1, 1]
+		self.cnn_depth = len(filters)
+
+		# For discrete condition we are using Embedding
+		self.cond_embedding = layers.Embedding(input_dim=n_class, output_dim=50)
+		self.cond_flat      = layers.Flatten()
+		self.cond_dense     = layers.Dense(units=(res * res * 1))
+		self.cond_reshape   = layers.Reshape(target_shape=(res, res, 1))
+
+		self.cnn_conv  = [layers.Conv2D(filters=filters[i], kernel_size=3,\
+										strides=strides[i], padding='same',\
+										kernel_initializer=tf.keras.initializers.TruncatedNormal(mean=0.0, stddev=0.02),\
+										use_bias=False)\
+										for i in range(self.cnn_depth)] 
+
+		self.cnn_bnorm = [layers.BatchNormalization() for _ in range(self.cnn_depth)]
+
+		self.cnn_act   = [layers.LeakyReLU(alpha=0.2) for _ in range(self.cnn_depth)]
+
+		# self.final_act = layers.Activation('sigmoid')
+
+		self.flat      = layers.Flatten()
+
+		self.disc_out  = layers.Dense(units=1)
+
+		# self.autoenc   = Autoencoder()
+
+	@tf.function
+	def call(self, x, C):
+		#x         = self.cnn_merge( x )
+		#x         = self.cnn_exp( x )
+		# mem_bank   = []
+		# C = self.cond_reshape( self.cond_dense( self.cond_flat( self.cond_embedding( C ) ) ) )
+		C = tf.expand_dims( tf.expand_dims(C, axis=1), axis=1)
+		C = tf.tile(C, [1, x.shape[1], x.shape[2], 1])
+		x = tf.concat([x, C], axis=-1)
+
+		for layer_no in range(self.cnn_depth):
+			# print(x.shape)
+			x = self.cnn_act[layer_no]( self.cnn_bnorm[layer_no]( self.cnn_conv[layer_no]( x ) ) )
+			# x = self.cnn_bnorm[layer_no]( self.cnn_act[layer_no]( self.cnn_conv[layer_no]( x ) ) )
+			# x = self.cnn_act[layer_no]( self.cnn_conv[layer_no]( x ) )
+			# if layer_no == 0:
+			# 	mem_bank.append( x )
+			# if layer_no == 1:
+			# 	mem_bank.append( x )
+			# x = self.cnn_act[layer_no]( self.cnn_conv[layer_no]( x ) )
+
+		# reconst_x = self.autoenc( x )
+		
+		# condition = tf.expand_dims(tf.expand_dims(condition, axis=1), axis=1)
+		# condition = tf.tile(condition, [1, x.shape[1], x.shape[1], 1])
+		# x         = tf.concat([x, condition], axis=-1)
+
+		# x = self.cnn_act[layer_no+1]( self.cnn_bnorm[layer_no+1]( self.cnn_conv[layer_no+1]( x ) ) )
+		# x = self.cnn_bnorm[layer_no+1]( self.cnn_act[layer_no+1]( self.cnn_conv[layer_no+1]( x ) ) )
+		# x = self.cnn_act[layer_no+1]( self.cnn_conv[layer_no+1]( x ) )
+
+		# reconst_x = self.autoenc( x )
+		reconst_x   = None
+
+		# x = self.cnn_act[layer_no+2]( self.cnn_bnorm[layer_no+2]( self.cnn_conv[layer_no+2]( x ) ) )
+		# reconst_x = self.autoenc( x, mem_bank )
+
+		# x = self.final_act( x )
+		# x = self.out( self.flat( x ) )
+		x = self.disc_out( self.flat( x ) )
+
+		return x, reconst_x
+
+class DCGAN(Model):
+	def __init__(self):
+		super(DCGAN, self).__init__()
+		self.gen    = Generator()
+		self.disc   = Discriminator()
+
+@tf.function
+def dist_train_step(mirrored_strategy, model, model_gopt, model_copt, X, C, latent_dim=96, batch_size=64):
+
+	diff_augment_policies = "color,translation"
+	noise_vector          = tf.random.uniform(shape=(batch_size, latent_dim), minval=-1, maxval=1)
+	noise_vector_2        = tf.random.uniform(shape=(batch_size, latent_dim), minval=-1, maxval=1)
+	noise_vector          = tf.concat([noise_vector, C], axis=-1)
+	noise_vector_2        = tf.concat([noise_vector_2, C], axis=-1)
+	# @tf.function
+	def train_step_disc(model, model_gopt, model_copt, X, C, latent_dim=96, batch_size=64):	
+		with tf.GradientTape() as ctape:
+			# noise_vector = tf.random.uniform(shape=(batch_size, latent_dim), minval=-1, maxval=1)
+			# noise_vector = tf.random.uniform(shape=(batch_size, latent_dim), minval=-1, maxval=1)
+			# noise_vector = tf.random.normal(shape=(batch_size, latent_dim))
+
+			fake_img     = model.gen(noise_vector, training=False)
+
+			X_aug        = diff_augment(X, policy=diff_augment_policies)
+			fake_img     = diff_augment(fake_img, policy=diff_augment_policies)
+
+			D_real, X_recon = model.disc(X_aug, C, training=True)
+			D_fake, _       = model.disc(fake_img, C, training=True)
+
+			# c_loss       = disc_loss(D_real, D_fake) +\
+			# 			   tf.reduce_mean(tf.keras.losses.MeanAbsoluteError(reduction=tf.keras.losses.Reduction.NONE)(X_aug, X_recon))
+			# c_loss       = disc_hinge(D_real, D_fake) +\
+			# 			   tf.reduce_mean(tf.keras.losses.MeanAbsoluteError(reduction=tf.keras.losses.Reduction.NONE)(X_aug, X_recon))
+			c_loss       = disc_hinge(D_real, D_fake)
+
+		variables = model.disc.trainable_variables
+		gradients = ctape.gradient(c_loss, variables)
+		model_copt.apply_gradients(zip(gradients, variables))
+		return c_loss
+
+	# @tf.function
+	def train_step_gen(model, model_gopt, model_copt, X, C, latent_dim=96, batch_size=64):
+		with tf.GradientTape() as gtape:
+			# noise_vector = tf.random.uniform(shape=(batch_size, latent_dim), minval=-1, maxval=1)
+			# noise_vector = tf.random.normal(shape=(batch_size, latent_dim))
+			
+			fake_img_o   = model.gen(noise_vector, training=True)
+			fake_img_2_o = model.gen(noise_vector_2, training=True)
+			#D_fake       = model.disc(fake_img, H_hat, training=False)
+
+			fake_img     = diff_augment(fake_img_o, policy=diff_augment_policies)
+			fake_img_2   = diff_augment(fake_img_2_o, policy=diff_augment_policies)
+
+			D_fake, _    = model.disc(fake_img, C, training=False)
+			D_fake_2, _  = model.disc(fake_img_2, C, training=False)
+			# g_loss       = gen_loss(D_fake)
+			g_loss       = gen_hinge(D_fake) + gen_hinge(D_fake_2)
+			mode_loss    = tf.divide(tf.reduce_mean(tf.abs(tf.subtract(fake_img_2_o, fake_img_o))),\
+									tf.reduce_mean(tf.abs(tf.subtract(noise_vector_2, noise_vector)))
+									)
+			mode_loss   = tf.divide(1.0, mode_loss + 1e-5)
+			g_loss      = g_loss + 1.0 * mode_loss
+
+		variables = model.gen.trainable_variables #+ model.gcn.trainable_variables
+		gradients = gtape.gradient(g_loss, variables)
+		model_gopt.apply_gradients(zip(gradients, variables))
+		return g_loss
+
+	per_replica_loss_disc = mirrored_strategy.run(train_step_disc, args=(model, model_gopt, model_copt, X, C, latent_dim, batch_size,))
+	per_replica_loss_gen  = mirrored_strategy.run(train_step_gen, args=(model, model_gopt, model_copt, X, C, latent_dim, batch_size,))
+	
+	# print(per_replica_loss_disc)
+
+	# print(mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_loss_disc, axis=0).numpy())
+
+	discriminator_loss = mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_loss_disc, axis=None)
+	generator_loss = mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_loss_gen, axis=None)
+	return generator_loss, discriminator_loss
\ No newline at end of file