|
a |
|
b/inst/deepbleed/blocks/vnet.py |
|
|
1 |
# @author: msharrock |
|
|
2 |
# version: 0.0.1 |
|
|
3 |
|
|
|
4 |
""" |
|
|
5 |
VNet Blocks for DeepBleed |
|
|
6 |
|
|
|
7 |
tensorflow version 2.0 |
|
|
8 |
|
|
|
9 |
""" |
|
|
10 |
|
|
|
11 |
import tensorflow as tf |
|
|
12 |
from tensorflow.keras import layers |
|
|
13 |
|
|
|
14 |
|
|
|
15 |
class VNetInBlock(layers.Layer): |
|
|
16 |
def __init__(self): |
|
|
17 |
super(VNetInBlock, self).__init__() |
|
|
18 |
self.add = layers.Add() |
|
|
19 |
self.concatenate = layers.Concatenate() |
|
|
20 |
self.convolution = layers.Conv3D(filters=16, kernel_size=(5,5,5), strides=1, |
|
|
21 |
padding='same', kernel_initializer='he_normal', activation='relu') |
|
|
22 |
|
|
|
23 |
def call(self, inputs): |
|
|
24 |
x = self.convolution(inputs) |
|
|
25 |
d = self.concatenate(16 * [inputs]) |
|
|
26 |
|
|
|
27 |
return self.add([x, d]) |
|
|
28 |
|
|
|
29 |
|
|
|
30 |
class VNetDownBlock(layers.Layer): |
|
|
31 |
def __init__(self, channels, n_convs, norm=False, drop=False, training=False): |
|
|
32 |
super(VNetDownBlock, self).__init__() |
|
|
33 |
self.channels = channels |
|
|
34 |
self.n_convs = n_convs |
|
|
35 |
self.training = training |
|
|
36 |
self.norm = norm |
|
|
37 |
self.drop = drop |
|
|
38 |
self.add = layers.Add() |
|
|
39 |
self.downsample = layers.Conv3D(filters=self.channels, kernel_size=(2,2,2), strides=2, |
|
|
40 |
padding='valid', kernel_initializer='he_normal', activation=None) |
|
|
41 |
self.convolution = layers.Conv3D(filters=self.channels, kernel_size=(5,5,5), strides=1, |
|
|
42 |
padding='same', kernel_initializer='he_normal', activation=None) |
|
|
43 |
self.batch_norm = layers.BatchNormalization(scale=False, renorm=True, trainable=self.training) |
|
|
44 |
self.activation = layers.Activation('relu') |
|
|
45 |
self.dropout = layers.Dropout(0.1) |
|
|
46 |
|
|
|
47 |
def call(self, inputs): |
|
|
48 |
d = self.downsample(inputs) |
|
|
49 |
if self.norm: |
|
|
50 |
d = self.batch_norm(d, training=self.training) |
|
|
51 |
d = self.activation(d) |
|
|
52 |
x = d |
|
|
53 |
|
|
|
54 |
for _ in range(self.n_convs): |
|
|
55 |
x = self.convolution(x) |
|
|
56 |
x = self.activation(x) |
|
|
57 |
if self.drop: |
|
|
58 |
x = self.dropout(x, training=self.training) |
|
|
59 |
|
|
|
60 |
return self.add([x, d]) |
|
|
61 |
|
|
|
62 |
class VNetUpBlock(layers.Layer): |
|
|
63 |
def __init__(self, channels, n_convs, norm=False, drop=False, training=False): |
|
|
64 |
super(VNetUpBlock, self).__init__() |
|
|
65 |
self.channels = channels |
|
|
66 |
self.n_convs = n_convs |
|
|
67 |
self.training = training |
|
|
68 |
self.norm = norm |
|
|
69 |
self.drop = drop |
|
|
70 |
self.add = layers.Add() |
|
|
71 |
self.concatenate = layers.Concatenate() |
|
|
72 |
self.upsample = layers.Conv3DTranspose(filters=self.channels//2, kernel_size=(2,2,2), strides=2, |
|
|
73 |
padding='valid', kernel_initializer='he_normal', activation=None) |
|
|
74 |
self.convolution = layers.Conv3D(filters=self.channels, kernel_size=(5,5,5), strides=1, |
|
|
75 |
padding='same', kernel_initializer='he_normal', activation=None) |
|
|
76 |
self.batch_norm = layers.BatchNormalization(scale=False, renorm=True, trainable=self.training) |
|
|
77 |
self.activation = layers.Activation('relu') |
|
|
78 |
self.dropout = layers.Dropout(0.1) |
|
|
79 |
|
|
|
80 |
|
|
|
81 |
def call(self, inputs, skip): |
|
|
82 |
x = self.upsample(inputs) |
|
|
83 |
if self.norm: |
|
|
84 |
x = self.batch_norm(x, training=self.training) |
|
|
85 |
x = self.activation(x) |
|
|
86 |
cat = self.concatenate([x, skip]) |
|
|
87 |
x = cat |
|
|
88 |
|
|
|
89 |
for _ in range(self.n_convs): |
|
|
90 |
x = self.convolution(x) |
|
|
91 |
x = self.activation(x) |
|
|
92 |
if self.drop: |
|
|
93 |
x = self.dropout(x, training=self.training) |
|
|
94 |
|
|
|
95 |
return self.add([x, cat]) |
|
|
96 |
|
|
|
97 |
|
|
|
98 |
class VNetOutBlock(layers.Layer): |
|
|
99 |
|
|
|
100 |
def __init__(self): |
|
|
101 |
super(VNetOutBlock, self).__init__() |
|
|
102 |
self.final = layers.Conv3D(filters=2, kernel_size=(1,1,1), strides=1, |
|
|
103 |
padding='valid', kernel_initializer='he_normal', activation='relu') |
|
|
104 |
|
|
|
105 |
self.binary = layers.Conv3D(filters=1, kernel_size=(1,1,1), strides=1, |
|
|
106 |
padding='valid', kernel_initializer='he_normal', activation='sigmoid') |
|
|
107 |
|
|
|
108 |
def call(self, inputs): |
|
|
109 |
x = self.final(inputs) |
|
|
110 |
|
|
|
111 |
return self.binary(x) |