|
a |
|
b/fetal_net/model/unet3d/unet.py |
|
|
1 |
import numpy as np |
|
|
2 |
from keras import backend as K |
|
|
3 |
from keras.engine import Input, Model |
|
|
4 |
from keras.layers import Conv3D, MaxPooling3D, UpSampling3D, Activation, BatchNormalization, PReLU, Deconvolution3D |
|
|
5 |
from keras.optimizers import Adam |
|
|
6 |
|
|
|
7 |
from ...metrics import dice_coefficient_loss, get_label_dice_coefficient_function, dice_coefficient, vod_coefficient |
|
|
8 |
|
|
|
9 |
K.set_image_data_format("channels_first") |
|
|
10 |
|
|
|
11 |
try: |
|
|
12 |
from keras.engine import merge |
|
|
13 |
except ImportError: |
|
|
14 |
from keras.layers.merge import concatenate |
|
|
15 |
|
|
|
16 |
|
|
|
17 |
def unet_model_3d(input_shape, pool_size=(2, 2, 2), n_labels=1, initial_learning_rate=0.00001, deconvolution=False, |
|
|
18 |
depth=4, n_base_filters=32, include_label_wise_dice_coefficients=False, |
|
|
19 |
batch_normalization=False, activation_name="sigmoid", loss_function=dice_coefficient_loss, |
|
|
20 |
**kargs): |
|
|
21 |
""" |
|
|
22 |
Builds the 3D UNet Keras model.f |
|
|
23 |
:param metrics: List metrics to be calculated during model training (default is dice coefficient). |
|
|
24 |
:param include_label_wise_dice_coefficients: If True and n_labels is greater than 1, model will report the dice |
|
|
25 |
coefficient for each label as metric. |
|
|
26 |
:param n_base_filters: The number of filters that the first layer in the convolution network will have. Following |
|
|
27 |
layers will contain a multiple of this number. Lowering this number will likely reduce the amount of memory required |
|
|
28 |
to train the model. |
|
|
29 |
:param depth: indicates the depth of the U-shape for the model. The greater the depth, the more max pooling |
|
|
30 |
layers will be added to the model. Lowering the depth may reduce the amount of memory required for training. |
|
|
31 |
:param input_shape: Shape of the input data (n_chanels, x_size, y_size, z_size). The x, y, and z sizes must be |
|
|
32 |
divisible by the pool size to the power of the depth of the UNet, that is pool_size^depth. |
|
|
33 |
:param pool_size: Pool size for the max pooling operations. |
|
|
34 |
:param n_labels: Number of binary labels that the model is learning. |
|
|
35 |
:param initial_learning_rate: Initial learning rate for the model. This will be decayed during training. |
|
|
36 |
:param deconvolution: If set to True, will use transpose convolution(deconvolution) instead of up-sampling. This |
|
|
37 |
increases the amount memory required during training. |
|
|
38 |
:return: Untrained 3D UNet Model |
|
|
39 |
""" |
|
|
40 |
inputs = Input(input_shape) |
|
|
41 |
current_layer = inputs |
|
|
42 |
levels = list() |
|
|
43 |
|
|
|
44 |
# add levels with max pooling |
|
|
45 |
for layer_depth in range(depth): |
|
|
46 |
layer1 = create_convolution_block(input_layer=current_layer, n_filters=n_base_filters*(2**layer_depth), |
|
|
47 |
batch_normalization=batch_normalization) |
|
|
48 |
layer2 = create_convolution_block(input_layer=layer1, n_filters=n_base_filters*(2**layer_depth)*2, |
|
|
49 |
batch_normalization=batch_normalization) |
|
|
50 |
if layer_depth < depth - 1: |
|
|
51 |
current_layer = MaxPooling3D(pool_size=pool_size)(layer2) |
|
|
52 |
levels.append([layer1, layer2, current_layer]) |
|
|
53 |
else: |
|
|
54 |
current_layer = layer2 |
|
|
55 |
levels.append([layer1, layer2]) |
|
|
56 |
|
|
|
57 |
# add levels with up-convolution or up-sampling |
|
|
58 |
for layer_depth in range(depth-2, -1, -1): |
|
|
59 |
up_convolution = get_up_convolution(pool_size=pool_size, deconvolution=deconvolution, |
|
|
60 |
n_filters=current_layer._keras_shape[1])(current_layer) |
|
|
61 |
concat = concatenate([up_convolution, levels[layer_depth][1]], axis=1) |
|
|
62 |
current_layer = create_convolution_block(n_filters=levels[layer_depth][1]._keras_shape[1], |
|
|
63 |
input_layer=concat, batch_normalization=batch_normalization) |
|
|
64 |
current_layer = create_convolution_block(n_filters=levels[layer_depth][1]._keras_shape[1], |
|
|
65 |
input_layer=current_layer, |
|
|
66 |
batch_normalization=batch_normalization) |
|
|
67 |
|
|
|
68 |
final_convolution = Conv3D(n_labels, (1, 1, 1))(current_layer) |
|
|
69 |
act = Activation(activation_name)(final_convolution) |
|
|
70 |
model = Model(inputs=inputs, outputs=act) |
|
|
71 |
|
|
|
72 |
# if not isinstance(metrics, list): |
|
|
73 |
# metrics = [metrics] |
|
|
74 |
# |
|
|
75 |
# if include_label_wise_dice_coefficients and n_labels > 1: |
|
|
76 |
# label_wise_dice_metrics = [get_label_dice_coefficient_function(index) for index in range(n_labels)] |
|
|
77 |
# if metrics: |
|
|
78 |
# metrics = metrics + label_wise_dice_metrics |
|
|
79 |
# else: |
|
|
80 |
# metrics = label_wise_dice_metrics |
|
|
81 |
metrics = ['binary_accuracy', vod_coefficient] |
|
|
82 |
if loss_function != dice_coefficient_loss: |
|
|
83 |
metrics += [dice_coefficient] |
|
|
84 |
|
|
|
85 |
model.compile(optimizer=Adam(lr=initial_learning_rate), loss=loss_function, metrics=metrics) |
|
|
86 |
return model |
|
|
87 |
|
|
|
88 |
|
|
|
89 |
def create_convolution_block(input_layer, n_filters, batch_normalization=False, kernel=(3, 3, 3), activation=None, |
|
|
90 |
padding='same', strides=(1, 1, 1), instance_normalization=False): |
|
|
91 |
""" |
|
|
92 |
|
|
|
93 |
:param strides: |
|
|
94 |
:param input_layer: |
|
|
95 |
:param n_filters: |
|
|
96 |
:param batch_normalization: |
|
|
97 |
:param kernel: |
|
|
98 |
:param activation: Keras activation layer to use. (default is 'relu') |
|
|
99 |
:param padding: |
|
|
100 |
:return: |
|
|
101 |
""" |
|
|
102 |
layer = Conv3D(n_filters, kernel, padding=padding, strides=strides)(input_layer) |
|
|
103 |
if batch_normalization: |
|
|
104 |
layer = BatchNormalization(axis=1)(layer) |
|
|
105 |
elif instance_normalization: |
|
|
106 |
try: |
|
|
107 |
from keras_contrib.layers.normalization import InstanceNormalization |
|
|
108 |
except ImportError: |
|
|
109 |
raise ImportError("Install keras_contrib in order to use instance normalization." |
|
|
110 |
"\nTry: pip install git+https://www.github.com/farizrahman4u/keras-contrib.git") |
|
|
111 |
layer = InstanceNormalization(axis=1)(layer) |
|
|
112 |
if activation is None: |
|
|
113 |
return Activation('relu')(layer) |
|
|
114 |
else: |
|
|
115 |
return activation()(layer) |
|
|
116 |
|
|
|
117 |
|
|
|
118 |
def compute_level_output_shape(n_filters, depth, pool_size, image_shape): |
|
|
119 |
""" |
|
|
120 |
Each level has a particular output shape based on the number of filters used in that level and the depth or number |
|
|
121 |
of max pooling operations that have been done on the data at that point. |
|
|
122 |
:param image_shape: shape of the 3d image. |
|
|
123 |
:param pool_size: the pool_size parameter used in the max pooling operation. |
|
|
124 |
:param n_filters: Number of filters used by the last node in a given level. |
|
|
125 |
:param depth: The number of levels down in the U-shaped model a given node is. |
|
|
126 |
:return: 5D vector of the shape of the output node |
|
|
127 |
""" |
|
|
128 |
output_image_shape = np.asarray(np.divide(image_shape, np.power(pool_size, depth)), dtype=np.int32).tolist() |
|
|
129 |
return tuple([None, n_filters] + output_image_shape) |
|
|
130 |
|
|
|
131 |
|
|
|
132 |
def get_up_convolution(n_filters, pool_size, kernel_size=(2, 2, 2), strides=(2, 2, 2), |
|
|
133 |
deconvolution=False): |
|
|
134 |
if deconvolution: |
|
|
135 |
return Deconvolution3D(filters=n_filters, kernel_size=kernel_size, |
|
|
136 |
strides=strides) |
|
|
137 |
else: |
|
|
138 |
return UpSampling3D(size=pool_size) |