Diff of /tf_models.py [000000] .. [e44b03]

Switch to unified view

a b/tf_models.py
1
import tensorflow as tf
2
from tensorflow.contrib.keras.python.keras.layers import Conv3D, MaxPooling3D, UpSampling3D, Activation, Conv3DTranspose
3
from tf_layers import *
4
5
6
def PlainCounterpart(input, name):
7
8
    x = Conv3DWithBN(input, filters=24, ksize=3, strides=1, padding='same', name=name + '_conv_15rf_1x')
9
    x = Conv3DWithBN(x, filters=36, ksize=3, strides=1, padding='same', name=name + '_conv_15rf_2x')
10
    x = Conv3DWithBN(x, filters=48, ksize=3, strides=1, padding='same', name=name + '_conv_15rf_3x')
11
    x = Conv3DWithBN(x, filters=60, ksize=3, strides=1, padding='same', name=name + '_conv_15rf_4x')
12
    x = Conv3DWithBN(x, filters=72, ksize=3, strides=1, padding='same', name=name + '_conv_15rf_5x')
13
    x = Conv3DWithBN(x, filters=84, ksize=3, strides=1, padding='same', name=name + '_conv_15rf_6x')
14
    x = Conv3DWithBN(x, filters=96, ksize=3, strides=1, padding='same', name=name + '_conv_15rf_7x')
15
16
    out_15rf = x
17
18
    x = Conv3DWithBN(x, filters=108, ksize=3, strides=1, padding='same', name=name + '_conv_27rf_1x')
19
    x = Conv3DWithBN(x, filters=120, ksize=3, strides=1, padding='same', name=name + '_conv_27rf_2x')
20
    x = Conv3DWithBN(x, filters=132, ksize=3, strides=1, padding='same', name=name + '_conv_27rf_3x')
21
    x = Conv3DWithBN(x, filters=144, ksize=3, strides=1, padding='same', name=name + '_conv_27rf_4x')
22
    x = Conv3DWithBN(x, filters=156, ksize=3, strides=1, padding='same', name=name + '_conv_27rf_5x')
23
    x = Conv3DWithBN(x, filters=168, ksize=3, strides=1, padding='same', name=name + '_conv_27rf_6x')
24
25
    out_27rf = x
26
27
    return out_15rf, out_27rf
28
29
30
def BraTS2ScaleDenseNetConcat(input, name):
31
32
    x = Conv3D(filters=24, kernel_size=3, strides=1, padding='same', name=name+'_conv_init')(input)
33
    x = DenseNetUnit3D(x, growth_rate=12, ksize=3, rep=6, name=name+'_denseblock1')
34
35
    out_15rf = BatchNormalization(center=True, scale=True)(x)
36
    out_15rf = Activation('relu')(out_15rf)
37
    out_15rf = Conv3DWithBN(out_15rf, filters=96, ksize=1, strides=1, name=name + '_out_15_postconv')
38
39
    x = DenseNetUnit3D(x, growth_rate=12, ksize=3, rep=6, name=name+'_denseblock2')
40
41
    out_27rf = BatchNormalization(center=True, scale=True)(x)
42
    out_27rf = Activation('relu')(out_27rf)
43
    out_27rf = Conv3DWithBN(out_27rf, filters=168, ksize=1, strides=1, name=name + '_out_27_postconv')
44
45
    return out_15rf, out_27rf
46
47
def BraTS2ScaleDenseNetConcat_large(input, name):
48
49
    x = Conv3D(filters=48, kernel_size=3, strides=1, padding='same', name=name+'_conv_init')(input)
50
    x = DenseNetUnit3D(x, growth_rate=12, ksize=3, rep=6, name=name+'_denseblock1')
51
52
    out_15rf = BatchNormalization(center=True, scale=True)(x)
53
    out_15rf = Activation('relu')(out_15rf)
54
    out_15rf = Conv3DWithBN(out_15rf, filters=192, ksize=1, strides=1, name=name + '_out_15_postconv')
55
56
    x = DenseNetUnit3D(x, growth_rate=24, ksize=3, rep=6, name=name+'_denseblock2')
57
58
    out_27rf = BatchNormalization(center=True, scale=True)(x)
59
    out_27rf = Activation('relu')(out_27rf)
