|
a |
|
b/tf_models.py |
|
|
1 |
import tensorflow as tf |
|
|
2 |
from tensorflow.contrib.keras.python.keras.layers import Conv3D, MaxPooling3D, UpSampling3D, Activation, Conv3DTranspose |
|
|
3 |
from tf_layers import * |
|
|
4 |
|
|
|
5 |
|
|
|
6 |
def PlainCounterpart(input, name): |
|
|
7 |
|
|
|
8 |
x = Conv3DWithBN(input, filters=24, ksize=3, strides=1, padding='same', name=name + '_conv_15rf_1x') |
|
|
9 |
x = Conv3DWithBN(x, filters=36, ksize=3, strides=1, padding='same', name=name + '_conv_15rf_2x') |
|
|
10 |
x = Conv3DWithBN(x, filters=48, ksize=3, strides=1, padding='same', name=name + '_conv_15rf_3x') |
|
|
11 |
x = Conv3DWithBN(x, filters=60, ksize=3, strides=1, padding='same', name=name + '_conv_15rf_4x') |
|
|
12 |
x = Conv3DWithBN(x, filters=72, ksize=3, strides=1, padding='same', name=name + '_conv_15rf_5x') |
|
|
13 |
x = Conv3DWithBN(x, filters=84, ksize=3, strides=1, padding='same', name=name + '_conv_15rf_6x') |
|
|
14 |
x = Conv3DWithBN(x, filters=96, ksize=3, strides=1, padding='same', name=name + '_conv_15rf_7x') |
|
|
15 |
|
|
|
16 |
out_15rf = x |
|
|
17 |
|
|
|
18 |
x = Conv3DWithBN(x, filters=108, ksize=3, strides=1, padding='same', name=name + '_conv_27rf_1x') |
|
|
19 |
x = Conv3DWithBN(x, filters=120, ksize=3, strides=1, padding='same', name=name + '_conv_27rf_2x') |
|
|
20 |
x = Conv3DWithBN(x, filters=132, ksize=3, strides=1, padding='same', name=name + '_conv_27rf_3x') |
|
|
21 |
x = Conv3DWithBN(x, filters=144, ksize=3, strides=1, padding='same', name=name + '_conv_27rf_4x') |
|
|
22 |
x = Conv3DWithBN(x, filters=156, ksize=3, strides=1, padding='same', name=name + '_conv_27rf_5x') |
|
|
23 |
x = Conv3DWithBN(x, filters=168, ksize=3, strides=1, padding='same', name=name + '_conv_27rf_6x') |
|
|
24 |
|
|
|
25 |
out_27rf = x |
|
|
26 |
|
|
|
27 |
return out_15rf, out_27rf |
|
|
28 |
|
|
|
29 |
|
|
|
30 |
def BraTS2ScaleDenseNetConcat(input, name): |
|
|
31 |
|
|
|
32 |
x = Conv3D(filters=24, kernel_size=3, strides=1, padding='same', name=name+'_conv_init')(input) |
|
|
33 |
x = DenseNetUnit3D(x, growth_rate=12, ksize=3, rep=6, name=name+'_denseblock1') |
|
|
34 |
|
|
|
35 |
out_15rf = BatchNormalization(center=True, scale=True)(x) |
|
|
36 |
out_15rf = Activation('relu')(out_15rf) |
|
|
37 |
out_15rf = Conv3DWithBN(out_15rf, filters=96, ksize=1, strides=1, name=name + '_out_15_postconv') |
|
|
38 |
|
|
|
39 |
x = DenseNetUnit3D(x, growth_rate=12, ksize=3, rep=6, name=name+'_denseblock2') |
|
|
40 |
|
|
|
41 |
out_27rf = BatchNormalization(center=True, scale=True)(x) |
|
|
42 |
out_27rf = Activation('relu')(out_27rf) |
|
|
43 |
out_27rf = Conv3DWithBN(out_27rf, filters=168, ksize=1, strides=1, name=name + '_out_27_postconv') |
|
|
44 |
|
|
|
45 |
return out_15rf, out_27rf |
|
|
46 |
|
|
|
47 |
def BraTS2ScaleDenseNetConcat_large(input, name): |
|
|
48 |
|
|
|
49 |
x = Conv3D(filters=48, kernel_size=3, strides=1, padding='same', name=name+'_conv_init')(input) |
|
|
50 |
x = DenseNetUnit3D(x, growth_rate=12, ksize=3, rep=6, name=name+'_denseblock1') |
|
|
51 |
|
|
|
52 |
out_15rf = BatchNormalization(center=True, scale=True)(x) |
|
|
53 |
out_15rf = Activation('relu')(out_15rf) |
|
|
54 |
out_15rf = Conv3DWithBN(out_15rf, filters=192, ksize=1, strides=1, name=name + '_out_15_postconv') |
|
|
55 |
|
|
|
56 |
x = DenseNetUnit3D(x, growth_rate=24, ksize=3, rep=6, name=name+'_denseblock2') |
|
|
57 |
|
|
|
58 |
out_27rf = BatchNormalization(center=True, scale=True)(x) |
|
|
59 |
out_27rf = Activation('relu')(out_27rf) |
|
|
60 |
out_27rf = Conv3DWithBN(out_27rf, filters=336, ksize=1, strides=1, name=name + '_out_27_postconv') |
|
|
61 |
|
|
|
62 |
return out_15rf, out_27rf |
|
|
63 |
|
|
|
64 |
|
|
|
65 |
def BraTS2ScaleDenseNet(input, num_labels): |
|
|
66 |
|
|
|
67 |
x = Conv3D(filters=24, kernel_size=3, strides=1, padding='same')(input) |
|
|
68 |
x = DenseNetUnit3D(x, growth_rate=12, ksize=3, rep=6) |
|
|
69 |
|
|
|
70 |
out_15rf = BatchNormalization(center=True, scale=True)(x) |
|
|
71 |
out_15rf = Activation('relu')(out_15rf) |
|
|
72 |
out_15rf = Conv3DWithBN(out_15rf, filters=96, ksize=1, strides=1, name='out_15_postconv') |
|
|
73 |
|
|
|
74 |
x = DenseNetUnit3D(x, growth_rate=12, ksize=3, rep=6) |
|
|
75 |
|
|
|
76 |
out_27rf = BatchNormalization(center=True, scale=True)(x) |
|
|
77 |
out_27rf = Activation('relu')(out_27rf) |
|
|
78 |
out_27rf = Conv3DWithBN(out_27rf, filters=168, ksize=1, strides=1, name='out_27_postconv') |
|
|
79 |
|
|
|
80 |
score_15rf = Conv3D(num_labels, kernel_size=1, strides=1, padding='same')(out_15rf) |
|
|
81 |
score_27rf = Conv3D(num_labels, kernel_size=1, strides=1, padding='same')(out_27rf) |
|
|
82 |
|
|
|
83 |
score = score_15rf[:, 13:25, 13:25, 13:25, :] + \ |
|
|
84 |
score_27rf[:, 13:25, 13:25, 13:25, :] |
|
|
85 |
|
|
|
86 |
return score |
|
|
87 |
|
|
|
88 |
|
|
|
89 |
def BraTS3ScaleDenseNet(input, num_labels): |
|
|
90 |
|
|
|
91 |
x = Conv3D(filters=24, kernel_size=3, strides=1, padding='same')(input) |
|
|
92 |
x = DenseNetUnit3D(x, growth_rate=12, ksize=3, rep=5) |
|
|
93 |
|
|
|
94 |
out_13rf = BatchNormalization(center=True, scale=True)(x) |
|
|
95 |
out_13rf = Activation('relu')(out_13rf) |
|
|
96 |
out_13rf = Conv3DWithBN(out_13rf, filters=84, ksize=1, strides=1, name='out_13_postconv') |
|
|
97 |
|
|
|
98 |
x = DenseNetUnit3D(x, growth_rate=12, ksize=3, rep=5) |
|
|
99 |
|
|
|
100 |
out_23rf = BatchNormalization(center=True, scale=True)(x) |
|
|
101 |
out_23rf = Activation('relu')(out_23rf) |
|
|
102 |
out_23rf = Conv3DWithBN(out_23rf, filters=144, ksize=1, strides=1, name='out_23_postconv') |
|
|
103 |
|
|
|
104 |
x = DenseNetUnit3D(x, growth_rate=12, ksize=3, rep=5) |
|
|
105 |
|
|
|
106 |
out_33rf = BatchNormalization(center=True, scale=True)(x) |
|
|
107 |
out_33rf = Activation('relu')(out_33rf) |
|
|
108 |
out_33rf = Conv3DWithBN(out_33rf, filters=204, ksize=1, strides=1, name='out_33_postconv') |
|
|
109 |
|
|
|
110 |
score_13rf = Conv3D(num_labels, kernel_size=1, strides=1, padding='same')(out_13rf) |
|
|
111 |
score_23rf = Conv3D(num_labels, kernel_size=1, strides=1, padding='same')(out_23rf) |
|
|
112 |
score_33rf = Conv3D(num_labels, kernel_size=1, strides=1, padding='same')(out_33rf) |
|
|
113 |
|
|
|
114 |
score = score_13rf[:, 16:28, 16:28, 16:28, :] + \ |
|
|
115 |
score_23rf[:, 16:28, 16:28, 16:28, :] + \ |
|
|
116 |
score_33rf[:, 16:28, 16:28, 16:28, :] |
|
|
117 |
|
|
|
118 |
return score |
|
|
119 |
|
|
|
120 |
|
|
|
121 |
|
|
|
122 |
def BraTS1ScaleDenseNet(input, num_labels): |
|
|
123 |
|
|
|
124 |
x = Conv3D(filters=36, kernel_size=5, strides=1, padding='same')(input) |
|
|
125 |
x = DenseNetUnit3D(x, growth_rate=18, ksize=3, rep=6) |
|
|
126 |
|
|
|
127 |
out_15rf = BatchNormalization(center=True, scale=True)(x) |
|
|
128 |
out_15rf = Activation('relu')(out_15rf) |
|
|
129 |
out_15rf = Conv3DWithBN(out_15rf, filters=144, ksize=1, strides=1, name='out_17_postconv1') |
|
|
130 |
out_15rf = Conv3DWithBN(out_15rf, filters=144, ksize=1, strides=1, name='out_17_postconv2') |
|
|
131 |
|
|
|
132 |
score_15rf = Conv3D(num_labels, kernel_size=1, strides=1, padding='same')(out_15rf) |
|
|
133 |
|
|
|
134 |
score = score_15rf[:, 8:20, 8:20, 8:20, :] |
|
|
135 |
return score |