|
a |
|
b/3D/model.py |
|
|
1 |
|
|
|
2 |
import numpy as np |
|
|
3 |
from keras import backend as K |
|
|
4 |
from keras.engine import Input, Model |
|
|
5 |
from keras.layers import Conv3D, MaxPooling3D, UpSampling3D, Activation, BatchNormalization, PReLU |
|
|
6 |
from keras.optimizers import Adam |
|
|
7 |
from functools import partial |
|
|
8 |
|
|
|
9 |
#from metrics import dice_coef_loss, get_label_dice_coefficient_function, dice_coef |
|
|
10 |
|
|
|
11 |
K.set_image_data_format("channels_last") |
|
|
12 |
|
|
|
13 |
try: |
|
|
14 |
from keras.engine import merge |
|
|
15 |
except ImportError: |
|
|
16 |
from keras.layers.merge import concatenate |
|
|
17 |
|
|
|
18 |
|
|
|
19 |
|
|
|
20 |
def dice_coef(y_true, y_pred, smooth=1.): |
|
|
21 |
y_true_f = K.flatten(y_true) |
|
|
22 |
y_pred_f = K.flatten(y_pred) |
|
|
23 |
intersection = K.sum(y_true_f * y_pred_f) |
|
|
24 |
return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth) |
|
|
25 |
|
|
|
26 |
|
|
|
27 |
#def dice_coef_loss(y_true, y_pred): |
|
|
28 |
# return -dice_coef(y_true, y_pred) |
|
|
29 |
|
|
|
30 |
|
|
|
31 |
def dice_coef_loss(y_true, y_pred): |
|
|
32 |
distance = 0 |
|
|
33 |
for label_index in range(4): |
|
|
34 |
dice_coef_class = dice_coef(y_true[:,:,:,:,label_index], y_pred[:,:,:,:,label_index]) |
|
|
35 |
distance = 1 - dice_coef_class + distance |
|
|
36 |
return distance |
|
|
37 |
|
|
|
38 |
|
|
|
39 |
def label_wise_dice_coefficient(y_true, y_pred, label_index): |
|
|
40 |
return dice_coef(y_true[:,:,:,:,label_index], y_pred[:,:, :,:,label_index]) |
|
|
41 |
|
|
|
42 |
|
|
|
43 |
def get_label_dice_coefficient_function(label_index): |
|
|
44 |
f = partial(label_wise_dice_coefficient, label_index=label_index) |
|
|
45 |
f.__setattr__('__name__', 'label_{0}_dice_coef'.format(label_index)) |
|
|
46 |
return f |
|
|
47 |
|
|
|
48 |
def unet_model_3d(input_shape, pool_size=(2, 2, 2), n_labels=4, initial_learning_rate=0.00001, deconvolution=False, |
|
|
49 |
depth=3, n_base_filters=16, include_label_wise_dice_coefficients=True, metrics=dice_coef, |
|
|
50 |
batch_normalization=False): |
|
|
51 |
""" |
|
|
52 |
Builds the 3D UNet Keras model.f |
|
|
53 |
:param metrics: List metrics to be calculated during model training (default is dice coefficient). |
|
|
54 |
:param include_label_wise_dice_coefficients: If True and n_labels is greater than 1, model will report the dice |
|
|
55 |
coefficient for each label as metric. |
|
|
56 |
:param n_base_filters: The number of filters that the first layer in the convolution network will have. Following |
|
|
57 |
layers will contain a multiple of this number. Lowering this number will likely reduce the amount of memory required |
|
|
58 |
to train the model. |
|
|
59 |
:param depth: indicates the depth of the U-shape for the model. The greater the depth, the more max pooling |
|
|
60 |
layers will be added to the model. Lowering the depth may reduce the amount of memory required for training. |
|
|
61 |
:param input_shape: Shape of the input data (n_chanels, x_size, y_size, z_size). The x, y, and z sizes must be |
|
|
62 |
divisible by the pool size to the power of the depth of the UNet, that is pool_size^depth. |
|
|
63 |
:param pool_size: Pool size for the max pooling operations. |
|
|
64 |
:param n_labels: Number of binary labels that the model is learning. |
|
|
65 |
:param initial_learning_rate: Initial learning rate for the model. This will be decayed during training. |
|
|
66 |
:param deconvolution: If set to True, will use transpose convolution(deconvolution) instead of up-sampling. This |
|
|
67 |
increases the amount memory required during training. |
|
|
68 |
:return: Untrained 3D UNet Model |
|
|
69 |
""" |
|
|
70 |
inputs = Input(input_shape) |
|
|
71 |
current_layer = inputs |
|
|
72 |
levels = list() |
|
|
73 |
|
|
|
74 |
# add levels with max pooling |
|
|
75 |
for layer_depth in range(depth): |
|
|
76 |
layer1 = create_convolution_block(input_layer=current_layer, n_filters=n_base_filters*(2**layer_depth), |
|
|
77 |
batch_normalization=batch_normalization) |
|
|
78 |
layer2 = create_convolution_block(input_layer=layer1, n_filters=n_base_filters*(2**layer_depth)*2, |
|
|
79 |
batch_normalization=batch_normalization) |
|
|
80 |
if layer_depth < depth - 1: |
|
|
81 |
current_layer = MaxPooling3D(pool_size=pool_size)(layer2) |
|
|
82 |
levels.append([layer1, layer2, current_layer]) |
|
|
83 |
else: |
|
|
84 |
current_layer = layer2 |
|
|
85 |
levels.append([layer1, layer2]) |
|
|
86 |
|
|
|
87 |
# add levels with up-convolution or up-sampling |
|
|
88 |
for layer_depth in range(depth-2, -1, -1): |
|
|
89 |
up_convolution = get_up_convolution(pool_size=pool_size, deconvolution=deconvolution, depth=layer_depth, |
|
|
90 |
n_filters=current_layer._keras_shape[1], |
|
|
91 |
image_shape=input_shape[-3:])(current_layer) |
|
|
92 |
concat = concatenate([up_convolution, levels[layer_depth][1]], axis=-1) |
|
|
93 |
current_layer = create_convolution_block(n_filters=levels[layer_depth][1]._keras_shape[1], |
|
|
94 |
input_layer=concat, batch_normalization=batch_normalization) |
|
|
95 |
current_layer = create_convolution_block(n_filters=levels[layer_depth][1]._keras_shape[1], |
|
|
96 |
input_layer=current_layer, |
|
|
97 |
batch_normalization=batch_normalization) |
|
|
98 |
|
|
|
99 |
final_convolution = Conv3D(n_labels, (1, 1, 1))(current_layer) |
|
|
100 |
act = Activation('sigmoid')(final_convolution) |
|
|
101 |
model = Model(inputs=inputs, outputs=act) |
|
|
102 |
|
|
|
103 |
if not isinstance(metrics, list): |
|
|
104 |
metrics = [metrics] |
|
|
105 |
|
|
|
106 |
if include_label_wise_dice_coefficients and n_labels > 1: |
|
|
107 |
label_wise_dice_metrics = [get_label_dice_coefficient_function(index) for index in range(n_labels)] |
|
|
108 |
if metrics: |
|
|
109 |
metrics = metrics + label_wise_dice_metrics |
|
|
110 |
else: |
|
|
111 |
metrics = label_wise_dice_metrics |
|
|
112 |
|
|
|
113 |
model.compile(optimizer=Adam(lr=initial_learning_rate), loss=dice_coef_loss, metrics=metrics) |
|
|
114 |
return model |
|
|
115 |
|
|
|
116 |
|
|
|
117 |
def create_convolution_block(input_layer, n_filters, batch_normalization=False, kernel=(3, 3, 3), activation=None, |
|
|
118 |
padding='same'): |
|
|
119 |
""" |
|
|
120 |
|
|
|
121 |
:param input_layer: |
|
|
122 |
:param n_filters: |
|
|
123 |
:param batch_normalization: |
|
|
124 |
:param kernel: |
|
|
125 |
:param activation: Keras activation layer to use. (default is 'relu') |
|
|
126 |
:param padding: |
|
|
127 |
:return: |
|
|
128 |
""" |
|
|
129 |
layer = Conv3D(n_filters, kernel, padding=padding)(input_layer) |
|
|
130 |
if batch_normalization: |
|
|
131 |
layer = BatchNormalization(axis=1)(layer) |
|
|
132 |
if activation is None: |
|
|
133 |
return Activation('relu')(layer) |
|
|
134 |
else: |
|
|
135 |
return activation()(layer) |
|
|
136 |
|
|
|
137 |
|
|
|
138 |
def compute_level_output_shape(n_filters, depth, pool_size, image_shape): |
|
|
139 |
""" |
|
|
140 |
Each level has a particular output shape based on the number of filters used in that level and the depth or number |
|
|
141 |
of max pooling operations that have been done on the data at that point. |
|
|
142 |
:param image_shape: shape of the 3d image. |
|
|
143 |
:param pool_size: the pool_size parameter used in the max pooling operation. |
|
|
144 |
:param n_filters: Number of filters used by the last node in a given level. |
|
|
145 |
:param depth: The number of levels down in the U-shaped model a given node is. |
|
|
146 |
:return: 5D vector of the shape of the output node |
|
|
147 |
""" |
|
|
148 |
output_image_shape = np.asarray(np.divide(image_shape, np.power(pool_size, depth)), dtype=np.int32).tolist() |
|
|
149 |
return tuple([None, n_filters] + output_image_shape) |
|
|
150 |
|
|
|
151 |
|
|
|
152 |
def get_up_convolution(depth, n_filters, pool_size, image_shape, kernel_size=(2, 2, 2), strides=(2, 2, 2), |
|
|
153 |
deconvolution=False): |
|
|
154 |
if deconvolution: |
|
|
155 |
try: |
|
|
156 |
from keras_contrib.layers import Deconvolution3D |
|
|
157 |
except ImportError: |
|
|
158 |
raise ImportError("Install keras_contrib in order to use deconvolution. Otherwise set deconvolution=False." |
|
|
159 |
"\nTry: pip install git+https://www.github.com/farizrahman4u/keras-contrib.git") |
|
|
160 |
|
|
|
161 |
return Deconvolution3D(filters=n_filters, kernel_size=kernel_size, |
|
|
162 |
output_shape=compute_level_output_shape(n_filters=n_filters, depth=depth, |
|
|
163 |
pool_size=pool_size, image_shape=image_shape), |
|
|
164 |
strides=strides, input_shape=compute_level_output_shape(n_filters=n_filters, |
|
|
165 |
depth=depth, |
|
|
166 |
pool_size=pool_size, |
|
|
167 |
image_shape=image_shape)) |
|
|
168 |
else: |
|
|
169 |
return UpSampling3D(size=pool_size) |