|
a |
|
b/models/autoencoder.py |
|
|
1 |
import numpy as np |
|
|
2 |
import tensorflow as tf |
|
|
3 |
from tensorflow.compat.v1.layers import Dense |
|
|
4 |
from tensorflow.python.keras.layers import Conv2D, Flatten, Dropout |
|
|
5 |
|
|
|
6 |
from models.customlayers import build_unified_encoder, build_unified_decoder |
|
|
7 |
|
|
|
8 |
|
|
|
9 |
def autoencoder(x, dropout_rate, dropout, config): |
|
|
10 |
outputs = {} |
|
|
11 |
|
|
|
12 |
with tf.variable_scope('Encoder'): |
|
|
13 |
encoder = build_unified_encoder(x.get_shape().as_list(), config.intermediateResolutions) |
|
|
14 |
|
|
|
15 |
temp_out = x |
|
|
16 |
for layer in encoder: |
|
|
17 |
temp_out = layer(temp_out) |
|
|
18 |
|
|
|
19 |
with tf.variable_scope("Bottleneck"): |
|
|
20 |
intermediate_conv = Conv2D(temp_out.get_shape().as_list()[3] // 8, 1, padding='same') |
|
|
21 |
intermediate_conv_reverse = Conv2D(temp_out.get_shape().as_list()[3], 1, padding='same') |
|
|
22 |
dropout_layer = Dropout(dropout_rate) |
|
|
23 |
temp_out = intermediate_conv(temp_out) |
|
|
24 |
|
|
|
25 |
reshape = temp_out.get_shape().as_list()[1:] |
|
|
26 |
z_layer = Dense(config.zDim) |
|
|
27 |
dec_dense = Dense(np.prod(reshape)) |
|
|
28 |
|
|
|
29 |
outputs['z'] = z = dropout_layer(z_layer(Flatten()(temp_out)), dropout) |
|
|
30 |
temp_out = intermediate_conv_reverse(tf.reshape(dropout_layer(dec_dense(z)), [-1, *reshape])) |
|
|
31 |
|
|
|
32 |
with tf.variable_scope('Decoder'): |
|
|
33 |
decoder = build_unified_decoder(config.outputWidth, config.intermediateResolutions, config.numChannels) |
|
|
34 |
# Decode: z -> x_hat |
|
|
35 |
for layer in decoder: |
|
|
36 |
temp_out = layer(temp_out) |
|
|
37 |
|
|
|
38 |
outputs['x_hat'] = temp_out |
|
|
39 |
|
|
|
40 |
return outputs |