Diff of /model.py [000000] .. [277df6]

Switch to unified view

a b/model.py
1
import tensorflow as tf
2
from tensorflow.keras import Model, layers, backend
3
from tensorflow.keras.constraints import Constraint
4
from losses import disc_hinge, disc_loss, gen_loss, gen_hinge
5
from diff_augment import diff_augment
6
from tensorflow_addons.layers import SpectralNormalization
7
8
tf.random.set_seed(45)
9
# np.random.seed(45)
10
11
class Generator(Model):
12
    def __init__(self, n_class=10, res=128):
13
        super(Generator, self).__init__()
14
        # filters   = [  1024, 512, 256, 128,  64, 32]#, 32, 16]
15
        # strides   = [     4,   2,   2,   2,   2,  2]#,  2, 2]
16
        filters   = [  1024, 512, 256, 128,  64, 32]#, 16]
17
        strides   = [     4,   2,   2,   2,   2,  2]#, 2]
18
        self.cnn_depth  = len(filters)
19
20
        # For discrete condition we are using Embedding
21
        self.cond_embedding = layers.Embedding(input_dim=n_class, output_dim=50)
22
        self.cond_flat      = layers.Flatten()
23
        self.cond_dense     = layers.Dense(units=(8 * 8 * 1))
24
        self.cond_reshape   = layers.Reshape(target_shape=(64,))
25
26
        # Hyperparameter:
27
        # If only conv  : mean=0.0, var=0.02
28
        # If using bnorm: mean=1.0, var=0.02
29
        self.conv  = [SpectralNormalization(layers.Conv2DTranspose(\
30
                      filters=filters[idx], kernel_size=3,\
31
                      strides=strides[idx], padding='same',\
32
                      kernel_initializer=tf.keras.initializers.TruncatedNormal(mean=0.0, stddev=0.02),\
33
                      use_bias=False))\
34
                      for idx in range(self.cnn_depth)]
35
36
        self.act   = [layers.LeakyReLU() for idx in range(self.cnn_depth)]
37
38
        self.bnorm = [layers.BatchNormalization() for idx in range(self.cnn_depth)]
39
40
        self.last_conv = SpectralNormalization(layers.Conv2D(filters=3, kernel_size=3,\
41
                                       strides=1, padding='same',\
42
                                       activation='tanh',\
43
                                       kernel_initializer=tf.keras.initializers.TruncatedNormal(mean=0.0, stddev=0.02),\
44
                                       use_bias=False))
45
46
    @tf.function
47
    def call(self, X):
48
        # C = self.cond_reshape( self.cond_dense( self.cond_flat( self.cond_embedding( C ) ) ) )
49
        # X = tf.concat([C, X], axis=-1)
50
        
51
        X = tf.expand_dims(tf.expand_dims(X, axis=1), axis=1)
52
        X = self.act[0]( self.conv[0]( X ) )
53
54
        for idx in range(1, self.cnn_depth):
55
            X = self.act[idx]( self.bnorm[idx]( self.conv[idx]( X ) ) )
56
            # X = self.bnorm[idx]( self.act[idx]( self.conv[idx]( X ) ) )
57
            # X = self.act[idx]( self.conv[idx]( X ) )
58
        X = self.last_conv(X)
59
        return X
60
61
62
class Discriminator(Model):
63
    def __init__(self, n_class=10, res=128):
64
        super(Discriminator, self).__init__()
65
        # filters    = [32, 64, 128, 256, 256, 512, 512, 1]
66
        # strides    = [ 2,  2,   2,   2,   2,   2,   1, 1]
67
        # filters    = [ 64, 128, 256, 512, 1024, 1]
68
        # strides    = [  2,   2,   2,   2,    1, 1]
69
        filters    = [ 64, 128, 256, 512, 1024, 1]
70
        strides    = [  2,   2,   2,   2,    1, 1]
71
        self.cnn_depth = len(filters)
72
73
        # For discrete condition we are using Embedding
74
        self.cond_embedding = layers.Embedding(input_dim=n_class, output_dim=50)
75
        self.cond_flat      = layers.Flatten()
76
        self.cond_dense     = layers.Dense(units=(res * res * 1))
