a b/Network.py
1
# digital networks for End-to-end optimization
2
# Author: Yicheng Wu @ Rice University
3
# 03/29/2019
4
5
6
import tensorflow as tf
7
8
9
def max_pool_2x2(x):
10
    return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
11
12
13
def conv2dPad(x, W):
14
    N = int((W.shape[0].value - 1) / 2)
15
    x_pad = tf.pad(x, [[0, 0], [N, N], [N, N], [0, 0]], "SYMMETRIC")
16
    return tf.nn.conv2d(x_pad, W, strides=[1, 1, 1, 1], padding='VALID')
17
18
19
def BN(x, phase_BN, scope):
20
    return tf.layers.batch_normalization(x, momentum=0.9, training=phase_BN)
21
22
23
def cnnLayerPad(scope_name, inputs, outChannels, kernel_size, is_training, relu=True, maxpool=True):
24
    with tf.variable_scope(scope_name) as scope:
25
        inChannels = inputs.shape[-1].value
26
        W_conv = tf.get_variable('W_conv', [kernel_size, kernel_size, inChannels, outChannels])
27
        b_conv = tf.get_variable('b_conv', [outChannels])
28
        x_conv = conv2dPad(inputs, W_conv) + b_conv
29
        out = BN(x_conv, is_training, scope)
30
        if relu:
31
            out = tf.nn.relu(out)
32
        if maxpool:
33
            out = max_pool_2x2(out)
34
        return out
35
36
37
def cnn3x3Pad(scope_name, inputs, outChannels, is_training, relu=True):
38
    with tf.variable_scope(scope_name) as scope:
39
        inChannels = inputs.shape[-1].value
40
        W_conv = tf.get_variable('W_conv', [3, 3, inChannels, outChannels])
41
        b_conv = tf.get_variable('b_conv', [outChannels])
42
        x_conv = conv2dPad(inputs, W_conv) + b_conv
43
        out = BN(x_conv, is_training, scope)
44
        if relu:
45
            out = tf.nn.relu(out)
46
47
        return out
48
49
50
def deconv(scope_name, inputs, outChannels):
51
    with tf.variable_scope(scope_name) as scope:
52
        inChannels = inputs.shape[-1].value
53
        Nx = inputs.shape[1].value
54
        Ny = inputs.shape[2].value
55
        W = tf.get_variable('resize_conv', [3, 3, inChannels, outChannels])
56
        resize = tf.image.resize_nearest_neighbor(inputs, [2 * Nx, 2 * Ny])
57
        output = conv2dPad(resize, W)
58
        return output
59
60
61
def UNet(inputs, phase_BN):
62
    # in the decoder, instead of using transpose conv, do resize by 2 + conv2
63
    # for each conv, do pad+valid instead of same
64
    down1_1 = cnn3x3Pad('down1_1', inputs, 32, phase_BN)
65
    down1_2 = cnn3x3Pad('down1_2', down1_1, 32, phase_BN)
66
67
    down2_0 = max_pool_2x2(down1_2)
68
    down2_1 = cnn3x3Pad('down2_1', down2_0, 64, phase_BN)
69
    down2_2 = cnn3x3Pad('down2_2', down2_1, 64, phase_BN)
70
71
    down3_0 = max_pool_2x2(down2_2)
72
    down3_1 = cnn3x3Pad('down3_1', down3_0, 128, phase_BN)
73
    down3_2 = cnn3x3Pad('down3_2', down3_1, 128, phase_BN)
74
75
    down4_0 = max_pool_2x2(down3_2)
76
    down4_1 = cnn3x3Pad('down4_1', down4_0, 256, phase_BN)
77
    down4_2 = cnn3x3Pad('down4_2', down4_1, 256, phase_BN)
78
79
    down5_0 = max_pool_2x2(down4_2)
80
    down5_1 = cnn3x3Pad('down5_1', down5_0, 512, phase_BN)
81
    down5_2 = cnn3x3Pad('down5_2', down5_1, 512, phase_BN)
82
83
    up4_0 = tf.concat([deconv('up4_0', down5_2, 256), down4_2], axis=-1)
84
    up4_1 = cnn3x3Pad('up4_1', up4_0, 256, phase_BN)
85
    up4_2 = cnn3x3Pad('up4_2', up4_1, 256, phase_BN)
86
87
    up3_0 = tf.concat([deconv('up3_0', up4_2, 128), down3_2], axis=-1)
88
    up3_1 = cnn3x3Pad('up3_1', up3_0, 128, phase_BN)
89
    up3_2 = cnn3x3Pad('up3_2', up3_1, 128, phase_BN)
90
91
    up2_0 = tf.concat([deconv('up2_0', up3_2, 64), down2_2], axis=-1)
92
    up2_1 = cnn3x3Pad('up2_1', up2_0, 64, phase_BN)
93
    up2_2 = cnn3x3Pad('up2_2', up2_1, 64, phase_BN)
94
95
    up1_0 = tf.concat([deconv('up1_0', up2_2, 32), down1_2], axis=-1)
96
    up1_1 = cnn3x3Pad('up1_1', up1_0, 32, phase_BN)
97
    up1_2 = cnn3x3Pad('up1_2', up1_1, 32, phase_BN)
98
99
    up1_3 = cnnLayerPad('up1_3', up1_2, 1, 1, phase_BN, relu=False, maxpool=False)
100
    out = tf.tanh(up1_3) + inputs
101
102
    # set the range to be in (0,1)
103
    out = tf.minimum(out, 1.0)
104
    out = tf.maximum(out, 0.0)
105
106
    return out
107