a b/models/autoencoder.py
1
import numpy as np
2
import tensorflow as tf
3
from tensorflow.compat.v1.layers import Dense
4
from tensorflow.python.keras.layers import Conv2D, Flatten, Dropout
5
6
from models.customlayers import build_unified_encoder, build_unified_decoder
7
8
9
def autoencoder(x, dropout_rate, dropout, config):
10
    outputs = {}
11
12
    with tf.variable_scope('Encoder'):
13
        encoder = build_unified_encoder(x.get_shape().as_list(), config.intermediateResolutions)
14
15
        temp_out = x
16
        for layer in encoder:
17
            temp_out = layer(temp_out)
18
19
    with tf.variable_scope("Bottleneck"):
20
        intermediate_conv = Conv2D(temp_out.get_shape().as_list()[3] // 8, 1, padding='same')
21
        intermediate_conv_reverse = Conv2D(temp_out.get_shape().as_list()[3], 1, padding='same')
22
        dropout_layer = Dropout(dropout_rate)
23
        temp_out = intermediate_conv(temp_out)
24
25
        reshape = temp_out.get_shape().as_list()[1:]
26
        z_layer = Dense(config.zDim)
27
        dec_dense = Dense(np.prod(reshape))
28
29
        outputs['z'] = z = dropout_layer(z_layer(Flatten()(temp_out)), dropout)
30
        temp_out = intermediate_conv_reverse(tf.reshape(dropout_layer(dec_dense(z)), [-1, *reshape]))
31
32
    with tf.variable_scope('Decoder'):
33
        decoder = build_unified_decoder(config.outputWidth, config.intermediateResolutions, config.numChannels)
34
        # Decode: z -> x_hat
35
        for layer in decoder:
36
            temp_out = layer(temp_out)
37
38
        outputs['x_hat'] = temp_out
39
40
    return outputs