Diff of /model.py [000000] .. [9e0229]

Switch to unified view

a b/model.py
1
import tensorflow as tf
2
import tensorflow.contrib as tc
3
import tensorflow.contrib.layers as tcl
4
5
#the default is relu function
6
def leaky_relu(x, alpha=0.2):
7
    return tf.maximum(tf.minimum(0.0, alpha * x), x)
8
    #return tf.maximum(0.0, x)
9
    #return tf.nn.tanh(x)
10
    #return tf.nn.elu(x)
11
12
def conv_cond_concat(x, y):
13
    """Concatenate conditioning vector on feature map axis."""
14
    x_shapes = x.get_shape()
15
    y_shapes = y.get_shape()
16
    return tf.concat([x , y*tf.ones([tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2] ,tf.shape(y)[3]])], 3)
17
18
class Discriminator(object):
19
    def __init__(self, input_dim, name, nb_layers=2,nb_units=256):
20
        self.input_dim = input_dim
21
        self.name = name
22
        self.nb_layers = nb_layers
23
        self.nb_units = nb_units
24
25
    def __call__(self, x, reuse=True):
26
        with tf.variable_scope(self.name) as vs:
27
            if reuse:
28
                vs.reuse_variables()
29
            
30
            fc = tcl.fully_connected(
31
                x, self.nb_units,
32
                #weights_initializer=tf.random_normal_initializer(stddev=0.02),
33
                activation_fn=tf.identity
34
            )
35
            #fc = tcl.batch_norm(fc)
36
            fc = leaky_relu(fc)
37
            for _ in range(self.nb_layers-1):
38
                fc = tcl.fully_connected(
39
                    fc, self.nb_units,
40
                    #weights_initializer=tf.random_normal_initializer(stddev=0.02),
41
                    activation_fn=tf.identity
42
                    )
43
                fc = tcl.batch_norm(fc)
44
                #fc = leaky_relu(fc)
45
                fc = tf.nn.tanh(fc)
46
            
47
            output = tcl.fully_connected(
48
                fc, 1, 
49
                #weights_initializer=tf.random_normal_initializer(stddev=0.02),
50
                activation_fn=tf.identity
51
                )
52
            return output
53
54
    @property
55
    def vars(self):
56
        return [var for var in tf.global_variables() if self.name in var.name]
57
58
class Generator(object):
59
    def __init__(self, input_dim, output_dim, name, nb_layers=2, nb_units=256, concat_every_fcl=True):
60
        self.input_dim = input_dim
61
        self.output_dim = output_dim
62
        self.name = name
63
        self.nb_layers = nb_layers
64
        self.nb_units = nb_units
65
        self.concat_every_fcl = concat_every_fcl
66
        
67
    def __call__(self, z, reuse=True):
68
        #with tf.variable_scope(self.name,reuse=tf.AUTO_REUSE) as vs:       
69
        with tf.variable_scope(self.name) as vs:
70
            if reuse:
71
                vs.reuse_variables()
72
            y = z[:,self.input_dim:]
73
            fc = tcl.fully_connected(
74
                z, self.nb_units,
75
                weights_initializer=tf.random_normal_initializer(stddev=0.02),
76
                weights_regularizer=tcl.l2_regularizer(2.5e-5),
77
                activation_fn=tf.identity
78
                )
79
            #fc = tc.layers.batch_norm(fc,decay=0.9,scale=True,updates_collections=None,is_training = True)
80
            fc = leaky_relu(fc)
81
            #fc = tf.nn.dropout(fc,0.1)
82
            if self.concat_every_fcl:
83
                fc = tf.concat([fc, y], 1)
84
            for _ in range(self.nb_layers-1):
85
                fc = tcl.fully_connected(
86
                    fc, self.nb_units,
87
                    weights_initializer=tf.random_normal_initializer(stddev=0.02),
88
                    weights_regularizer=tcl.l2_regularizer(2.5e-5),
89
                    activation_fn=tf.identity
90
                    )
91
                #fc = tc.layers.batch_norm(fc,decay=0.9,scale=True,updates_collections=None,is_training = True)
92
                
93
                fc = leaky_relu(fc)
94
                if self.concat_every_fcl:
95
                    fc = tf.concat([fc, y], 1)
96
            
97
            output = tcl.fully_connected(
98
                fc, self.output_dim,
99
                weights_initializer=tf.random_normal_initializer(stddev=0.02),
100
                weights_regularizer=tcl.l2_regularizer(2.5e-5),
101
                #activation_fn=tf.sigmoid
102
                activation_fn=tf.identity
103
                )
104
            #output = tc.layers.batch_norm(output,decay=0.9,scale=True,updates_collections=None,is_training = True)
105
            #output = tf.nn.relu(output)
106
            return output
107
108
    @property
109
    def vars(self):
110
        return [var for var in tf.global_variables() if self.name in var.name]
111
112
113
114
class Encoder(object):
115
    def __init__(self, input_dim, output_dim, feat_dim, name, nb_layers=2, nb_units=256):
116
        self.input_dim = input_dim
117
        self.output_dim = output_dim
