|
a |
|
b/model.py |
|
|
1 |
import tensorflow as tf |
|
|
2 |
|
|
|
3 |
from layers.encoder import Encoder |
|
|
4 |
from layers.decoder import Decoder |
|
|
5 |
from layers.vae import VariationalAutoencoder |
|
|
6 |
|
|
|
7 |
|
|
|
8 |
class Model(tf.keras.models.Model): |
|
|
9 |
def __init__(self, |
|
|
10 |
data_format='channels_last', |
|
|
11 |
groups=8, |
|
|
12 |
reduction=2, |
|
|
13 |
l2_scale=1e-5, |
|
|
14 |
dropout=0.2, |
|
|
15 |
downsampling='conv', |
|
|
16 |
upsampling='conv', |
|
|
17 |
base_filters=16, |
|
|
18 |
depth=4, |
|
|
19 |
in_ch=2, |
|
|
20 |
out_ch=3): |
|
|
21 |
""" Initializes the model, a cross between the 3D U-net |
|
|
22 |
and 2018 BraTS Challenge top model with VAE regularization. |
|
|
23 |
|
|
|
24 |
References: |
|
|
25 |
- [3D U-Net: Learning Dense Volumetric Segmentation from Sparse Annotation](https://arxiv.org/pdf/1606.06650.pdf) |
|
|
26 |
- [3D MRI brain tumor segmentation using autoencoder regularization](https://arxiv.org/pdf/1810.11654.pdf) |
|
|
27 |
""" |
|
|
28 |
super(Model, self).__init__() |
|
|
29 |
self.epoch = tf.Variable(0, name='epoch', trainable=False) |
|
|
30 |
self.encoder = Encoder( |
|
|
31 |
data_format=data_format, |
|
|
32 |
groups=groups, |
|
|
33 |
reduction=reduction, |
|
|
34 |
l2_scale=l2_scale, |
|
|
35 |
dropout=dropout, |
|
|
36 |
downsampling=downsampling, |
|
|
37 |
base_filters=base_filters, |
|
|
38 |
depth=depth) |
|
|
39 |
self.decoder = Decoder( |
|
|
40 |
data_format=data_format, |
|
|
41 |
groups=groups, |
|
|
42 |
reduction=reduction, |
|
|
43 |
l2_scale=l2_scale, |
|
|
44 |
upsampling=upsampling, |
|
|
45 |
base_filters=base_filters, |
|
|
46 |
depth=depth, |
|
|
47 |
out_ch=out_ch) |
|
|
48 |
self.vae = VariationalAutoencoder( |
|
|
49 |
data_format=data_format, |
|
|
50 |
groups=groups, |
|
|
51 |
reduction=reduction, |
|
|
52 |
l2_scale=l2_scale, |
|
|
53 |
upsampling=upsampling, |
|
|
54 |
base_filters=base_filters, |
|
|
55 |
depth=depth, |
|
|
56 |
out_ch=in_ch) |
|
|
57 |
|
|
|
58 |
def call(self, inputs, training=None, inference=None): |
|
|
59 |
# Inference mode does not evaluate VAE branch. |
|
|
60 |
assert (not inference or not training), \ |
|
|
61 |
'Cannot run training and inference modes simultaneously.' |
|
|
62 |
|
|
|
63 |
inputs = self.encoder(inputs, training=training) |
|
|
64 |
|
|
|
65 |
y_pred = self.decoder((inputs[-1], inputs[:-1]), training=training) |
|
|
66 |
|
|
|
67 |
if inference: |
|
|
68 |
return (y_pred, None, None, None) |
|
|
69 |
y_vae, z_mean, z_logvar = self.vae(inputs[-1], training=training) |
|
|
70 |
|
|
|
71 |
return (y_pred, y_vae, z_mean, z_logvar) |