|
a |
|
b/BraTs18Challege/Vnet/layer.py |
|
|
1 |
''' |
|
|
2 |
covlution layer,pool layer,initialization。。。。 |
|
|
3 |
''' |
|
|
4 |
from __future__ import division |
|
|
5 |
import tensorflow as tf |
|
|
6 |
import numpy as np |
|
|
7 |
import cv2 |
|
|
8 |
|
|
|
9 |
|
|
|
10 |
# Weight initialization (Xavier's init) |
|
|
11 |
def weight_xavier_init(shape, n_inputs, n_outputs, activefunction='sigomd', uniform=True, variable_name=None): |
|
|
12 |
if activefunction == 'sigomd': |
|
|
13 |
if uniform: |
|
|
14 |
init_range = tf.sqrt(6.0 / (n_inputs + n_outputs)) |
|
|
15 |
initial = tf.random_uniform(shape, -init_range, init_range) |
|
|
16 |
return tf.get_variable(name=variable_name, initializer=initial, trainable=True) |
|
|
17 |
else: |
|
|
18 |
stddev = tf.sqrt(2.0 / (n_inputs + n_outputs)) |
|
|
19 |
initial = tf.truncated_normal(shape, mean=0.0, stddev=stddev) |
|
|
20 |
return tf.get_variable(name=variable_name, initializer=initial, trainable=True) |
|
|
21 |
elif activefunction == 'relu': |
|
|
22 |
if uniform: |
|
|
23 |
init_range = tf.sqrt(6.0 / (n_inputs + n_outputs)) * np.sqrt(2) |
|
|
24 |
initial = tf.random_uniform(shape, -init_range, init_range) |
|
|
25 |
return tf.get_variable(name=variable_name, initializer=initial, trainable=True) |
|
|
26 |
else: |
|
|
27 |
stddev = tf.sqrt(2.0 / (n_inputs + n_outputs)) * np.sqrt(2) |
|
|
28 |
initial = tf.truncated_normal(shape, mean=0.0, stddev=stddev) |
|
|
29 |
return tf.get_variable(name=variable_name, initializer=initial, trainable=True) |
|
|
30 |
elif activefunction == 'tan': |
|
|
31 |
if uniform: |
|
|
32 |
init_range = tf.sqrt(6.0 / (n_inputs + n_outputs)) * 4 |
|
|
33 |
initial = tf.random_uniform(shape, -init_range, init_range) |
|
|
34 |
return tf.get_variable(name=variable_name, initializer=initial, trainable=True) |
|
|
35 |
else: |
|
|
36 |
stddev = tf.sqrt(2.0 / (n_inputs + n_outputs)) * 4 |
|
|
37 |
initial = tf.truncated_normal(shape, mean=0.0, stddev=stddev) |
|
|
38 |
return tf.get_variable(name=variable_name, initializer=initial, trainable=True) |
|
|
39 |
|
|
|
40 |
|
|
|
41 |
# Bias initialization |
|
|
42 |
def bias_variable(shape, variable_name=None): |
|
|
43 |
initial = tf.constant(0.1, shape=shape) |
|
|
44 |
return tf.get_variable(name=variable_name, initializer=initial, trainable=True) |
|
|
45 |
|
|
|
46 |
|
|
|
47 |
# 3D convolution |
|
|
48 |
def conv3d(x, W, stride=1): |
|
|
49 |
conv_3d = tf.nn.conv3d(x, W, strides=[1, stride, stride, stride, 1], padding='SAME') |
|
|
50 |
return conv_3d |
|
|
51 |
|
|
|
52 |
|
|
|
53 |
# 3D upsampling |
|
|
54 |
def upsample3d(x, scale_factor, scope=None): |
|
|
55 |
'''' |
|
|
56 |
X shape is [nsample,dim,rows, cols, channel] |
|
|
57 |
out shape is[nsample,dim*scale_factor,rows*scale_factor, cols*scale_factor, channel] |
|
|
58 |
''' |
|
|
59 |
x_shape = tf.shape(x) |
|
|
60 |
k = tf.ones([scale_factor, scale_factor, scale_factor, x_shape[-1], x_shape[-1]]) |
|
|
61 |
# note k.shape = [dim,rows, cols, depth_in, depth_output] |
|
|
62 |
output_shape = tf.stack( |
|
|
63 |
[x_shape[0], x_shape[1] * scale_factor, x_shape[2] * scale_factor, x_shape[3] * scale_factor, x_shape[4]]) |
|
|
64 |
upsample = tf.nn.conv3d_transpose(value=x, filter=k, output_shape=output_shape, |
|
|
65 |
strides=[1, scale_factor, scale_factor, scale_factor, 1], |
|
|
66 |
padding='SAME', name=scope) |
|
|
67 |
return upsample |
|
|
68 |
|
|
|
69 |
|
|
|
70 |
# 3D deconvolution |
|
|
71 |
def deconv3d(x, W, samefeature=False, depth=False): |
|
|
72 |
""" |
|
|
73 |
depth flag:False is z axis is same between input and output,true is z axis is input is twice than output |
|
|
74 |
""" |
|
|
75 |
x_shape = tf.shape(x) |
|
|
76 |
if depth: |
|
|
77 |
if samefeature: |
|
|
78 |
output_shape = tf.stack([x_shape[0], x_shape[1] * 2, x_shape[2] * 2, x_shape[3] * 2, x_shape[4]]) |
|
|
79 |
else: |
|
|
80 |
output_shape = tf.stack([x_shape[0], x_shape[1] * 2, x_shape[2] * 2, x_shape[3] * 2, x_shape[4] // 2]) |
|
|
81 |
deconv = tf.nn.conv3d_transpose(x, W, output_shape, strides=[1, 2, 2, 2, 1], padding='SAME') |
|
|
82 |
else: |
|
|
83 |
if samefeature: |
|
|
84 |
output_shape = tf.stack([x_shape[0], x_shape[1] * 2, x_shape[2] * 2, x_shape[3], x_shape[4]]) |
|
|
85 |
else: |
|
|
86 |
output_shape = tf.stack([x_shape[0], x_shape[1] * 2, x_shape[2] * 2, x_shape[3], x_shape[4] // 2]) |
|
|
87 |
deconv = tf.nn.conv3d_transpose(x, W, output_shape, strides=[1, 2, 2, 1, 1], padding='SAME') |
|
|
88 |
return deconv |
|
|
89 |
|
|
|
90 |
|
|
|
91 |
# Max Pooling |
|
|
92 |
def max_pool3d(x, depth=False): |
|
|
93 |
""" |
|
|
94 |
depth flag:False is z axis is same between input and output,true is z axis is input is twice than output |
|
|
95 |
""" |
|
|
96 |
if depth: |
|
|
97 |
pool3d = tf.nn.max_pool3d(x, ksize=[1, 2, 2, 2, 1], strides=[1, 2, 2, 2, 1], padding='SAME') |
|
|
98 |
else: |
|
|
99 |
pool3d = tf.nn.max_pool3d(x, ksize=[1, 2, 2, 1, 1], strides=[1, 2, 2, 1, 1], padding='SAME') |
|
|
100 |
return pool3d |
|
|
101 |
|
|
|
102 |
|
|
|
103 |
# Unet crop and concat |
|
|
104 |
def crop_and_concat(x1, x2): |
|
|
105 |
""" |
|
|
106 |
concat x1 and x2 |
|
|
107 |
:param x1: |
|
|
108 |
:param x2: |
|
|
109 |
:return: |
|
|
110 |
""" |
|
|
111 |
x1_shape = tf.shape(x1) |
|
|
112 |
x2_shape = tf.shape(x2) |
|
|
113 |
# offsets for the top left corner of the crop |
|
|
114 |
offsets = [0, (x1_shape[1] - x2_shape[1]) // 2, |
|
|
115 |
(x1_shape[2] - x2_shape[2]) // 2, (x1_shape[3] - x2_shape[3]) // 2, 0] |
|
|
116 |
size = [-1, x2_shape[1], x2_shape[2], x2_shape[3], -1] |
|
|
117 |
x1_crop = tf.slice(x1, offsets, size) |
|
|
118 |
return tf.concat([x1_crop, x2], 4) |
|
|
119 |
|
|
|
120 |
|
|
|
121 |
# Batch Normalization |
|
|
122 |
def normalizationlayer(x, is_train, height=None, width=None, image_z=None, norm_type=None, G=16, esp=1e-5, scope=None): |
|
|
123 |
""" |
|
|
124 |
normalizationlayer |
|
|
125 |
:param x:input data with shap of[batch,height,width,channel] |
|
|
126 |
:param is_train:flag of normalizationlayer,True is training,False is Testing |
|
|
127 |
:param height:in some condition,the data height is in Runtime determined,such as through deconv layer and conv2d |
|
|
128 |
:param width:in some condition,the data width is in Runtime determined |
|
|
129 |
:param image_z: |
|
|
130 |
:param norm_type:normalization type:support"batch","group","None" |
|
|
131 |
:param G:in group normalization,channel is seperated with group number(G) |
|
|
132 |
:param esp:Prevent divisor from being zero |
|
|
133 |
:param scope:normalizationlayer scope |
|
|
134 |
:return: |
|
|
135 |
""" |
|
|
136 |
with tf.name_scope(scope + norm_type): |
|
|
137 |
if norm_type == None: |
|
|
138 |
output = x |
|
|
139 |
elif norm_type == 'batch': |
|
|
140 |
output = tf.contrib.layers.batch_norm(x, center=True, scale=True, is_train=is_train) |
|
|
141 |
elif norm_type == "group": |
|
|
142 |
# tranpose:[bs,z,h,w,c]to[bs,c,z,h,w]following the paper |
|
|
143 |
x = tf.transpose(x, [0, 4, 1, 2, 3]) |
|
|
144 |
N, C, Z, H, W = x.get_shape().as_list() |
|
|
145 |
G = min(G, C) |
|
|
146 |
if H == None and W == None and Z == None: |
|
|
147 |
Z, H, W = image_z, height, width |
|
|
148 |
x = tf.reshape(x, [-1, G, C // G, Z, H, W]) |
|
|
149 |
mean, var = tf.nn.moments(x, [2, 3, 4, 5], keep_dims=True) |
|
|
150 |
x = (x - mean) / tf.sqrt(var + esp) |
|
|
151 |
gama = tf.get_variable(scope + norm_type + 'group_gama', [C], initializer=tf.constant_initializer(1.0)) |
|
|
152 |
beta = tf.get_variable(scope + norm_type + 'group_beta', [C], initializer=tf.constant_initializer(0.0)) |
|
|
153 |
gama = tf.reshape(gama, [1, C, 1, 1, 1]) |
|
|
154 |
beta = tf.reshape(beta, [1, C, 1, 1, 1]) |
|
|
155 |
output = tf.reshape(x, [-1, C, Z, H, W]) * gama + beta |
|
|
156 |
# tranpose:[bs,c,z,h,w]to[bs,z,h,w,c]following the paper |
|
|
157 |
output = tf.transpose(output, [0, 2, 3, 4, 1]) |
|
|
158 |
return output |
|
|
159 |
|
|
|
160 |
|
|
|
161 |
# resnet add_connect |
|
|
162 |
def resnet_Add(x1, x2): |
|
|
163 |
""" |
|
|
164 |
add x1 and x2 |
|
|
165 |
:param x1: |
|
|
166 |
:param x2: |
|
|
167 |
:return: |
|
|
168 |
""" |
|
|
169 |
if x1.get_shape().as_list()[4] != x2.get_shape().as_list()[4]: |
|
|
170 |
# Option A: Zero-padding |
|
|
171 |
residual_connection = x2 + tf.pad(x1, [[0, 0], [0, 0], [0, 0], [0, 0], |
|
|
172 |
[0, x2.get_shape().as_list()[4] - |
|
|
173 |
x1.get_shape().as_list()[4]]]) |
|
|
174 |
else: |
|
|
175 |
residual_connection = x2 + x1 |
|
|
176 |
return residual_connection |
|
|
177 |
|
|
|
178 |
|
|
|
179 |
def save_images(images, size, path): |
|
|
180 |
img = (images + 1.0) / 2.0 |
|
|
181 |
h, w = img.shape[1], img.shape[2] |
|
|
182 |
merge_img = np.zeros((h * size[0], w * size[1])) |
|
|
183 |
for idx, image in enumerate(images): |
|
|
184 |
i = idx % size[1] |
|
|
185 |
j = idx // size[1] |
|
|
186 |
merge_img[j * h:j * h + h, i * w:i * w + w] = image |
|
|
187 |
result = merge_img * 255. |
|
|
188 |
result = np.clip(result, 0, 255).astype('uint8') |
|
|
189 |
return cv2.imwrite(path, result) |
|
|
190 |
|
|
|
191 |
|
|
|
192 |
def gatingsignal3d(x, kernal, phase, image_z=None, height=None, width=None, scope=None): |
|
|
193 |
"""this is simply 1x1x1 convolution, bn, activation,Gating Signal(Query) |
|
|
194 |
:param x: |
|
|
195 |
:param kernal:(1,1,1,inputfilters,outputfilters) |
|
|
196 |
:param phase: |
|
|
197 |
:param drop: |
|
|
198 |
:param image_z: |
|
|
199 |
:param height: |
|
|
200 |
:param width: |
|
|
201 |
:param scope: |
|
|
202 |
:return: |
|
|
203 |
""" |
|
|
204 |
with tf.name_scope(scope): |
|
|
205 |
W = weight_xavier_init(shape=kernal, n_inputs=kernal[0] * kernal[1] * kernal[2] * kernal[3], |
|
|
206 |
n_outputs=kernal[-1], activefunction='relu', variable_name=scope + 'conv_W') |
|
|
207 |
B = bias_variable([kernal[-1]], variable_name=scope + 'conv_B') |
|
|
208 |
conv = conv3d(x, W) + B |
|
|
209 |
conv = normalizationlayer(conv, is_train=phase, height=height, width=width, image_z=image_z, norm_type='group', |
|
|
210 |
scope=scope) |
|
|
211 |
conv = tf.nn.relu(conv) |
|
|
212 |
return conv |
|
|
213 |
|
|
|
214 |
|
|
|
215 |
def attngatingblock(x, g, inputfilters, outfilters, scale_factor, phase, image_z=None, height=None, width=None, |
|
|
216 |
scope=None): |
|
|
217 |
""" |
|
|
218 |
take g which is the spatially smaller signal, do a conv to get the same number of feature channels as x (bigger spatially) |
|
|
219 |
do a conv on x to also get same feature channels (theta_x) |
|
|
220 |
then, upsample g to be same size as x add x and g (concat_xg) relu, 1x1x1 conv, then sigmoid then upsample the final - |
|
|
221 |
this gives us attn coefficients |
|
|
222 |
:param x: |
|
|
223 |
:param g: |
|
|
224 |
:param inputfilters: |
|
|
225 |
:param outfilters: |
|
|
226 |
:param scale_factor:2 |
|
|
227 |
:param scope: |
|
|
228 |
:return: |
|
|
229 |
""" |
|
|
230 |
with tf.name_scope(scope): |
|
|
231 |
kernalx = (1, 1, 1, inputfilters, outfilters) |
|
|
232 |
Wx = weight_xavier_init(shape=kernalx, n_inputs=kernalx[0] * kernalx[1] * kernalx[2] * kernalx[3], |
|
|
233 |
n_outputs=kernalx[-1], activefunction='relu', variable_name=scope + 'conv_Wx') |
|
|
234 |
Bx = bias_variable([kernalx[-1]], variable_name=scope + 'conv_Bx') |
|
|
235 |
theta_x = conv3d(x, Wx, scale_factor) + Bx |
|
|
236 |
kernalg = (1, 1, 1, inputfilters, outfilters) |
|
|
237 |
Wg = weight_xavier_init(shape=kernalg, n_inputs=kernalg[0] * kernalg[1] * kernalg[2] * kernalg[3], |
|
|
238 |
n_outputs=kernalg[-1], activefunction='relu', variable_name=scope + 'conv_Wg') |
|
|
239 |
Bg = bias_variable([kernalg[-1]], variable_name=scope + 'conv_Bg') |
|
|
240 |
phi_g = conv3d(g, Wg) + Bg |
|
|
241 |
|
|
|
242 |
add_xg = resnet_Add(theta_x, phi_g) |
|
|
243 |
act_xg = tf.nn.relu(add_xg) |
|
|
244 |
|
|
|
245 |
kernalpsi = (1, 1, 1, outfilters, 1) |
|
|
246 |
Wpsi = weight_xavier_init(shape=kernalpsi, n_inputs=kernalpsi[0] * kernalpsi[1] * kernalpsi[2] * kernalpsi[3], |
|
|
247 |
n_outputs=kernalpsi[-1], activefunction='relu', variable_name=scope + 'conv_Wpsi') |
|
|
248 |
Bpsi = bias_variable([kernalpsi[-1]], variable_name=scope + 'conv_Bpsi') |
|
|
249 |
psi = conv3d(act_xg, Wpsi) + Bpsi |
|
|
250 |
sigmoid_psi = tf.nn.sigmoid(psi) |
|
|
251 |
|
|
|
252 |
upsample_psi = upsample3d(sigmoid_psi, scale_factor=scale_factor, scope=scope + "resampler") |
|
|
253 |
|
|
|
254 |
# Attention: upsample_psi * x |
|
|
255 |
gat_x = tf.multiply(upsample_psi, x) |
|
|
256 |
kernal_gat_x = (1, 1, 1, outfilters, outfilters) |
|
|
257 |
Wgatx = weight_xavier_init(shape=kernal_gat_x, |
|
|
258 |
n_inputs=kernal_gat_x[0] * kernal_gat_x[1] * kernal_gat_x[2] * kernal_gat_x[3], |
|
|
259 |
n_outputs=kernal_gat_x[-1], activefunction='relu', |
|
|
260 |
variable_name=scope + 'conv_Wgatx') |
|
|
261 |
Bgatx = bias_variable([kernalpsi[-1]], variable_name=scope + 'conv_Bgatx') |
|
|
262 |
gat_x_out = conv3d(gat_x, Wgatx) + Bgatx |
|
|
263 |
gat_x_out = normalizationlayer(gat_x_out, is_train=phase, height=height, width=width, image_z=image_z, |
|
|
264 |
norm_type='group', scope=scope) |
|
|
265 |
return gat_x_out |
|
|
266 |
|
|
|
267 |
|
|
|
268 |
def positionAttentionblock(x, inputfilters, outfilters, kernal_size=1, scope=None): |
|
|
269 |
""" |
|
|
270 |
Position attention module |
|
|
271 |
:param x: |
|
|
272 |
:param inputfilters:inputfilter number |
|
|
273 |
:param outfilters:outputfilter number |
|
|
274 |
:param scope: |
|
|
275 |
:return: |
|
|
276 |
""" |
|
|
277 |
with tf.name_scope(scope): |
|
|
278 |
m_batchsize, Z, H, W, C = x.get_shape().as_list() |
|
|
279 |
|
|
|
280 |
kernalquery = (kernal_size, kernal_size, kernal_size, inputfilters, outfilters) |
|
|
281 |
Wquery = weight_xavier_init(shape=kernalquery, |
|
|
282 |
n_inputs=kernalquery[0] * kernalquery[1] * kernalquery[2] * kernalquery[3], |
|
|
283 |
n_outputs=kernalquery[-1], activefunction='relu', |
|
|
284 |
variable_name=scope + 'conv_Wquery') |
|
|
285 |
Bquery = bias_variable([kernalquery[-1]], variable_name=scope + 'conv_Bquery') |
|
|
286 |
query_conv = conv3d(x, Wquery) + Bquery |
|
|
287 |
query_conv_new = tf.reshape(query_conv, [-1, Z * H * W, C]) |
|
|
288 |
|
|
|
289 |
kernalkey = (kernal_size, kernal_size, kernal_size, inputfilters, outfilters) |
|
|
290 |
Wkey = weight_xavier_init(shape=kernalkey, n_inputs=kernalkey[0] * kernalkey[1] * kernalkey[2] * kernalkey[3], |
|
|
291 |
n_outputs=kernalkey[-1], activefunction='relu', variable_name=scope + 'conv_Wkey') |
|
|
292 |
Bkey = bias_variable([kernalkey[-1]], variable_name=scope + 'conv_Bkey') |
|
|
293 |
key_conv = conv3d(x, Wkey) + Bkey |
|
|
294 |
key_conv_new = tf.reshape(key_conv, [-1, Z * H * W, C]) |
|
|
295 |
|
|
|
296 |
# OOM,such as 512x512x32 then matric is 8388608x8388608 |
|
|
297 |
key_conv_new = tf.transpose(key_conv_new, [0, 2, 1]) |
|
|
298 |
# (2,2,2,3)*(2,2,3,4)=(2,2,2,4),(2,2,3)*(2,3,4)=(2,2,4) |
|
|
299 |
energy = tf.matmul(query_conv_new, key_conv_new) # (m_batchsize,Z*H*W,Z*H*W) |
|
|
300 |
attention = tf.nn.softmax(energy, -1) |
|
|
301 |
|
|
|
302 |
kernalproj = (kernal_size, kernal_size, kernal_size, inputfilters, outfilters) |
|
|
303 |
Wproj = weight_xavier_init(shape=kernalproj, |
|
|
304 |
n_inputs=kernalproj[0] * kernalproj[1] * kernalproj[2] * kernalproj[3], |
|
|
305 |
n_outputs=kernalproj[-1], activefunction='relu', variable_name=scope + 'conv_Wproj') |
|
|
306 |
Bproj = bias_variable([kernalproj[-1]], variable_name=scope + 'conv_Bproj') |
|
|
307 |
proj_value = conv3d(x, Wproj) + Bproj |
|
|
308 |
proj_value_new = tf.reshape(proj_value, [-1, Z * H * W, C]) |
|
|
309 |
|
|
|
310 |
out = tf.matmul(attention, proj_value_new) # (m_batchsize,Z*H*W,C) |
|
|
311 |
out_new = tf.reshape(out, [-1, Z, H, W, C]) |
|
|
312 |
|
|
|
313 |
out_new = resnet_Add(out_new, x) |
|
|
314 |
return out_new |
|
|
315 |
|
|
|
316 |
|
|
|
317 |
def channelAttentionblock(x, scope=None): |
|
|
318 |
""" |
|
|
319 |
Channel attention module |
|
|
320 |
:param x:input |
|
|
321 |
:param scope: scope name |
|
|
322 |
:return:channelattention result |
|
|
323 |
""" |
|
|
324 |
with tf.name_scope(scope): |
|
|
325 |
m_batchsize, Z, H, W, C = x.get_shape().as_list() |
|
|
326 |
|
|
|
327 |
proj_query = tf.reshape(x, [-1, Z * H * W, C]) |
|
|
328 |
proj_key = tf.reshape(x, [-1, Z * H * W, C]) |
|
|
329 |
proj_query = tf.transpose(proj_query, [0, 2, 1]) |
|
|
330 |
|
|
|
331 |
energy = tf.matmul(proj_query, proj_key) # (-1,C,C) |
|
|
332 |
attention = tf.nn.softmax(energy, -1) # (-1,C,C) |
|
|
333 |
|
|
|
334 |
proj_value = tf.reshape(x, [-1, Z * H * W, C]) |
|
|
335 |
proj_value = tf.transpose(proj_value, [0, 2, 1]) |
|
|
336 |
out = tf.matmul(attention, proj_value) # (-1,C,Z*H*W) |
|
|
337 |
|
|
|
338 |
out = tf.reshape(out, [-1, Z, H, W, C]) |
|
|
339 |
out = resnet_Add(out, x) |
|
|
340 |
return out |
|
|
341 |
|
|
|
342 |
|
|
|
343 |
def NonLocalBlock(input_x, phase, image_z=None, image_height=None, image_width=None, scope=None): |
|
|
344 |
""" |
|
|
345 |
Non-local netural network |
|
|
346 |
:param input_x: |
|
|
347 |
:param out_channels: |
|
|
348 |
:param scope: |
|
|
349 |
:return: |
|
|
350 |
""" |
|
|
351 |
batchsize, dimensizon, height, width, out_channels = input_x.get_shape().as_list() |
|
|
352 |
with tf.name_scope(scope): |
|
|
353 |
kernal_thela = (1, 1, 1, out_channels, out_channels // 2) |
|
|
354 |
W_thela = weight_xavier_init(shape=kernal_thela, |
|
|
355 |
n_inputs=kernal_thela[0] * kernal_thela[1] * kernal_thela[2] * kernal_thela[3], |
|
|
356 |
n_outputs=kernal_thela[-1], activefunction='relu', |
|
|
357 |
variable_name=scope + 'conv_W_thela') |
|
|
358 |
B_thela = bias_variable([kernal_thela[-1]], variable_name=scope + 'conv_B_thela') |
|
|
359 |
thela = conv3d(input_x, W_thela) + B_thela |
|
|
360 |
thela = normalizationlayer(thela, is_train=phase, height=image_height, width=image_width, image_z=image_z, |
|
|
361 |
norm_type='group', scope=scope + "NonLocalbn1") |
|
|
362 |
|
|
|
363 |
kernal_phi = (1, 1, 1, out_channels, out_channels // 2) |
|
|
364 |
W_phi = weight_xavier_init(shape=kernal_phi, |
|
|
365 |
n_inputs=kernal_phi[0] * kernal_phi[1] * kernal_phi[2] * kernal_phi[3], |
|
|
366 |
n_outputs=kernal_phi[-1], activefunction='relu', |
|
|
367 |
variable_name=scope + 'conv_W_phi') |
|
|
368 |
B_phi = bias_variable([kernal_phi[-1]], variable_name=scope + 'conv_B_phi') |
|
|
369 |
phi = conv3d(input_x, W_phi) + B_phi |
|
|
370 |
phi = normalizationlayer(phi, is_train=phase, height=image_height, width=image_width, image_z=image_z, |
|
|
371 |
norm_type='group', scope=scope + "NonLocalbn2") |
|
|
372 |
|
|
|
373 |
kernal_g = (1, 1, 1, out_channels, out_channels // 2) |
|
|
374 |
W_g = weight_xavier_init(shape=kernal_g, |
|
|
375 |
n_inputs=kernal_g[0] * kernal_g[1] * kernal_g[2] * kernal_g[3], |
|
|
376 |
n_outputs=kernal_g[-1], activefunction='relu', |
|
|
377 |
variable_name=scope + 'conv_W_g') |
|
|
378 |
B_g = bias_variable([kernal_g[-1]], variable_name=scope + 'conv_B_g') |
|
|
379 |
g = conv3d(input_x, W_g) + B_g |
|
|
380 |
g = normalizationlayer(g, is_train=phase, height=image_height, width=image_width, image_z=image_z, |
|
|
381 |
norm_type='group', scope=scope + "NonLocalbn3") |
|
|
382 |
|
|
|
383 |
g_x = tf.reshape(g, [-1, dimensizon * height * width, out_channels // 2]) |
|
|
384 |
theta_x = tf.reshape(thela, [-1, dimensizon * height * width, out_channels // 2]) |
|
|
385 |
phi_x = tf.reshape(phi, [-1, dimensizon * height * width, out_channels // 2]) |
|
|
386 |
phi_x = tf.transpose(phi_x, [0, 2, 1]) |
|
|
387 |
|
|
|
388 |
f = tf.matmul(theta_x, phi_x) |
|
|
389 |
|
|
|
390 |
f_softmax = tf.nn.softmax(f, -1) |
|
|
391 |
y = tf.matmul(f_softmax, g_x) |
|
|
392 |
y = tf.reshape(y, [-1, dimensizon, height, width, out_channels // 2]) |
|
|
393 |
|
|
|
394 |
kernal_y = (1, 1, 1, out_channels // 2, out_channels) |
|
|
395 |
W_y = weight_xavier_init(shape=kernal_y, |
|
|
396 |
n_inputs=kernal_y[0] * kernal_y[1] * kernal_y[2] * kernal_y[3], |
|
|
397 |
n_outputs=kernal_y[-1], activefunction='relu', |
|
|
398 |
variable_name=scope + 'conv_W_y') |
|
|
399 |
B_y = bias_variable([kernal_y[-1]], variable_name=scope + 'conv_B_y') |
|
|
400 |
w_y = conv3d(y, W_y) + B_y |
|
|
401 |
w_y = normalizationlayer(w_y, is_train=phase, height=image_height, width=image_width, image_z=image_z, |
|
|
402 |
norm_type='group', scope=scope + "NonLocalbn4") |
|
|
403 |
z = resnet_Add(input_x, w_y) |
|
|
404 |
return z |
|
|
405 |
|
|
|
406 |
|
|
|
407 |
def conv_bn_relu_drop(x, kernal, phase, drop, image_z=None, height=None, width=None, scope=None): |
|
|
408 |
""" |
|
|
409 |
conv+bn+relu+drop |
|
|
410 |
:param x: |
|
|
411 |
:param kernal: |
|
|
412 |
:param phase: |
|
|
413 |
:param drop: |
|
|
414 |
:param image_z: |
|
|
415 |
:param height: |
|
|
416 |
:param width: |
|
|
417 |
:param scope: |
|
|
418 |
:return: |
|
|
419 |
""" |
|
|
420 |
with tf.name_scope(scope): |
|
|
421 |
W = weight_xavier_init(shape=kernal, n_inputs=kernal[0] * kernal[1] * kernal[2] * kernal[3], |
|
|
422 |
n_outputs=kernal[-1], activefunction='relu', variable_name=scope + 'conv_W') |
|
|
423 |
B = bias_variable([kernal[-1]], variable_name=scope + 'conv_B') |
|
|
424 |
conv = conv3d(x, W) + B |
|
|
425 |
conv = normalizationlayer(conv, is_train=phase, height=height, width=width, image_z=image_z, norm_type='group', |
|
|
426 |
scope=scope) |
|
|
427 |
conv = tf.nn.dropout(tf.nn.leaky_relu(conv), drop) |
|
|
428 |
return conv |
|
|
429 |
|
|
|
430 |
|
|
|
431 |
def down_sampling(x, kernal, phase, drop, image_z=None, height=None, width=None, scope=None): |
|
|
432 |
""" |
|
|
433 |
downsampling with conv stride=2 |
|
|
434 |
:param x: |
|
|
435 |
:param kernal: |
|
|
436 |
:param phase: |
|
|
437 |
:param drop: |
|
|
438 |
:param image_z: |
|
|
439 |
:param height: |
|
|
440 |
:param width: |
|
|
441 |
:param scope: |
|
|
442 |
:return: |
|
|
443 |
""" |
|
|
444 |
with tf.name_scope(scope): |
|
|
445 |
W = weight_xavier_init(shape=kernal, n_inputs=kernal[0] * kernal[1] * kernal[2] * kernal[3], |
|
|
446 |
n_outputs=kernal[-1], |
|
|
447 |
activefunction='relu', variable_name=scope + 'W') |
|
|
448 |
B = bias_variable([kernal[-1]], variable_name=scope + 'B') |
|
|
449 |
conv = conv3d(x, W, 2) + B |
|
|
450 |
conv = normalizationlayer(conv, is_train=phase, height=height, width=width, image_z=image_z, norm_type='group', |
|
|
451 |
scope=scope) |
|
|
452 |
conv = tf.nn.dropout(tf.nn.leaky_relu(conv), drop) |
|
|
453 |
return conv |
|
|
454 |
|
|
|
455 |
|
|
|
456 |
def deconv_relu(x, kernal, samefeture=False, scope=None): |
|
|
457 |
""" |
|
|
458 |
deconv+relu |
|
|
459 |
:param x: |
|
|
460 |
:param kernal: |
|
|
461 |
:param samefeture: |
|
|
462 |
:param scope: |
|
|
463 |
:return: |
|
|
464 |
""" |
|
|
465 |
with tf.name_scope(scope): |
|
|
466 |
W = weight_xavier_init(shape=kernal, n_inputs=kernal[0] * kernal[1] * kernal[2] * kernal[-1], |
|
|
467 |
n_outputs=kernal[-2], activefunction='relu', variable_name=scope + 'W') |
|
|
468 |
B = bias_variable([kernal[-2]], variable_name=scope + 'B') |
|
|
469 |
conv = deconv3d(x, W, samefeture, True) + B |
|
|
470 |
conv = tf.nn.leaky_relu(conv) |
|
|
471 |
return conv |
|
|
472 |
|
|
|
473 |
|
|
|
474 |
def conv_sigmod(x, kernal, scope=None): |
|
|
475 |
""" |
|
|
476 |
conv_sigmod |
|
|
477 |
:param x: |
|
|
478 |
:param kernal: |
|
|
479 |
:param scope: |
|
|
480 |
:return: |
|
|
481 |
""" |
|
|
482 |
with tf.name_scope(scope): |
|
|
483 |
W = weight_xavier_init(shape=kernal, n_inputs=kernal[0] * kernal[1] * kernal[2] * kernal[3], |
|
|
484 |
n_outputs=kernal[-1], activefunction='sigomd', variable_name=scope + 'W') |
|
|
485 |
B = bias_variable([kernal[-1]], variable_name=scope + 'B') |
|
|
486 |
conv = conv3d(x, W) + B |
|
|
487 |
conv = tf.nn.sigmoid(conv) |
|
|
488 |
return conv |