118
        self.feat_dim = feat_dim
119
        self.name = name
120
        self.nb_layers = nb_layers
121
        self.nb_units = nb_units
122
123
    def __call__(self, x, reuse=True):
124
        #with tf.variable_scope(self.name,reuse=tf.AUTO_REUSE) as vs:
125
        with tf.variable_scope(self.name) as vs:
126
            if reuse:
127
                vs.reuse_variables()
128
            fc = tcl.fully_connected(
129
                x, self.nb_units,
130
                #weights_initializer=tf.random_normal_initializer(stddev=0.02),
131
                activation_fn=tf.identity
132
                )
133
            fc = leaky_relu(fc)
134
            for _ in range(self.nb_layers-1):
135
                fc = tcl.fully_connected(
136
                    fc, self.nb_units,
137
                    #weights_initializer=tf.random_normal_initializer(stddev=0.02),
138
                    activation_fn=tf.identity
139
                    )
140
                fc = leaky_relu(fc)
141
142
            output = tcl.fully_connected(
143
                fc, self.output_dim, 
144
                #weights_initializer=tf.random_normal_initializer(stddev=0.02),
145
                activation_fn=tf.identity
146
                )               
147
            logits = output[:, self.feat_dim:]
148
            y = tf.nn.softmax(logits)
149
            #return output[:, 0:self.feat_dim], y, logits
150
            return output, y
151
152
    @property
153
    def vars(self):
154
        return [var for var in tf.global_variables() if self.name in var.name]
155
156
157
class Discriminator_img(object):
158
    def __init__(self, input_dim, name, nb_layers=2,nb_units=256,dataset='mnist'):
159
        self.input_dim = input_dim
160
        self.name = name
161
        self.nb_layers = nb_layers
162
        self.nb_units = nb_units
163
        self.dataset = dataset
164
165
    def __call__(self, z, reuse=True):
166
        with tf.variable_scope(self.name) as vs:
167
            if reuse:
168
                vs.reuse_variables()
169
            bs = tf.shape(z)[0]
170
171
            if self.dataset=="mnist":
172
                z = tf.reshape(z, [bs, 28, 28, 1])
173
            elif self.dataset=="cifar10":
174
                z = tf.reshape(z, [bs, 32, 32, 3])
175
            conv = tcl.convolution2d(z, 64, [4,4],[2,2],
176
                weights_initializer=tf.random_normal_initializer(stddev=0.02),
177
                activation_fn=tf.identity
178
                )
179
            #(bs, 14, 14, 32)
180
            conv = leaky_relu(conv)
181
            for _ in range(self.nb_layers-1):
182
                conv = tcl.convolution2d(conv, 128, [4,4],[2,2],
183
                    weights_initializer=tf.random_normal_initializer(stddev=0.02),
184
                    activation_fn=tf.identity
185
                    )
186
                #conv = tc.layers.batch_norm(conv,decay=0.9,scale=True,updates_collections=None)
187
                conv = leaky_relu(conv)
188
            #(bs, 7, 7, 32)
189
            #fc = tf.reshape(conv, [bs, -1])
190
            fc = tcl.flatten(conv)
191
            #(bs, 1568)
192
            fc = tcl.fully_connected(
193
                fc, 1024,
194
                weights_initializer=tf.random_normal_initializer(stddev=0.02),
195
                activation_fn=tf.identity
196
                )
197
            #fc = tc.layers.batch_norm(fc,decay=0.9,scale=True,updates_collections=None)
198
            fc = leaky_relu(fc)
199
            output = tcl.fully_connected(
200
                fc, 1, 
201
                activation_fn=tf.identity
202
                )
203
            return output
204
205
    @property
206
    def vars(self):
207
        return [var for var in tf.global_variables() if self.name in var.name]
208
209
210
#generator for images, G()
211
class Generator_img(object):
212
    def __init__(self, nb_classes, output_dim, name, nb_layers=2,nb_units=256,dataset='mnist',is_training=True):
213
        self.nb_classes = nb_classes
214
        self.output_dim = output_dim
215
        self.name = name
216
        self.nb_layers = nb_layers
217
        self.nb_units = nb_units
218
        self.dataset = dataset
219
        self.is_training = is_training
220
221
    def __call__(self, z, reuse=True):
222
        #with tf.variable_scope(self.name,reuse=tf.AUTO_REUSE) as vs:       
223
        with tf.variable_scope(self.name) as vs:
224
            if reuse:
225
                vs.reuse_variables()
226
            bs = tf.shape(z)[0]
227
            y = z[:,-10:]
228
            #yb = tf.reshape(y, shape=[bs, 1, 1, 10])
229
            fc = tcl.fully_connected(
230
                z, 1024,
231
                weights_initializer=tf.random_normal_initializer(stddev=0.02),
232
                weights_regularizer=tc.layers.l2_regularizer(2.5e-5),
233
                activation_fn=tf.identity
234
                )
235
            fc = tc.layers.batch_norm(fc,decay=0.9,scale=True,updates_collections=None,is_training = self.is_training)
236
            fc = tf.nn.relu(fc)