60
    out_27rf = Conv3DWithBN(out_27rf, filters=336, ksize=1, strides=1, name=name + '_out_27_postconv')
61
62
    return out_15rf, out_27rf
63
64
65
def BraTS2ScaleDenseNet(input, num_labels):
66
67
    x = Conv3D(filters=24, kernel_size=3, strides=1, padding='same')(input)
68
    x = DenseNetUnit3D(x, growth_rate=12, ksize=3, rep=6)
69
70
    out_15rf = BatchNormalization(center=True, scale=True)(x)
71
    out_15rf = Activation('relu')(out_15rf)
72
    out_15rf = Conv3DWithBN(out_15rf, filters=96, ksize=1, strides=1, name='out_15_postconv')
73
74
    x = DenseNetUnit3D(x, growth_rate=12, ksize=3, rep=6)
75
76
    out_27rf = BatchNormalization(center=True, scale=True)(x)
77
    out_27rf = Activation('relu')(out_27rf)
78
    out_27rf = Conv3DWithBN(out_27rf, filters=168, ksize=1, strides=1, name='out_27_postconv')
79
80
    score_15rf = Conv3D(num_labels, kernel_size=1, strides=1, padding='same')(out_15rf)
81
    score_27rf = Conv3D(num_labels, kernel_size=1, strides=1, padding='same')(out_27rf)
82
83
    score = score_15rf[:, 13:25, 13:25, 13:25, :] + \
84
            score_27rf[:, 13:25, 13:25, 13:25, :]
85
86
    return score
87
88
89
def BraTS3ScaleDenseNet(input, num_labels):
90
91
    x = Conv3D(filters=24, kernel_size=3, strides=1, padding='same')(input)
92
    x = DenseNetUnit3D(x, growth_rate=12, ksize=3, rep=5)
93
94
    out_13rf = BatchNormalization(center=True, scale=True)(x)
95
    out_13rf = Activation('relu')(out_13rf)
96
    out_13rf = Conv3DWithBN(out_13rf, filters=84, ksize=1, strides=1, name='out_13_postconv')
97
98
    x = DenseNetUnit3D(x, growth_rate=12, ksize=3, rep=5)
99
100
    out_23rf = BatchNormalization(center=True, scale=True)(x)
101
    out_23rf = Activation('relu')(out_23rf)
102
    out_23rf = Conv3DWithBN(out_23rf, filters=144, ksize=1, strides=1, name='out_23_postconv')
103
104
    x = DenseNetUnit3D(x, growth_rate=12, ksize=3, rep=5)
105
106
    out_33rf = BatchNormalization(center=True, scale=True)(x)
107
    out_33rf = Activation('relu')(out_33rf)
108
    out_33rf = Conv3DWithBN(out_33rf, filters=204, ksize=1, strides=1, name='out_33_postconv')
109
110
    score_13rf = Conv3D(num_labels, kernel_size=1, strides=1, padding='same')(out_13rf)
111
    score_23rf = Conv3D(num_labels, kernel_size=1, strides=1, padding='same')(out_23rf)
112
    score_33rf = Conv3D(num_labels, kernel_size=1, strides=1, padding='same')(out_33rf)
113
114
    score = score_13rf[:, 16:28, 16:28, 16:28, :] + \
115
            score_23rf[:, 16:28, 16:28, 16:28, :] + \
116
            score_33rf[:, 16:28, 16:28, 16:28, :]
117
118
    return score
119
120
121
122
def BraTS1ScaleDenseNet(input, num_labels):
123
124
    x = Conv3D(filters=36, kernel_size=5, strides=1, padding='same')(input)
125
    x = DenseNetUnit3D(x, growth_rate=18, ksize=3, rep=6)
126
127
    out_15rf = BatchNormalization(center=True, scale=True)(x)
128
    out_15rf = Activation('relu')(out_15rf)
129
    out_15rf = Conv3DWithBN(out_15rf, filters=144, ksize=1, strides=1, name='out_17_postconv1')
130
    out_15rf = Conv3DWithBN(out_15rf, filters=144, ksize=1, strides=1, name='out_17_postconv2')
131
132
    score_15rf = Conv3D(num_labels, kernel_size=1, strides=1, padding='same')(out_15rf)
133
134
    score = score_15rf[:, 8:20, 8:20, 8:20, :]
135
    return score