a b/tool/Code/utilities/models.py
1
# Copyright 2019 Population Health Sciences and Image Analysis, German Center for Neurodegenerative Diseases(DZNE)
2
#
3
#    Licensed under the Apache License, Version 2.0 (the "License");
4
#    you may not use this file except in compliance with the License.
5
#    You may obtain a copy of the License at
6
#
7
#        http://www.apache.org/licenses/LICENSE-2.0
8
#
9
#    Unless required by applicable law or agreed to in writing, software
10
#    distributed under the License is distributed on an "AS IS" BASIS,
11
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
#    See the License for the specific language governing permissions and
13
#    limitations under the License.
14
15
16
17
import sys
18
sys.path.append('../')
19
sys.path.append('./')
20
import os
21
22
import numpy as np
23
from keras.models import load_model
24
from keras import backend as K
25
import Code.utilities.loss as loss
26
from Code.utilities.image_processing import change_data_plane,swap_axes,largets_connected_componets,find_labels
27
28
29
30
31
def find_unique_index_slice(data):
32
33
    aux_index=[]
34
35
    for z in range(data.shape[0]):
36
        labels,counts=np.unique(data[z,:,:],return_counts=True)
37
        if 2 in labels:
38
            num_pixels=np.sum(counts[1:])
39
            position=np.where(labels==2)
40
            if counts[position[0][0]] >= (num_pixels*0.8):
41
                aux_index.append(z)
42
43
44
    higher_index=np.max(aux_index)
45
    lower_index= np.min(aux_index)
46
47
    return higher_index,lower_index
48
49
50
def run_adipose_localization(data,flags):
51
52
    print('-' * 30)
53
    print ('Run Abdominal Localization Block')
54
    planes=['coronal','sagittal']
55
    high_idx=0
56
    low_idx=0
57
    for plane in planes:
58
        plane_model = os.path.join(flags['localizationModels'], 'Loc_CDFNet_Baseline_' + str(plane))
59
        params_path = os.path.join(plane_model, 'train_parameters.npy')
60
        params = np.load(params_path).item()
61
        params['modelParams']['SavePath'] = plane_model
62
        tmp_high_idx,tmp_low_idx=test_localization_model(params,data)
63
        high_idx += tmp_high_idx
64
        low_idx +=tmp_low_idx
