[e44b03]: / tf_models.py

Download this file

135 lines (88 with data), 6.1 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import tensorflow as tf
from tensorflow.contrib.keras.python.keras.layers import Conv3D, MaxPooling3D, UpSampling3D, Activation, Conv3DTranspose
from tf_layers import *
def PlainCounterpart(input, name):
x = Conv3DWithBN(input, filters=24, ksize=3, strides=1, padding='same', name=name + '_conv_15rf_1x')
x = Conv3DWithBN(x, filters=36, ksize=3, strides=1, padding='same', name=name + '_conv_15rf_2x')
x = Conv3DWithBN(x, filters=48, ksize=3, strides=1, padding='same', name=name + '_conv_15rf_3x')
x = Conv3DWithBN(x, filters=60, ksize=3, strides=1, padding='same', name=name + '_conv_15rf_4x')
x = Conv3DWithBN(x, filters=72, ksize=3, strides=1, padding='same', name=name + '_conv_15rf_5x')
x = Conv3DWithBN(x, filters=84, ksize=3, strides=1, padding='same', name=name + '_conv_15rf_6x')
x = Conv3DWithBN(x, filters=96, ksize=3, strides=1, padding='same', name=name + '_conv_15rf_7x')
out_15rf = x
x = Conv3DWithBN(x, filters=108, ksize=3, strides=1, padding='same', name=name + '_conv_27rf_1x')
x = Conv3DWithBN(x, filters=120, ksize=3, strides=1, padding='same', name=name + '_conv_27rf_2x')
x = Conv3DWithBN(x, filters=132, ksize=3, strides=1, padding='same', name=name + '_conv_27rf_3x')
x = Conv3DWithBN(x, filters=144, ksize=3, strides=1, padding='same', name=name + '_conv_27rf_4x')
x = Conv3DWithBN(x, filters=156, ksize=3, strides=1, padding='same', name=name + '_conv_27rf_5x')
x = Conv3DWithBN(x, filters=168, ksize=3, strides=1, padding='same', name=name + '_conv_27rf_6x')
out_27rf = x
return out_15rf, out_27rf
def BraTS2ScaleDenseNetConcat(input, name):
x = Conv3D(filters=24, kernel_size=3, strides=1, padding='same', name=name+'_conv_init')(input)
x = DenseNetUnit3D(x, growth_rate=12, ksize=3, rep=6, name=name+'_denseblock1')
out_15rf = BatchNormalization(center=True, scale=True)(x)
out_15rf = Activation('relu')(out_15rf)
out_15rf = Conv3DWithBN(out_15rf, filters=96, ksize=1, strides=1, name=name + '_out_15_postconv')
x = DenseNetUnit3D(x, growth_rate=12, ksize=3, rep=6, name=name+'_denseblock2')
out_27rf = BatchNormalization(center=True, scale=True)(x)
out_27rf = Activation('relu')(out_27rf)
out_27rf = Conv3DWithBN(out_27rf, filters=168, ksize=1, strides=1, name=name + '_out_27_postconv')
return out_15rf, out_27rf
def BraTS2ScaleDenseNetConcat_large(input, name):
x = Conv3D(filters=48, kernel_size=3, strides=1, padding='same', name=name+'_conv_init')(input)
x = DenseNetUnit3D(x, growth_rate=12, ksize=3, rep=6, name=name+'_denseblock1')
out_15rf = BatchNormalization(center=True, scale=True)(x)
out_15rf = Activation('relu')(out_15rf)
out_15rf = Conv3DWithBN(out_15rf, filters=192, ksize=1, strides=1, name=name + '_out_15_postconv')
x = DenseNetUnit3D(x, growth_rate=24, ksize=3, rep=6, name=name+'_denseblock2')
out_27rf = BatchNormalization(center=True, scale=True)(x)
out_27rf = Activation('relu')(out_27rf)
out_27rf = Conv3DWithBN(out_27rf, filters=336, ksize=1, strides=1, name=name + '_out_27_postconv')
return out_15rf, out_27rf
def BraTS2ScaleDenseNet(input, num_labels):
x = Conv3D(filters=24, kernel_size=3, strides=1, padding='same')(input)
x = DenseNetUnit3D(x, growth_rate=12, ksize=3, rep=6)
out_15rf = BatchNormalization(center=True, scale=True)(x)
out_15rf = Activation('relu')(out_15rf)
out_15rf = Conv3DWithBN(out_15rf, filters=96, ksize=1, strides=1, name='out_15_postconv')
x = DenseNetUnit3D(x, growth_rate=12, ksize=3, rep=6)
out_27rf = BatchNormalization(center=True, scale=True)(x)
out_27rf = Activation('relu')(out_27rf)
out_27rf = Conv3DWithBN(out_27rf, filters=168, ksize=1, strides=1, name='out_27_postconv')
score_15rf = Conv3D(num_labels, kernel_size=1, strides=1, padding='same')(out_15rf)
score_27rf = Conv3D(num_labels, kernel_size=1, strides=1, padding='same')(out_27rf)
score = score_15rf[:, 13:25, 13:25, 13:25, :] + \
score_27rf[:, 13:25, 13:25, 13:25, :]
return score
def BraTS3ScaleDenseNet(input, num_labels):
x = Conv3D(filters=24, kernel_size=3, strides=1, padding='same')(input)
x = DenseNetUnit3D(x, growth_rate=12, ksize=3, rep=5)
out_13rf = BatchNormalization(center=True, scale=True)(x)
out_13rf = Activation('relu')(out_13rf)
out_13rf = Conv3DWithBN(out_13rf, filters=84, ksize=1, strides=1, name='out_13_postconv')
x = DenseNetUnit3D(x, growth_rate=12, ksize=3, rep=5)
out_23rf = BatchNormalization(center=True, scale=True)(x)
out_23rf = Activation('relu')(out_23rf)
out_23rf = Conv3DWithBN(out_23rf, filters=144, ksize=1, strides=1, name='out_23_postconv')
x = DenseNetUnit3D(x, growth_rate=12, ksize=3, rep=5)
out_33rf = BatchNormalization(center=True, scale=True)(x)
out_33rf = Activation('relu')(out_33rf)
out_33rf = Conv3DWithBN(out_33rf, filters=204, ksize=1, strides=1, name='out_33_postconv')
score_13rf = Conv3D(num_labels, kernel_size=1, strides=1, padding='same')(out_13rf)
score_23rf = Conv3D(num_labels, kernel_size=1, strides=1, padding='same')(out_23rf)
score_33rf = Conv3D(num_labels, kernel_size=1, strides=1, padding='same')(out_33rf)
score = score_13rf[:, 16:28, 16:28, 16:28, :] + \
score_23rf[:, 16:28, 16:28, 16:28, :] + \
score_33rf[:, 16:28, 16:28, 16:28, :]
return score
def BraTS1ScaleDenseNet(input, num_labels):
x = Conv3D(filters=36, kernel_size=5, strides=1, padding='same')(input)
x = DenseNetUnit3D(x, growth_rate=18, ksize=3, rep=6)
out_15rf = BatchNormalization(center=True, scale=True)(x)
out_15rf = Activation('relu')(out_15rf)
out_15rf = Conv3DWithBN(out_15rf, filters=144, ksize=1, strides=1, name='out_17_postconv1')
out_15rf = Conv3DWithBN(out_15rf, filters=144, ksize=1, strides=1, name='out_17_postconv2')
score_15rf = Conv3D(num_labels, kernel_size=1, strides=1, padding='same')(out_15rf)
score = score_15rf[:, 8:20, 8:20, 8:20, :]
return score