|
a |
|
b/tool/Code/utilities/models.py |
|
|
1 |
# Copyright 2019 Population Health Sciences and Image Analysis, German Center for Neurodegenerative Diseases(DZNE) |
|
|
2 |
# |
|
|
3 |
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
4 |
# you may not use this file except in compliance with the License. |
|
|
5 |
# You may obtain a copy of the License at |
|
|
6 |
# |
|
|
7 |
# http://www.apache.org/licenses/LICENSE-2.0 |
|
|
8 |
# |
|
|
9 |
# Unless required by applicable law or agreed to in writing, software |
|
|
10 |
# distributed under the License is distributed on an "AS IS" BASIS, |
|
|
11 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
12 |
# See the License for the specific language governing permissions and |
|
|
13 |
# limitations under the License. |
|
|
14 |
|
|
|
15 |
|
|
|
16 |
|
|
|
17 |
import sys |
|
|
18 |
sys.path.append('../') |
|
|
19 |
sys.path.append('./') |
|
|
20 |
import os |
|
|
21 |
|
|
|
22 |
import numpy as np |
|
|
23 |
from keras.models import load_model |
|
|
24 |
from keras import backend as K |
|
|
25 |
import Code.utilities.loss as loss |
|
|
26 |
from Code.utilities.image_processing import change_data_plane,swap_axes,largets_connected_componets,find_labels |
|
|
27 |
|
|
|
28 |
|
|
|
29 |
|
|
|
30 |
|
|
|
31 |
def find_unique_index_slice(data): |
|
|
32 |
|
|
|
33 |
aux_index=[] |
|
|
34 |
|
|
|
35 |
for z in range(data.shape[0]): |
|
|
36 |
labels,counts=np.unique(data[z,:,:],return_counts=True) |
|
|
37 |
if 2 in labels: |
|
|
38 |
num_pixels=np.sum(counts[1:]) |
|
|
39 |
position=np.where(labels==2) |
|
|
40 |
if counts[position[0][0]] >= (num_pixels*0.8): |
|
|
41 |
aux_index.append(z) |
|
|
42 |
|
|
|
43 |
|
|
|
44 |
higher_index=np.max(aux_index) |
|
|
45 |
lower_index= np.min(aux_index) |
|
|
46 |
|
|
|
47 |
return higher_index,lower_index |
|
|
48 |
|
|
|
49 |
|
|
|
50 |
def run_adipose_localization(data,flags): |
|
|
51 |
|
|
|
52 |
print('-' * 30) |
|
|
53 |
print ('Run Abdominal Localization Block') |
|
|
54 |
planes=['coronal','sagittal'] |
|
|
55 |
high_idx=0 |
|
|
56 |
low_idx=0 |
|
|
57 |
for plane in planes: |
|
|
58 |
plane_model = os.path.join(flags['localizationModels'], 'Loc_CDFNet_Baseline_' + str(plane)) |
|
|
59 |
params_path = os.path.join(plane_model, 'train_parameters.npy') |
|
|
60 |
params = np.load(params_path).item() |
|
|
61 |
params['modelParams']['SavePath'] = plane_model |
|
|
62 |
tmp_high_idx,tmp_low_idx=test_localization_model(params,data) |
|
|
63 |
high_idx += tmp_high_idx |
|
|
64 |
low_idx +=tmp_low_idx |
|
|
65 |
|
|
|
66 |
high_idx=int(high_idx // 2) |
|
|
67 |
low_idx=int(low_idx // 2) |
|
|
68 |
return high_idx,low_idx |
|
|
69 |
|
|
|
70 |
|
|
|
71 |
|
|
|
72 |
def run_adipose_segmentation(data,flags,args): |
|
|
73 |
|
|
|
74 |
print('-' * 30) |
|
|
75 |
print ('Run AAT Segmentation Block') |
|
|
76 |
|
|
|
77 |
# ============ Load Params ================================== |
|
|
78 |
# Multiviewmodel |
|
|
79 |
multiview_path = os.path.join(flags['multiviewModel'], 'Baseline_Mixed_Multi_Plane') |
|
|
80 |
multiview_params = np.load(os.path.join(multiview_path,'train_parameters.npy')).item() |
|
|
81 |
multiview_params['modelParams']['SavePath']= multiview_path |
|
|
82 |
nbclasses = multiview_params['modelParams']['nClasses'] |
|
|
83 |
# uni axial Model Path |
|
|
84 |
if args.axial: |
|
|
85 |
print('-' * 30) |
|
|
86 |
print('Segmentation done only on the axial plane') |
|
|
87 |
print('-' * 30) |
|
|
88 |
base_line_dir_axial = os.path.join(flags['singleViewModels'], 'CDFNet_Baseline_axial') |
|
|
89 |
base_line_dirs=[] |
|
|
90 |
base_line_dirs.append(base_line_dir_axial) |
|
|
91 |
else: |
|
|
92 |
base_line_dir_axial = os.path.join(flags['singleViewModels'],'CDFNet_Baseline_axial') |
|
|
93 |
base_line_dir_frontal= os.path.join(flags['singleViewModels'],'CDFNet_Baseline_coronal') |
|
|
94 |
base_line_dir_sagital= os.path.join(flags['singleViewModels'],'CDFNet_Baseline_sagittal') |
|
|
95 |
|
|
|
96 |
base_line_dirs=[] |
|
|
97 |
base_line_dirs.append(base_line_dir_axial) |
|
|
98 |
base_line_dirs.append(base_line_dir_frontal) |
|
|
99 |
base_line_dirs.append(base_line_dir_sagital) |
|
|
100 |
|
|
|
101 |
|
|
|
102 |
test_data = np.zeros((1, multiview_params['modelParams']['PatchSize'][0], |
|
|
103 |
multiview_params['modelParams']['PatchSize'][1], |
|
|
104 |
multiview_params['modelParams']['PatchSize'][2], len(base_line_dirs) * nbclasses)) |
|
|
105 |
i = 0 |
|
|
106 |
for plane_model in base_line_dirs: |
|
|
107 |
print(plane_model) |
|
|
108 |
params_path = os.path.join(plane_model, 'train_parameters.npy') |
|
|
109 |
params = np.load(params_path).item() |
|
|
110 |
params['modelParams']['SavePath']=plane_model |
|
|
111 |
if args.axial: |
|
|
112 |
y_predict = test_model(params, data) |
|
|
113 |
else: |
|
|
114 |
test_data[0, 0:data.shape[0], :, :, i * nbclasses:(i + 1) * nbclasses] = test_model(params, data) |
|
|
115 |
i += 1 |
|
|
116 |
if args.axial: |
|
|
117 |
final_img = np.argmax(y_predict, axis=-1) |
|
|
118 |
final_img = np.asarray(final_img, dtype=np.int16) |
|
|
119 |
else: |
|
|
120 |
final_img = test_multiplane(multiview_params, test_data) |
|
|
121 |
|
|
|
122 |
return final_img |
|
|
123 |
|
|
|
124 |
|
|
|
125 |
def test_multiplane(params,data): |
|
|
126 |
"""Segmentation network for the probability maps of frontal,axial and sagittal |
|
|
127 |
Args: |
|
|
128 |
params: train parameters of the network |
|
|
129 |
data: ndarray (int or float) containing 15 probability maps |
|
|
130 |
|
|
|
131 |
Returns: |
|
|
132 |
out :ndarray, prediction array of 5 classes |
|
|
133 |
""" |
|
|
134 |
# ============ Path Configuration ================================== |
|
|
135 |
|
|
|
136 |
model_name = params['modelParams']['ModelName'] |
|
|
137 |
#model_path = os.path.join(params['modelParams']['SavePath'], model_name) |
|
|
138 |
model_path = params['modelParams']['SavePath'] |
|
|
139 |
# ============ Model Configuration ================================== |
|
|
140 |
n_ch = params['modelParams']['nChannels'] |
|
|
141 |
nb_classes = params['modelParams']['nClasses'] |
|
|
142 |
batch_size = params['modelParams']['BatchSize'] |
|
|
143 |
MedBalFactor=params['modelParams']['MedFrequency'] |
|
|
144 |
loss_type=params['modelParams']['Loss_Function'] |
|
|
145 |
sigma=params['modelParams']['GradientSigma'] |
|
|
146 |
print('-' * 30) |
|
|
147 |
print('model path') |
|
|
148 |
print(model_path + '/logs/' + model_name + '_best_weights.h5') |
|
|
149 |
|
|
|
150 |
model = load_model(model_path + '/logs/' + model_name + '_best_weights.h5', |
|
|
151 |
custom_objects={'logistic_loss': loss.custom_loss(MedBalFactor, sigma, loss_type), |
|
|
152 |
'weighted_logistic_loss': loss.custom_loss(MedBalFactor, sigma, loss_type), |
|
|
153 |
'weighted_gradient_loss': loss.custom_loss(MedBalFactor, sigma, loss_type), |
|
|
154 |
'mixed_loss': loss.custom_loss(MedBalFactor, sigma, loss_type), |
|
|
155 |
'dice_loss': loss.custom_loss(MedBalFactor, sigma, loss_type), |
|
|
156 |
'dice_coef': loss.dice_coef, |
|
|
157 |
'dice_coef_0': loss.dice_coef_0, |
|
|
158 |
'dice_coef_1': loss.dice_coef_1, |
|
|
159 |
'dice_coef_2': loss.dice_coef_2, |
|
|
160 |
'dice_coef_3': loss.dice_coef_3, |
|
|
161 |
'dice_coef_4': loss.dice_coef_4, |
|
|
162 |
'average_dice_coef': loss.average_dice_coef}) |
|
|
163 |
|
|
|
164 |
|
|
|
165 |
|
|
|
166 |
print('-' * 30) |
|
|
167 |
print('Evaluating Multiview model ...') |
|
|
168 |
print('-' * 30) |
|
|
169 |
|
|
|
170 |
|
|
|
171 |
y_predict = model.predict(data, batch_size=batch_size, verbose=0) |
|
|
172 |
|
|
|
173 |
# Reorganize prediction data |
|
|
174 |
y_predict = np.argmax(y_predict, axis=-1) |
|
|
175 |
y_predict = y_predict.reshape(data.shape[1], data.shape[2], data.shape[3]) |
|
|
176 |
y_predict = np.asarray(y_predict, dtype=np.int16) |
|
|
177 |
print(y_predict.shape) |
|
|
178 |
|
|
|
179 |
|
|
|
180 |
return y_predict |
|
|
181 |
|
|
|
182 |
def test_model(params,data): |
|
|
183 |
"""Segmentation network for each view (frontal,axial and sagittal) |
|
|
184 |
Args: |
|
|
185 |
params: train parameters of the network |
|
|
186 |
data: ndarray (int or float) containing the fat image |
|
|
187 |
|
|
|
188 |
Returns: |
|
|
189 |
out :ndarray, prediction array of 5 classes for each view |
|
|
190 |
""" |
|
|
191 |
# ============ Path Configuration ================================== |
|
|
192 |
|
|
|
193 |
model_name = params['modelParams']['ModelName'] |
|
|
194 |
model_path = params['modelParams']['SavePath'] |
|
|
195 |
# ============ Model Configuration ================================== |
|
|
196 |
n_ch = params['modelParams']['nChannels'] |
|
|
197 |
nb_classes = params['modelParams']['nClasses'] |
|
|
198 |
batch_size = params['modelParams']['BatchSize'] |
|
|
199 |
MedBalFactor = params['modelParams']['MedFrequency'] |
|
|
200 |
loss_type = params['modelParams']['Loss_Function'] |
|
|
201 |
sigma = params['modelParams']['GradientSigma'] |
|
|
202 |
plane = params['modelParams']['Plane'] |
|
|
203 |
|
|
|
204 |
if plane == 'frontal': |
|
|
205 |
plane = 'coronal' |
|
|
206 |
if plane == 'sagital': |
|
|
207 |
plane = 'sagittal' |
|
|
208 |
|
|
|
209 |
print('-' * 30) |
|
|
210 |
print('Evaluating %s...' % plane) |
|
|
211 |
print('-' * 30) |
|
|
212 |
print('Testing %s'%model_name) |
|
|
213 |
print('model path') |
|
|
214 |
print(model_path + '/logs/' + model_name + '_best_weights.h5') |
|
|
215 |
|
|
|
216 |
model = load_model(model_path + '/logs/' + model_name + '_best_weights.h5', |
|
|
217 |
custom_objects={'logistic_loss': loss.custom_loss(MedBalFactor, sigma, loss_type), |
|
|
218 |
'weighted_logistic_loss': loss.custom_loss(MedBalFactor, sigma, loss_type), |
|
|
219 |
'weighted_gradient_loss': loss.custom_loss(MedBalFactor, sigma, loss_type), |
|
|
220 |
'mixed_loss': loss.custom_loss(MedBalFactor, sigma, loss_type), |
|
|
221 |
'dice_loss': loss.custom_loss(MedBalFactor, sigma, loss_type), |
|
|
222 |
'dice_coef': loss.dice_coef, |
|
|
223 |
'dice_coef_0': loss.dice_coef_0, |
|
|
224 |
'dice_coef_1': loss.dice_coef_1, |
|
|
225 |
'dice_coef_2': loss.dice_coef_2, |
|
|
226 |
'dice_coef_3': loss.dice_coef_3, |
|
|
227 |
'dice_coef_4': loss.dice_coef_4, |
|
|
228 |
'average_dice_coef': loss.average_dice_coef}) |
|
|
229 |
|
|
|
230 |
|
|
|
231 |
|
|
|
232 |
|
|
|
233 |
X_test=np.copy(data) |
|
|
234 |
|
|
|
235 |
print('input size') |
|
|
236 |
print(X_test.shape) |
|
|
237 |
X_test,idx_low,idx_high = change_data_plane(X_test, plane=params['modelParams']['Plane'],return_index=True) |
|
|
238 |
|
|
|
239 |
X_test=X_test.reshape((X_test.shape[0], X_test.shape[1], X_test.shape[2], n_ch)) |
|
|
240 |
|
|
|
241 |
# ============ Evaluating ================================== |
|
|
242 |
print('-' * 30) |
|
|
243 |
print('Evaluating %s...'%plane) |
|
|
244 |
print('-' * 30) |
|
|
245 |
y_predict = model.predict(X_test, batch_size=batch_size, verbose=0) |
|
|
246 |
print('Change Plane to %s'%plane) |
|
|
247 |
y_predict = change_data_plane(y_predict, plane=params['modelParams']['Plane']) |
|
|
248 |
y_predict=y_predict[idx_low:idx_high, :, :, :] |
|
|
249 |
print(y_predict.shape) |
|
|
250 |
|
|
|
251 |
|
|
|
252 |
return y_predict |
|
|
253 |
|
|
|
254 |
|
|
|
255 |
def test_localization_model(params,data): |
|
|
256 |
"""Segmentation network for localizing the region of intertest (frontal,axial and sagittal) |
|
|
257 |
Args: |
|
|
258 |
params: train parameters of the network |
|
|
259 |
data: ndarray (int or float) containing the fat image |
|
|
260 |
|
|
|
261 |
Returns: |
|
|
262 |
out : slices boundaries of the ROI |
|
|
263 |
""" |
|
|
264 |
# ============ Path Configuration ================================== |
|
|
265 |
|
|
|
266 |
model_name = params['modelParams']['ModelName'] |
|
|
267 |
#model_path = os.path.join(params['modelParams']['SavePath'], model_name) |
|
|
268 |
model_path = params['modelParams']['SavePath'] |
|
|
269 |
|
|
|
270 |
# ============ Model Configuration ================================== |
|
|
271 |
n_ch = params['modelParams']['nChannels'] |
|
|
272 |
nb_classes = params['modelParams']['nClasses'] |
|
|
273 |
batch_size = params['modelParams']['BatchSize'] |
|
|
274 |
MedBalFactor = params['modelParams']['MedFrequency'] |
|
|
275 |
loss_type = params['modelParams']['Loss_Function'] |
|
|
276 |
sigma = params['modelParams']['GradientSigma'] |
|
|
277 |
plane = params['modelParams']['Plane'] |
|
|
278 |
if plane == 'frontal': |
|
|
279 |
plane= 'coronal' |
|
|
280 |
if plane == 'sagital': |
|
|
281 |
plane = 'sagittal' |
|
|
282 |
|
|
|
283 |
print('-' * 30) |
|
|
284 |
print('Evaluating %s...'%plane) |
|
|
285 |
print('-' * 30) |
|
|
286 |
print('Testing %s'%model_name) |
|
|
287 |
print('model path') |
|
|
288 |
print(model_path + '/logs/' + model_name + '_best_weights.h5') |
|
|
289 |
|
|
|
290 |
model = load_model(model_path + '/logs/' + model_name + '_best_weights.h5', |
|
|
291 |
custom_objects={'logistic_loss': loss.custom_loss(MedBalFactor, sigma, loss_type), |
|
|
292 |
'weighted_logistic_loss': loss.custom_loss(MedBalFactor, sigma, loss_type), |
|
|
293 |
'weighted_gradient_loss': loss.custom_loss(MedBalFactor, sigma, loss_type), |
|
|
294 |
'mixed_loss': loss.custom_loss(MedBalFactor, sigma, loss_type), |
|
|
295 |
'dice_loss': loss.custom_loss(MedBalFactor, sigma, loss_type), |
|
|
296 |
'dice_coef': loss.dice_coef, |
|
|
297 |
'dice_coef_0': loss.dice_coef_0, |
|
|
298 |
'dice_coef_1': loss.dice_coef_1, |
|
|
299 |
'dice_coef_2': loss.dice_coef_2, |
|
|
300 |
'dice_coef_3': loss.dice_coef_3, |
|
|
301 |
'dice_coef_4': loss.dice_coef_4, |
|
|
302 |
'average_dice_coef': loss.average_dice_coef}) |
|
|
303 |
|
|
|
304 |
X_test = np.copy(data) |
|
|
305 |
print('input size') |
|
|
306 |
print(X_test.shape) |
|
|
307 |
X_test, idx_low, idx_high = change_data_plane(X_test, plane=params['modelParams']['Plane'], return_index=True) |
|
|
308 |
|
|
|
309 |
X_test = X_test.reshape((X_test.shape[0], X_test.shape[1], X_test.shape[2], n_ch)) |
|
|
310 |
|
|
|
311 |
# ============ Evaluating ================================== |
|
|
312 |
print('-' * 30) |
|
|
313 |
y_predict = model.predict(X_test, batch_size=batch_size, verbose=0) |
|
|
314 |
y_predict = np.argmax(y_predict, axis=-1) |
|
|
315 |
print('Change Plane to %s'%plane) |
|
|
316 |
print(y_predict.shape) |
|
|
317 |
y_predict = change_data_plane(y_predict, plane=params['modelParams']['Plane']) |
|
|
318 |
y_predict=y_predict[idx_low:idx_high, :, :] |
|
|
319 |
|
|
|
320 |
print(y_predict.shape) |
|
|
321 |
high_idx,low_idx=find_unique_index_slice(y_predict) |
|
|
322 |
|
|
|
323 |
|
|
|
324 |
return high_idx,low_idx |