65
66
    high_idx=int(high_idx // 2)
67
    low_idx=int(low_idx // 2)
68
    return high_idx,low_idx
69
70
71
72
def run_adipose_segmentation(data,flags,args):
73
74
    print('-' * 30)
75
    print ('Run AAT Segmentation Block')
76
77
    # ============  Load Params ==================================
78
    # Multiviewmodel
79
    multiview_path = os.path.join(flags['multiviewModel'], 'Baseline_Mixed_Multi_Plane')
80
    multiview_params = np.load(os.path.join(multiview_path,'train_parameters.npy')).item()
81
    multiview_params['modelParams']['SavePath']= multiview_path
82
    nbclasses = multiview_params['modelParams']['nClasses']
83
    # uni axial Model Path
84
    if args.axial:
85
        print('-' * 30)
86
        print('Segmentation done only on the axial plane')
87
        print('-' * 30)
88
        base_line_dir_axial = os.path.join(flags['singleViewModels'], 'CDFNet_Baseline_axial')
89
        base_line_dirs=[]
90
        base_line_dirs.append(base_line_dir_axial)
91
    else:
92
        base_line_dir_axial  = os.path.join(flags['singleViewModels'],'CDFNet_Baseline_axial')
93
        base_line_dir_frontal= os.path.join(flags['singleViewModels'],'CDFNet_Baseline_coronal')
94
        base_line_dir_sagital= os.path.join(flags['singleViewModels'],'CDFNet_Baseline_sagittal')
95
96
        base_line_dirs=[]
97
        base_line_dirs.append(base_line_dir_axial)
98
        base_line_dirs.append(base_line_dir_frontal)
99
        base_line_dirs.append(base_line_dir_sagital)
100
101
102
        test_data = np.zeros((1, multiview_params['modelParams']['PatchSize'][0],
103
                              multiview_params['modelParams']['PatchSize'][1],
104
                              multiview_params['modelParams']['PatchSize'][2], len(base_line_dirs) * nbclasses))
105
    i = 0
106
    for plane_model in base_line_dirs:
107
        print(plane_model)
108
        params_path = os.path.join(plane_model, 'train_parameters.npy')
109
        params = np.load(params_path).item()
110
        params['modelParams']['SavePath']=plane_model
111
        if args.axial:
112
            y_predict = test_model(params, data)
113
        else:
114
            test_data[0, 0:data.shape[0], :, :, i * nbclasses:(i + 1) * nbclasses] = test_model(params, data)
115
        i += 1
116
    if args.axial:
117
        final_img = np.argmax(y_predict, axis=-1)
118
        final_img = np.asarray(final_img, dtype=np.int16)
119
    else:
120
        final_img = test_multiplane(multiview_params, test_data)
121
122
    return final_img
123
124
125
def test_multiplane(params,data):
126
    """Segmentation network for the probability maps of frontal,axial and sagittal
127
    Args:
128
        params: train parameters of the network
129
        data: ndarray (int or float) containing 15 probability maps
130
131
    Returns:
132
        out :ndarray, prediction array of 5 classes
133
"""
134
    # ============  Path Configuration ==================================
135
136
    model_name = params['modelParams']['ModelName']
137
    #model_path = os.path.join(params['modelParams']['SavePath'], model_name)
138
    model_path = params['modelParams']['SavePath']
139
    # ============  Model Configuration ==================================
140
    n_ch = params['modelParams']['nChannels']
141
    nb_classes = params['modelParams']['nClasses']
142
    batch_size = params['modelParams']['BatchSize']
143
    MedBalFactor=params['modelParams']['MedFrequency']
144
    loss_type=params['modelParams']['Loss_Function']
145
    sigma=params['modelParams']['GradientSigma']
146
    print('-' * 30)
147
    print('model path')
148
    print(model_path + '/logs/' + model_name + '_best_weights.h5')
149
150
    model = load_model(model_path + '/logs/' + model_name + '_best_weights.h5',
151
                       custom_objects={'logistic_loss': loss.custom_loss(MedBalFactor, sigma, loss_type),
152
                                       'weighted_logistic_loss': loss.custom_loss(MedBalFactor, sigma, loss_type),
153
                                       'weighted_gradient_loss': loss.custom_loss(MedBalFactor, sigma, loss_type),
154
                                       'mixed_loss': loss.custom_loss(MedBalFactor, sigma, loss_type),
155
                                       'dice_loss': loss.custom_loss(MedBalFactor, sigma, loss_type),
156
                                       'dice_coef': loss.dice_coef,
157
                                       'dice_coef_0': loss.dice_coef_0,
158
                                       'dice_coef_1': loss.dice_coef_1,
159
                                       'dice_coef_2': loss.dice_coef_2,
160
                                       'dice_coef_3': loss.dice_coef_3,
161
                                       'dice_coef_4': loss.dice_coef_4,
162
                                       'average_dice_coef': loss.average_dice_coef})
163
164
165
166
    print('-' * 30)
167
    print('Evaluating Multiview model  ...')
168
    print('-' * 30)
169
170
171
    y_predict = model.predict(data, batch_size=batch_size, verbose=0)
172
173
    # Reorganize prediction data
174
    y_predict = np.argmax(y_predict, axis=-1)
175
    y_predict = y_predict.reshape(data.shape[1], data.shape[2], data.shape[3])
176
    y_predict = np.asarray(y_predict, dtype=np.int16)
177
    print(y_predict.shape)
178
179
180
    return y_predict
181
182
def test_model(params,data):
183
    """Segmentation network for each view (frontal,axial and sagittal)
184
    Args:
185
        params: train parameters of the network
186
        data: ndarray (int or float) containing the fat image
187
188
    Returns:
189
        out :ndarray, prediction array of 5 classes for each view
190
    """
191
    # ============  Path Configuration ==================================
192
193
    model_name = params['modelParams']['ModelName']
194
    model_path = params['modelParams']['SavePath']
195
    # ============  Model Configuration ==================================
196
    n_ch = params['modelParams']['nChannels']
197
    nb_classes = params['modelParams']['nClasses']
198
    batch_size = params['modelParams']['BatchSize']
199
    MedBalFactor = params['modelParams']['MedFrequency']
200
    loss_type = params['modelParams']['Loss_Function']
201
    sigma = params['modelParams']['GradientSigma']
202
    plane = params['modelParams']['Plane']
203
204
    if plane == 'frontal':
205
        plane = 'coronal'
206
    if plane == 'sagital':
207
        plane = 'sagittal'
208
209
    print('-' * 30)
210
    print('Evaluating %s...' % plane)
211
    print('-' * 30)
212
    print('Testing %s'%model_name)
213
    print('model path')
214
    print(model_path + '/logs/' + model_name + '_best_weights.h5')
215
216
    model = load_model(model_path + '/logs/' + model_name + '_best_weights.h5',
217
                       custom_objects={'logistic_loss': loss.custom_loss(MedBalFactor, sigma, loss_type),
218
                                       'weighted_logistic_loss': loss.custom_loss(MedBalFactor, sigma, loss_type),
219
                                       'weighted_gradient_loss': loss.custom_loss(MedBalFactor, sigma, loss_type),
220
                                       'mixed_loss': loss.custom_loss(MedBalFactor, sigma, loss_type),
221
                                       'dice_loss': loss.custom_loss(MedBalFactor, sigma, loss_type),
222
                                       'dice_coef': loss.dice_coef,
223
                                       'dice_coef_0': loss.dice_coef_0,
224
                                       'dice_coef_1': loss.dice_coef_1,
225
                                       'dice_coef_2': loss.dice_coef_2,
226
                                       'dice_coef_3': loss.dice_coef_3,
227
                                       'dice_coef_4': loss.dice_coef_4,
228
                                       'average_dice_coef': loss.average_dice_coef})
229
230
231
232
233
    X_test=np.copy(data)
234
235
    print('input size')
236
    print(X_test.shape)
237
    X_test,idx_low,idx_high = change_data_plane(X_test, plane=params['modelParams']['Plane'],return_index=True)
238
239
    X_test=X_test.reshape((X_test.shape[0], X_test.shape[1], X_test.shape[2], n_ch))
240
241
    # ============  Evaluating ==================================
242
    print('-' * 30)
243
    print('Evaluating %s...'%plane)
244
    print('-' * 30)
245
    y_predict = model.predict(X_test, batch_size=batch_size, verbose=0)
246
    print('Change Plane to %s'%plane)
247
    y_predict = change_data_plane(y_predict, plane=params['modelParams']['Plane'])
248
    y_predict=y_predict[idx_low:idx_high, :, :, :]
249
    print(y_predict.shape)
250
251
252
    return y_predict
253
254
255
def test_localization_model(params,data):
256
    """Segmentation network for localizing the region of intertest (frontal,axial and sagittal)
257
    Args:
258
        params: train parameters of the network
259
        data: ndarray (int or float) containing the fat image
260
261
    Returns:
262
        out : slices boundaries of the ROI
263
    """
264
    # ============  Path Configuration ==================================
265
266
    model_name = params['modelParams']['ModelName']
267
    #model_path = os.path.join(params['modelParams']['SavePath'], model_name)
268
    model_path = params['modelParams']['SavePath']
269
270
    # ============  Model Configuration ==================================
271
    n_ch = params['modelParams']['nChannels']
272
    nb_classes = params['modelParams']['nClasses']
273
    batch_size = params['modelParams']['BatchSize']
274
    MedBalFactor = params['modelParams']['MedFrequency']
275
    loss_type = params['modelParams']['Loss_Function']
276
    sigma = params['modelParams']['GradientSigma']
277
    plane = params['modelParams']['Plane']
278
    if plane == 'frontal':
279
        plane= 'coronal'
280
    if plane == 'sagital':
281
        plane = 'sagittal'
282
283
    print('-' * 30)
284
    print('Evaluating %s...'%plane)
285
    print('-' * 30)
286
    print('Testing %s'%model_name)
287
    print('model path')
288
    print(model_path + '/logs/' + model_name + '_best_weights.h5')
289
290
    model = load_model(model_path + '/logs/' + model_name + '_best_weights.h5',
291
                       custom_objects={'logistic_loss': loss.custom_loss(MedBalFactor, sigma, loss_type),
292
                                       'weighted_logistic_loss': loss.custom_loss(MedBalFactor, sigma, loss_type),
293
                                       'weighted_gradient_loss': loss.custom_loss(MedBalFactor, sigma, loss_type),
294
                                       'mixed_loss': loss.custom_loss(MedBalFactor, sigma, loss_type),
295
                                       'dice_loss': loss.custom_loss(MedBalFactor, sigma, loss_type),
296
                                       'dice_coef': loss.dice_coef,
297
                                       'dice_coef_0': loss.dice_coef_0,
298
                                       'dice_coef_1': loss.dice_coef_1,
299
                                       'dice_coef_2': loss.dice_coef_2,
300
                                       'dice_coef_3': loss.dice_coef_3,
301
                                       'dice_coef_4': loss.dice_coef_4,
302
                                       'average_dice_coef': loss.average_dice_coef})
303
304
    X_test = np.copy(data)
305
    print('input size')
306
    print(X_test.shape)
307
    X_test, idx_low, idx_high = change_data_plane(X_test, plane=params['modelParams']['Plane'], return_index=True)
308
309
    X_test = X_test.reshape((X_test.shape[0], X_test.shape[1], X_test.shape[2], n_ch))
310
311
    # ============  Evaluating ==================================
312
    print('-' * 30)
313
    y_predict = model.predict(X_test, batch_size=batch_size, verbose=0)
314
    y_predict = np.argmax(y_predict, axis=-1)
315
    print('Change Plane to %s'%plane)
316
    print(y_predict.shape)
317
    y_predict = change_data_plane(y_predict, plane=params['modelParams']['Plane'])
318
    y_predict=y_predict[idx_low:idx_high, :, :]
319
320
    print(y_predict.shape)
321
    high_idx,low_idx=find_unique_index_slice(y_predict)
322
323
324
    return high_idx,low_idx