77
        self.cond_reshape   = layers.Reshape(target_shape=(res, res, 1))
78
79
        self.cnn_conv  = [layers.Conv2D(filters=filters[i], kernel_size=3,\
80
                                        strides=strides[i], padding='same',\
81
                                        kernel_initializer=tf.keras.initializers.TruncatedNormal(mean=0.0, stddev=0.02),\
82
                                        use_bias=False)\
83
                                        for i in range(self.cnn_depth)] 
84
85
        self.cnn_bnorm = [layers.BatchNormalization() for _ in range(self.cnn_depth)]
86
87
        self.cnn_act   = [layers.LeakyReLU(alpha=0.2) for _ in range(self.cnn_depth)]
88
89
        # self.final_act = layers.Activation('sigmoid')
90
91
        self.flat      = layers.Flatten()
92
93
        self.disc_out  = layers.Dense(units=1)
94
95
        # self.autoenc   = Autoencoder()
96
97
    @tf.function
98
    def call(self, x, C):
99
        #x         = self.cnn_merge( x )
100
        #x         = self.cnn_exp( x )
101
        # mem_bank   = []
102
        # C = self.cond_reshape( self.cond_dense( self.cond_flat( self.cond_embedding( C ) ) ) )
103
        C = tf.expand_dims( tf.expand_dims(C, axis=1), axis=1)
104
        C = tf.tile(C, [1, x.shape[1], x.shape[2], 1])
105
        x = tf.concat([x, C], axis=-1)
106
107
        for layer_no in range(self.cnn_depth):
108
            # print(x.shape)
109
            x = self.cnn_act[layer_no]( self.cnn_bnorm[layer_no]( self.cnn_conv[layer_no]( x ) ) )
110
            # x = self.cnn_bnorm[layer_no]( self.cnn_act[layer_no]( self.cnn_conv[layer_no]( x ) ) )
111
            # x = self.cnn_act[layer_no]( self.cnn_conv[layer_no]( x ) )
112
            # if layer_no == 0:
113
            #   mem_bank.append( x )
114
            # if layer_no == 1:
115
            #   mem_bank.append( x )
116
            # x = self.cnn_act[layer_no]( self.cnn_conv[layer_no]( x ) )
117
118
        # reconst_x = self.autoenc( x )
119
        
120
        # condition = tf.expand_dims(tf.expand_dims(condition, axis=1), axis=1)
121
        # condition = tf.tile(condition, [1, x.shape[1], x.shape[1], 1])
122
        # x         = tf.concat([x, condition], axis=-1)
123
124
        # x = self.cnn_act[layer_no+1]( self.cnn_bnorm[layer_no+1]( self.cnn_conv[layer_no+1]( x ) ) )
125
        # x = self.cnn_bnorm[layer_no+1]( self.cnn_act[layer_no+1]( self.cnn_conv[layer_no+1]( x ) ) )
126
        # x = self.cnn_act[layer_no+1]( self.cnn_conv[layer_no+1]( x ) )
127
128
        # reconst_x = self.autoenc( x )
129
        reconst_x   = None
130
131
        # x = self.cnn_act[layer_no+2]( self.cnn_bnorm[layer_no+2]( self.cnn_conv[layer_no+2]( x ) ) )
132
        # reconst_x = self.autoenc( x, mem_bank )
133
134
        # x = self.final_act( x )
135
        # x = self.out( self.flat( x ) )
136
        x = self.disc_out( self.flat( x ) )
137
138
        return x, reconst_x
139
140
class DCGAN(Model):
141
    def __init__(self):
142
        super(DCGAN, self).__init__()
143
        self.gen    = Generator()
144
        self.disc   = Discriminator()
145
146
@tf.function
147
def dist_train_step(mirrored_strategy, model, model_gopt, model_copt, X, C, latent_dim=96, batch_size=64):
148
149
    diff_augment_policies = "color,translation"
150
    noise_vector          = tf.random.uniform(shape=(batch_size, latent_dim), minval=-1, maxval=1)
151
    noise_vector_2        = tf.random.uniform(shape=(batch_size, latent_dim), minval=-1, maxval=1)
152
    noise_vector          = tf.concat([noise_vector, C], axis=-1)
