Switch to unified view

a b/inst/deepbleed/blocks/vnet.py
1
# @author: msharrock
2
# version: 0.0.1
3
4
"""
5
VNet Blocks for DeepBleed  
6
7
tensorflow version 2.0
8
9
"""
10
11
import tensorflow as tf
12
from tensorflow.keras import layers
13
14
15
class VNetInBlock(layers.Layer):
16
    def __init__(self):
17
        super(VNetInBlock, self).__init__()
18
        self.add = layers.Add()
19
        self.concatenate = layers.Concatenate() 
20
        self.convolution = layers.Conv3D(filters=16, kernel_size=(5,5,5), strides=1, 
21
                                         padding='same', kernel_initializer='he_normal', activation='relu') 
22
23
    def call(self, inputs): 
24
        x = self.convolution(inputs)
25
        d = self.concatenate(16 * [inputs])
26
27
        return self.add([x, d])
28
29
30
class VNetDownBlock(layers.Layer):
31
    def __init__(self, channels, n_convs, norm=False, drop=False, training=False):
32
        super(VNetDownBlock, self).__init__()
33
        self.channels = channels
34
        self.n_convs = n_convs
35
        self.training = training
36
        self.norm = norm
37
        self.drop = drop
38
        self.add = layers.Add()
39
        self.downsample = layers.Conv3D(filters=self.channels, kernel_size=(2,2,2), strides=2,
40
                                         padding='valid', kernel_initializer='he_normal', activation=None)
41
        self.convolution = layers.Conv3D(filters=self.channels, kernel_size=(5,5,5), strides=1, 
42
                                         padding='same', kernel_initializer='he_normal', activation=None) 
43
        self.batch_norm = layers.BatchNormalization(scale=False, renorm=True, trainable=self.training)
44
        self.activation = layers.Activation('relu')
45
        self.dropout = layers.Dropout(0.1)
46
        
47
    def call(self, inputs):  
48
        d = self.downsample(inputs) 
49
        if self.norm:
50
            d = self.batch_norm(d, training=self.training)
51
        d = self.activation(d)
52
        x = d
53
        
54
        for _ in range(self.n_convs):
55
            x = self.convolution(x)
56
            x = self.activation(x)
57
            if self.drop:
58
                x = self.dropout(x, training=self.training)
59
            
60
        return self.add([x, d])  
61
62
class VNetUpBlock(layers.Layer):
63
    def __init__(self, channels, n_convs, norm=False, drop=False, training=False):
64
        super(VNetUpBlock, self).__init__()
65
        self.channels = channels
66
        self.n_convs = n_convs
67
        self.training = training
68
        self.norm = norm
69
        self.drop = drop
70
        self.add = layers.Add() 
71
        self.concatenate = layers.Concatenate() 
72
        self.upsample = layers.Conv3DTranspose(filters=self.channels//2, kernel_size=(2,2,2), strides=2,
73
                                               padding='valid', kernel_initializer='he_normal', activation=None)
74
        self.convolution = layers.Conv3D(filters=self.channels, kernel_size=(5,5,5), strides=1, 
75
                                         padding='same', kernel_initializer='he_normal', activation=None) 
76
        self.batch_norm = layers.BatchNormalization(scale=False, renorm=True, trainable=self.training)
77
        self.activation = layers.Activation('relu')
78
        self.dropout = layers.Dropout(0.1)
79
        
80
        
81
    def call(self, inputs, skip):  
82
        x = self.upsample(inputs)
83
        if self.norm:
84
            x = self.batch_norm(x, training=self.training)
85
        x = self.activation(x)
86
        cat = self.concatenate([x, skip])
87
        x = cat
88
        
89
        for _ in range(self.n_convs):
90
            x = self.convolution(x)
91
            x = self.activation(x)
92
            if self.drop:
93
                x = self.dropout(x, training=self.training)
94
            
95
        return self.add([x, cat])  
96
97
98
class VNetOutBlock(layers.Layer):
99
100
    def __init__(self):
101
        super(VNetOutBlock, self).__init__()             
102
        self.final = layers.Conv3D(filters=2, kernel_size=(1,1,1), strides=1, 
103
                                         padding='valid', kernel_initializer='he_normal', activation='relu')
104
        
105
        self.binary = layers.Conv3D(filters=1, kernel_size=(1,1,1), strides=1, 
106
                                         padding='valid', kernel_initializer='he_normal', activation='sigmoid')
107
               
108
    def call(self, inputs):     
109
        x = self.final(inputs)
110
111
        return self.binary(x)