|
a |
|
b/Network.py |
|
|
1 |
# digital networks for End-to-end optimization |
|
|
2 |
# Author: Yicheng Wu @ Rice University |
|
|
3 |
# 03/29/2019 |
|
|
4 |
|
|
|
5 |
|
|
|
6 |
import tensorflow as tf |
|
|
7 |
|
|
|
8 |
|
|
|
9 |
def max_pool_2x2(x): |
|
|
10 |
return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') |
|
|
11 |
|
|
|
12 |
|
|
|
13 |
def conv2dPad(x, W): |
|
|
14 |
N = int((W.shape[0].value - 1) / 2) |
|
|
15 |
x_pad = tf.pad(x, [[0, 0], [N, N], [N, N], [0, 0]], "SYMMETRIC") |
|
|
16 |
return tf.nn.conv2d(x_pad, W, strides=[1, 1, 1, 1], padding='VALID') |
|
|
17 |
|
|
|
18 |
|
|
|
19 |
def BN(x, phase_BN, scope): |
|
|
20 |
return tf.layers.batch_normalization(x, momentum=0.9, training=phase_BN) |
|
|
21 |
|
|
|
22 |
|
|
|
23 |
def cnnLayerPad(scope_name, inputs, outChannels, kernel_size, is_training, relu=True, maxpool=True): |
|
|
24 |
with tf.variable_scope(scope_name) as scope: |
|
|
25 |
inChannels = inputs.shape[-1].value |
|
|
26 |
W_conv = tf.get_variable('W_conv', [kernel_size, kernel_size, inChannels, outChannels]) |
|
|
27 |
b_conv = tf.get_variable('b_conv', [outChannels]) |
|
|
28 |
x_conv = conv2dPad(inputs, W_conv) + b_conv |
|
|
29 |
out = BN(x_conv, is_training, scope) |
|
|
30 |
if relu: |
|
|
31 |
out = tf.nn.relu(out) |
|
|
32 |
if maxpool: |
|
|
33 |
out = max_pool_2x2(out) |
|
|
34 |
return out |
|
|
35 |
|
|
|
36 |
|
|
|
37 |
def cnn3x3Pad(scope_name, inputs, outChannels, is_training, relu=True): |
|
|
38 |
with tf.variable_scope(scope_name) as scope: |
|
|
39 |
inChannels = inputs.shape[-1].value |
|
|
40 |
W_conv = tf.get_variable('W_conv', [3, 3, inChannels, outChannels]) |
|
|
41 |
b_conv = tf.get_variable('b_conv', [outChannels]) |
|
|
42 |
x_conv = conv2dPad(inputs, W_conv) + b_conv |
|
|
43 |
out = BN(x_conv, is_training, scope) |
|
|
44 |
if relu: |
|
|
45 |
out = tf.nn.relu(out) |
|
|
46 |
|
|
|
47 |
return out |
|
|
48 |
|
|
|
49 |
|
|
|
50 |
def deconv(scope_name, inputs, outChannels): |
|
|
51 |
with tf.variable_scope(scope_name) as scope: |
|
|
52 |
inChannels = inputs.shape[-1].value |
|
|
53 |
Nx = inputs.shape[1].value |
|
|
54 |
Ny = inputs.shape[2].value |
|
|
55 |
W = tf.get_variable('resize_conv', [3, 3, inChannels, outChannels]) |
|
|
56 |
resize = tf.image.resize_nearest_neighbor(inputs, [2 * Nx, 2 * Ny]) |
|
|
57 |
output = conv2dPad(resize, W) |
|
|
58 |
return output |
|
|
59 |
|
|
|
60 |
|
|
|
61 |
def UNet(inputs, phase_BN): |
|
|
62 |
# in the decoder, instead of using transpose conv, do resize by 2 + conv2 |
|
|
63 |
# for each conv, do pad+valid instead of same |
|
|
64 |
down1_1 = cnn3x3Pad('down1_1', inputs, 32, phase_BN) |
|
|
65 |
down1_2 = cnn3x3Pad('down1_2', down1_1, 32, phase_BN) |
|
|
66 |
|
|
|
67 |
down2_0 = max_pool_2x2(down1_2) |
|
|
68 |
down2_1 = cnn3x3Pad('down2_1', down2_0, 64, phase_BN) |
|
|
69 |
down2_2 = cnn3x3Pad('down2_2', down2_1, 64, phase_BN) |
|
|
70 |
|
|
|
71 |
down3_0 = max_pool_2x2(down2_2) |
|
|
72 |
down3_1 = cnn3x3Pad('down3_1', down3_0, 128, phase_BN) |
|
|
73 |
down3_2 = cnn3x3Pad('down3_2', down3_1, 128, phase_BN) |
|
|
74 |
|
|
|
75 |
down4_0 = max_pool_2x2(down3_2) |
|
|
76 |
down4_1 = cnn3x3Pad('down4_1', down4_0, 256, phase_BN) |
|
|
77 |
down4_2 = cnn3x3Pad('down4_2', down4_1, 256, phase_BN) |
|
|
78 |
|
|
|
79 |
down5_0 = max_pool_2x2(down4_2) |
|
|
80 |
down5_1 = cnn3x3Pad('down5_1', down5_0, 512, phase_BN) |
|
|
81 |
down5_2 = cnn3x3Pad('down5_2', down5_1, 512, phase_BN) |
|
|
82 |
|
|
|
83 |
up4_0 = tf.concat([deconv('up4_0', down5_2, 256), down4_2], axis=-1) |
|
|
84 |
up4_1 = cnn3x3Pad('up4_1', up4_0, 256, phase_BN) |
|
|
85 |
up4_2 = cnn3x3Pad('up4_2', up4_1, 256, phase_BN) |
|
|
86 |
|
|
|
87 |
up3_0 = tf.concat([deconv('up3_0', up4_2, 128), down3_2], axis=-1) |
|
|
88 |
up3_1 = cnn3x3Pad('up3_1', up3_0, 128, phase_BN) |
|
|
89 |
up3_2 = cnn3x3Pad('up3_2', up3_1, 128, phase_BN) |
|
|
90 |
|
|
|
91 |
up2_0 = tf.concat([deconv('up2_0', up3_2, 64), down2_2], axis=-1) |
|
|
92 |
up2_1 = cnn3x3Pad('up2_1', up2_0, 64, phase_BN) |
|
|
93 |
up2_2 = cnn3x3Pad('up2_2', up2_1, 64, phase_BN) |
|
|
94 |
|
|
|
95 |
up1_0 = tf.concat([deconv('up1_0', up2_2, 32), down1_2], axis=-1) |
|
|
96 |
up1_1 = cnn3x3Pad('up1_1', up1_0, 32, phase_BN) |
|
|
97 |
up1_2 = cnn3x3Pad('up1_2', up1_1, 32, phase_BN) |
|
|
98 |
|
|
|
99 |
up1_3 = cnnLayerPad('up1_3', up1_2, 1, 1, phase_BN, relu=False, maxpool=False) |
|
|
100 |
out = tf.tanh(up1_3) + inputs |
|
|
101 |
|
|
|
102 |
# set the range to be in (0,1) |
|
|
103 |
out = tf.minimum(out, 1.0) |
|
|
104 |
out = tf.maximum(out, 0.0) |
|
|
105 |
|
|
|
106 |
return out |
|
|
107 |
|