[d6d24a]: / Segmentation / model / vnet.py

Download this file

110 lines (93 with data), 4.1 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
import tensorflow as tf
import tensorflow.keras.layers as tfkl
import inspect
from Segmentation.model.vnet_build_blocks import Conv_ResBlock, Up_ResBlock
class VNet(tf.keras.Model):
def __init__(self,
num_channels,
num_classes,
use_2d=False,
num_conv_layers=2,
kernel_size=(3, 3, 3),
activation='prelu',
use_batchnorm=True,
noise=0.0,
dropout_rate=0.25,
use_spatial_dropout=True,
predict_slice=False,
slice_format="mean",
**kwargs):
self.params = str(inspect.currentframe().f_locals)
super(VNet, self).__init__(**kwargs)
self.noise = noise
self.predict_slice = predict_slice
self.slice_format = slice_format
block_args = {
'use_2d': use_2d,
'num_conv_layers': num_conv_layers,
'kernel_size': kernel_size,
'activation': activation,
'use_batchnorm': use_batchnorm,
'dropout_rate': dropout_rate,
'use_spatial_dropout': use_spatial_dropout,
}
self.contracting_path = []
for i in range(len(num_channels)):
output_ch = num_channels[i]
self.contracting_path.append(Conv_ResBlock(output_ch,
**block_args,
**kwargs))
self.upsampling_path = []
n = len(num_channels) - 1
for i in range(n, -1, -1):
output_ch = num_channels[i]
self.upsampling_path.append(Up_ResBlock(output_ch,
**block_args,
**kwargs))
# convolution num_channels at the output
if use_2d:
self.conv_output = tfkl.Conv2D(filters=num_channels,
kernel_size=kernel_size,
activation=None,
padding='same')
else:
self.conv_output = tfkl.Conv3D(filters=num_classes,
kernel_size=kernel_size,
activation=None,
padding='same')
if activation == 'prelu':
self.activation = tfkl.PReLU() # alpha_initializer=tf.keras.initializers.Constant(value=0.25))
else:
self.activation = tfkl.Activation(activation)
if use_2d:
self.conv_1x1 = tfkl.Conv2D(filters=num_classes,
kernel_size=(1, 1),
padding='same')
else:
self.conv_1x1 = tfkl.Conv3D(filters=num_classes,
kernel_size=(1, 1, 1),
padding='same')
self.output_act = tfkl.Activation('sigmoid' if num_classes == 1 else 'softmax')
def call(self, x, training):
if self.noise and training:
x = tfkl.GaussianNoise(self.noise)(x)
blocks = []
# encoder blocks
for _, down in enumerate(self.contracting_path):
x, x_before = down(x, training=training)
blocks.append(x_before)
# decoder blocks
for j, up in enumerate(self.upsampling_path):
x = up([x, blocks[-j - 1]], training=training)
output = self.conv_output(x)
output = self.activation(output)
output = self.conv_1x1(output)
if self.predict_slice:
if self.slice_format == "mean":
output = tf.reduce_mean(output, -4)
output = tf.expand_dims(output, 1)
if self.slice_format == "sum":
output = tf.reduce_sum(output, -4)
output = tf.expand_dims(output, 1)
output = self.output_act(output)
return output