[242173]: / inst / deepbleed / blocks / vnet.py

Download this file

111 lines (87 with data), 4.2 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
# @author: msharrock
# version: 0.0.1
"""
VNet Blocks for DeepBleed
tensorflow version 2.0
"""
import tensorflow as tf
from tensorflow.keras import layers
class VNetInBlock(layers.Layer):
def __init__(self):
super(VNetInBlock, self).__init__()
self.add = layers.Add()
self.concatenate = layers.Concatenate()
self.convolution = layers.Conv3D(filters=16, kernel_size=(5,5,5), strides=1,
padding='same', kernel_initializer='he_normal', activation='relu')
def call(self, inputs):
x = self.convolution(inputs)
d = self.concatenate(16 * [inputs])
return self.add([x, d])
class VNetDownBlock(layers.Layer):
def __init__(self, channels, n_convs, norm=False, drop=False, training=False):
super(VNetDownBlock, self).__init__()
self.channels = channels
self.n_convs = n_convs
self.training = training
self.norm = norm
self.drop = drop
self.add = layers.Add()
self.downsample = layers.Conv3D(filters=self.channels, kernel_size=(2,2,2), strides=2,
padding='valid', kernel_initializer='he_normal', activation=None)
self.convolution = layers.Conv3D(filters=self.channels, kernel_size=(5,5,5), strides=1,
padding='same', kernel_initializer='he_normal', activation=None)
self.batch_norm = layers.BatchNormalization(scale=False, renorm=True, trainable=self.training)
self.activation = layers.Activation('relu')
self.dropout = layers.Dropout(0.1)
def call(self, inputs):
d = self.downsample(inputs)
if self.norm:
d = self.batch_norm(d, training=self.training)
d = self.activation(d)
x = d
for _ in range(self.n_convs):
x = self.convolution(x)
x = self.activation(x)
if self.drop:
x = self.dropout(x, training=self.training)
return self.add([x, d])
class VNetUpBlock(layers.Layer):
def __init__(self, channels, n_convs, norm=False, drop=False, training=False):
super(VNetUpBlock, self).__init__()
self.channels = channels
self.n_convs = n_convs
self.training = training
self.norm = norm
self.drop = drop
self.add = layers.Add()
self.concatenate = layers.Concatenate()
self.upsample = layers.Conv3DTranspose(filters=self.channels//2, kernel_size=(2,2,2), strides=2,
padding='valid', kernel_initializer='he_normal', activation=None)
self.convolution = layers.Conv3D(filters=self.channels, kernel_size=(5,5,5), strides=1,
padding='same', kernel_initializer='he_normal', activation=None)
self.batch_norm = layers.BatchNormalization(scale=False, renorm=True, trainable=self.training)
self.activation = layers.Activation('relu')
self.dropout = layers.Dropout(0.1)
def call(self, inputs, skip):
x = self.upsample(inputs)
if self.norm:
x = self.batch_norm(x, training=self.training)
x = self.activation(x)
cat = self.concatenate([x, skip])
x = cat
for _ in range(self.n_convs):
x = self.convolution(x)
x = self.activation(x)
if self.drop:
x = self.dropout(x, training=self.training)
return self.add([x, cat])
class VNetOutBlock(layers.Layer):
def __init__(self):
super(VNetOutBlock, self).__init__()
self.final = layers.Conv3D(filters=2, kernel_size=(1,1,1), strides=1,
padding='valid', kernel_initializer='he_normal', activation='relu')
self.binary = layers.Conv3D(filters=1, kernel_size=(1,1,1), strides=1,
padding='valid', kernel_initializer='he_normal', activation='sigmoid')
def call(self, inputs):
x = self.final(inputs)
return self.binary(x)