237
            #fc = tf.concat([fc, y], 1)
238
239
            if self.dataset=='mnist':
240
                fc = tcl.fully_connected(
241
                    fc, 7*7*128,
242
                    weights_initializer=tf.random_normal_initializer(stddev=0.02),
243
                    weights_regularizer=tc.layers.l2_regularizer(2.5e-5),
244
                    activation_fn=tf.identity
245
                    )
246
                fc = tf.reshape(fc, tf.stack([bs, 7, 7, 128]))
247
            elif self.dataset=='cifar10':
248
                fc = tcl.fully_connected(
249
                    fc, 8*8*128,
250
                    weights_initializer=tf.random_normal_initializer(stddev=0.02),
251
                    weights_regularizer=tc.layers.l2_regularizer(2.5e-5),
252
                    activation_fn=tf.identity
253
                    )
254
                fc = tf.reshape(fc, tf.stack([bs, 8, 8, 128]))
255
            fc = tc.layers.batch_norm(fc,decay=0.9,scale=True,updates_collections=None,is_training = self.is_training)
256
            fc = tf.nn.relu(fc)
257
            #fc = conv_cond_concat(fc,yb)
258
            conv = tcl.convolution2d_transpose(
259
                fc, 64, [4,4], [2,2],
260
                weights_initializer=tf.random_normal_initializer(stddev=0.02),
261
                weights_regularizer=tc.layers.l2_regularizer(2.5e-5),
262
                activation_fn=tf.identity
263
            )
264
            #(bs,14,14,64)
265
            conv = tc.layers.batch_norm(conv,decay=0.9,scale=True,updates_collections=None,is_training = self.is_training)
266
            conv = tf.nn.relu(conv)
267
            if self.dataset=='mnist':
268
                output = tcl.convolution2d_transpose(
269
                    conv, 1, [4, 4], [2, 2],
270
                    weights_initializer=tf.random_normal_initializer(stddev=0.02),
271
                    weights_regularizer=tc.layers.l2_regularizer(2.5e-5),
272
                    activation_fn=tf.nn.sigmoid
273
                )
274
                output = tf.reshape(output, [bs, -1])
275
            elif self.dataset=='cifar10':
276
                output = tcl.convolution2d_transpose(
277
                    conv, 3, [4, 4], [2, 2],
278
                    weights_initializer=tf.random_normal_initializer(stddev=0.02),
279
                    weights_regularizer=tc.layers.l2_regularizer(2.5e-5),
280
                    activation_fn=tf.nn.sigmoid
281
                )
282
                output = tf.reshape(output, [bs, -1])
283
            #(0,1) by tanh
284
            return output
285
286
    @property
287
    def vars(self):
288
        return [var for var in tf.global_variables() if self.name in var.name]
289
290
#encoder for images, H()
291
class Encoder_img(object):
292
    def __init__(self, nb_classes, output_dim, name, nb_layers=2,nb_units=256,dataset='mnist',cond=True):
293
        self.nb_classes = nb_classes
294
        self.output_dim = output_dim
295
        self.name = name
296
        self.nb_layers = nb_layers
297
        self.nb_units = nb_units
298
        self.dataset = dataset
299
        self.cond = cond
300
301
    def __call__(self, x, reuse=True):
302
        with tf.variable_scope(self.name) as vs:
303
            if reuse:
304
                vs.reuse_variables()
305
            bs = tf.shape(x)[0]
306
            if self.dataset=="mnist":
307
                x = tf.reshape(x, [bs, 28, 28, 1])
308
            elif self.dataset=="cifar10":
309
                x = tf.reshape(x, [bs, 32, 32, 3])
310
            conv = tcl.convolution2d(x,64,[4,4],[2,2],
311
                weights_initializer=tf.random_normal_initializer(stddev=0.02),
312
                weights_regularizer=tc.layers.l2_regularizer(2.5e-5),
313
                activation_fn=tf.identity
314
                )
315
            conv = leaky_relu(conv)
316
            for _ in range(self.nb_layers-1):
317
                conv = tcl.convolution2d(conv, self.nb_units, [4,4],[2,2],
318
                    weights_initializer=tf.random_normal_initializer(stddev=0.02),
319
                    weights_regularizer=tc.layers.l2_regularizer(2.5e-5),
320
                    activation_fn=tf.identity
321
                    )
322
                conv = leaky_relu(conv)
323
            conv = tcl.flatten(conv)
324
            fc = tcl.fully_connected(conv, 1024, 
325
                weights_initializer=tf.random_normal_initializer(stddev=0.02),
326
                weights_regularizer=tc.layers.l2_regularizer(2.5e-5),
327
                activation_fn=tf.identity)
328
            
329
            fc = leaky_relu(fc)
330
            output = tcl.fully_connected(
331
                fc, self.output_dim, 
332
                activation_fn=tf.identity
333
                )        
334
            logits = output[:, -self.nb_classes:]
335
            y = tf.nn.softmax(logits)
336
            return output[:, :-self.nb_classes], y, logits     
337
338
    @property
339
    def vars(self):
340
        return [var for var in tf.global_variables() if self.name in var.name]