153
    noise_vector_2        = tf.concat([noise_vector_2, C], axis=-1)
154
    # @tf.function
155
    def train_step_disc(model, model_gopt, model_copt, X, C, latent_dim=96, batch_size=64): 
156
        with tf.GradientTape() as ctape:
157
            # noise_vector = tf.random.uniform(shape=(batch_size, latent_dim), minval=-1, maxval=1)
158
            # noise_vector = tf.random.uniform(shape=(batch_size, latent_dim), minval=-1, maxval=1)
159
            # noise_vector = tf.random.normal(shape=(batch_size, latent_dim))
160
161
            fake_img     = model.gen(noise_vector, training=False)
162
163
            X_aug        = diff_augment(X, policy=diff_augment_policies)
164
            fake_img     = diff_augment(fake_img, policy=diff_augment_policies)
165
166
            D_real, X_recon = model.disc(X_aug, C, training=True)
167
            D_fake, _       = model.disc(fake_img, C, training=True)
168
169
            # c_loss       = disc_loss(D_real, D_fake) +\
170
            #              tf.reduce_mean(tf.keras.losses.MeanAbsoluteError(reduction=tf.keras.losses.Reduction.NONE)(X_aug, X_recon))
171
            # c_loss       = disc_hinge(D_real, D_fake) +\
172
            #              tf.reduce_mean(tf.keras.losses.MeanAbsoluteError(reduction=tf.keras.losses.Reduction.NONE)(X_aug, X_recon))
173
            c_loss       = disc_hinge(D_real, D_fake)
174
175
        variables = model.disc.trainable_variables
176
        gradients = ctape.gradient(c_loss, variables)
177
        model_copt.apply_gradients(zip(gradients, variables))
178
        return c_loss
179
180
    # @tf.function
181
    def train_step_gen(model, model_gopt, model_copt, X, C, latent_dim=96, batch_size=64):
182
        with tf.GradientTape() as gtape:
183
            # noise_vector = tf.random.uniform(shape=(batch_size, latent_dim), minval=-1, maxval=1)
184
            # noise_vector = tf.random.normal(shape=(batch_size, latent_dim))
185
            
186
            fake_img_o   = model.gen(noise_vector, training=True)
187
            fake_img_2_o = model.gen(noise_vector_2, training=True)
188
            #D_fake       = model.disc(fake_img, H_hat, training=False)
189
190
            fake_img     = diff_augment(fake_img_o, policy=diff_augment_policies)
191
            fake_img_2   = diff_augment(fake_img_2_o, policy=diff_augment_policies)
192
193
            D_fake, _    = model.disc(fake_img, C, training=False)
194
            D_fake_2, _  = model.disc(fake_img_2, C, training=False)
195
            # g_loss       = gen_loss(D_fake)
196
            g_loss       = gen_hinge(D_fake) + gen_hinge(D_fake_2)
197
            mode_loss    = tf.divide(tf.reduce_mean(tf.abs(tf.subtract(fake_img_2_o, fake_img_o))),\
198
                                    tf.reduce_mean(tf.abs(tf.subtract(noise_vector_2, noise_vector)))
199
                                    )
200
            mode_loss   = tf.divide(1.0, mode_loss + 1e-5)
201
            g_loss      = g_loss + 1.0 * mode_loss
202
203
        variables = model.gen.trainable_variables #+ model.gcn.trainable_variables
204
        gradients = gtape.gradient(g_loss, variables)
205
        model_gopt.apply_gradients(zip(gradients, variables))
206
        return g_loss
207
208
    per_replica_loss_disc = mirrored_strategy.run(train_step_disc, args=(model, model_gopt, model_copt, X, C, latent_dim, batch_size,))
209
    per_replica_loss_gen  = mirrored_strategy.run(train_step_gen, args=(model, model_gopt, model_copt, X, C, latent_dim, batch_size,))
210
    
211
    # print(per_replica_loss_disc)
212
213
    # print(mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_loss_disc, axis=0).numpy())
214
215
    discriminator_loss = mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_loss_disc, axis=None)
216
    generator_loss = mirrored_strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_loss_gen, axis=None)
217
    return generator_loss, discriminator_loss