a b/tool/Code/adipose_pipeline.py
1
# Copyright 2019 Population Health Sciences and Image Analysis, German Center for Neurodegenerative Diseases(DZNE)
2
#
3
#    Licensed under the Apache License, Version 2.0 (the "License");
4
#    you may not use this file except in compliance with the License.
5
#    You may obtain a copy of the License at
6
#
7
#        http://www.apache.org/licenses/LICENSE-2.0
8
#
9
#    Unless required by applicable law or agreed to in writing, software
10
#    distributed under the License is distributed on an "AS IS" BASIS,
11
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
#    See the License for the specific language governing permissions and
13
#    limitations under the License.
14
15
16
from __future__ import division
17
import sys
18
sys.path.append('./')
19
sys.path.append('../')
20
21
import os
22
23
import nibabel as nib
24
import pandas as pd
25
from Code.utilities.misc import locate_file
26
from Code.utilities.visualization_misc import multiview_plotting
27
from Code.utilities.metrics import calculate_statistics_v2
28
from Code.utilities.models import run_adipose_localization,run_adipose_segmentation
29
import numpy as np
30
from keras import backend as K
31
from Code.utilities.image_processing import largets_connected_componets,find_labels
32
from Code.utilities.conform import conform
33
34
35
36
def clean_segmentations(label_map):
37
38
    new_label_map=np.copy(label_map)
39
40
    new_label_map= largets_connected_componets(new_label_map)
41
42
    return new_label_map
43
44
45
def extreme_AAT_increase_flag(predict_array,threshold=0.3):
46
47
    extreme_increase_flag = False
48
    for slice in range(1,(predict_array.shape[0]-1)) :
49
        previous_sat =np.sum(predict_array[slice-1,:,:] == 1)
50
        previous_vat =np.sum(predict_array[slice-1,:,:] == 2)
51
        current_sat=np.sum(predict_array[slice,:,:] == 1)
52
        current_vat =np.sum(predict_array[slice,:,:] == 2)
53
        following_sat=np.sum(predict_array[slice+1,:,:] == 1)
54
        following_vat=np.sum(predict_array[slice+1,:,:] == 2)
55
56
        sat_threshold=current_sat*threshold
57
        vat_threshold = current_vat * threshold
58
59
        if np.abs(current_sat-previous_sat) > sat_threshold or np.abs(current_sat-following_sat) > sat_threshold:
60
            extreme_increase_flag= 'SAT increase over the threshold'
61
        elif np.abs(current_vat-previous_vat) > vat_threshold or np.abs(current_vat-following_vat) > vat_threshold:
62
            extreme_increase_flag = 'VAT increase over the threshold'
63
64
    return  extreme_increase_flag
65
66
def stats_variable_initialization(nb_comparments,weighted=True):
67
68
    # initialize Stats Variables
69
    variable_columns = []
70
71
    volume_variable_columns = ['VOL_cm3', 'SAT_VOL_cm3', 'VAT_VOL_cm3', 'AAT_VOL_cm3',
72
                               'VAT_VOL_TO_SAT_VOL', 'VAT_VOL_TO_AAT_VOL', 'SAT_VOL_TO_AAT_VOL']
73
74
    w_volume_variable_columns= ['W_VOL_cm3','WSAT_VOL_cm3', 'WVAT_VOL_cm3',
75
                               'WAAT_VOL_cm3', 'WVAT_VOL_TO_WSAT_VOL', 'WVAT_VOL_TO_WAAT_VOL', 'WSAT_VOL_TO_WAAT_VOL']
76
77
    area_variable_columns = ['HEIGHT_cm', 'AVG_AREA_cm2', 'AVG_PERIMETER_cm']
78
79
    base_variable_len={}
80
    base_variable_len['Area']=len(area_variable_columns)
81
    base_variable_len['Volume']=len(volume_variable_columns)
82
    base_variable_len['W_Volume']=len(w_volume_variable_columns)
83
84
    roi_areas = ['wb']
85
    if nb_comparments != 0:
86
        # From Feet to Head
87
        for i in range(int(nb_comparments), 0, -1):
88
            roi_areas.append('Q' + str(i))
89
90
    for roi in roi_areas:
91
        for area_id in area_variable_columns:
92
            variable_columns.append(roi + '_' + area_id)
93
        for vol_id in volume_variable_columns:
94
            variable_columns.append(roi + '_' + vol_id)
95
96
        if weighted:
97
            for w_vol_id in w_volume_variable_columns:
98
                variable_columns.append(roi + '_' + w_vol_id)
99
100
101
    variable_columns.insert(0, 'imageid')
102
    variable_columns.insert(1, '#_Slices')
103
    variable_columns.insert(2,'FLAGS')
104
105
    return variable_columns,base_variable_len
106
107
def check_image_contrast(water_array,fat_array):
108
109
    slice = fat_array.shape[0] // 2
110
111
    water_slice=water_array[slice,20:-20,20:-20]
112
    fat_slice=fat_array[slice,20:-20,20:-20]
113
114
    intensity_max=np.max([np.max(water_slice),np.max(fat_slice)])
115
116
    water_slice=water_slice/intensity_max
117
    fat_slice=fat_slice/intensity_max
118
119
    new_fat=np.zeros((fat_slice.shape[0],fat_slice.shape[1]))
120
121
    new_fat[fat_slice >= (0.10 * np.max(fat_slice))] = 2
122
    new_fat[fat_slice >= (0.30*np.max(fat_slice))] = 1
123
124
    border_idx=np.where(new_fat == 2)
125
126
    point_index=np.arange(0,len(border_idx[0]),10)
127
    point_y=border_idx[0][point_index]
128
    point_x=border_idx[1][point_index]
129
130
    fat_count=0
131
    no_fat_count=0
132
    for j in range(len(point_x)):
133
134
        value = fat_slice[point_y[j],point_x[j]] -water_slice[point_y[j],point_x[j]]
135
136
        if value < 0 :
137
            fat_count += 1
138
        else:
139
            no_fat_count += 1
140
141
    if no_fat_count > fat_count or ((no_fat_count/fat_count) > 0.75):
142
        FLAG='Check image contrast'
143
    else:
144
        FLAG = False
145
146
    return FLAG
147
148
def check_flags(predicted_array,water_array,fat_array,ratio_vat_sat,threshold=0.30,sat_to_vat_threshold=2.0):
149
    FLAG = check_image_contrast(water_array,fat_array)
150
151
    if FLAG == False:
152
153
        FLAG=extreme_AAT_increase_flag(predicted_array,threshold=threshold)
154
155
        if ratio_vat_sat > sat_to_vat_threshold:
156
            FLAG = 'High VAT to SAT ratio'
157
158
    return FLAG
159
160
161
162
163
164
def run_adipose_pipeline(args,flags,save_path='/',data_path='/',id='Test'):
165
166
    output_stats = 'AAT_stats.tsv'
167
    output_pred_fat = 'AAT_pred.nii.gz'
168
    output_pred = 'ALL_pred.nii.gz'
169
    qc_images = []
170
171
172
173
    print('-' * 30)
174
    print('Loading Subject')
175
    print(id)
176
    sub = id
177
178
179
    fat_file = locate_file('*'+str(args.fat_image), data_path)
180
    water_file = locate_file('*'+str(args.water_image), data_path)
181
182
183
184
    # Check fat
185
    if fat_file:
186
        print('-' * 30)
187
        print('Loading Fat Image')
188
        print(fat_file[0])
189
        #Load Fat Images
190
        fat_img = nib.load(fat_file[0])
191
        ishape = fat_img.shape
192
193
        #Check if  data from example_data_folder was loaded : Only contains the value -9999
194
        if len(np.unique(fat_img.get_data())) > 2:
195
            if len(ishape) > 3 and ishape[3] != 1:
196
                print('ERROR: Multiple input frames (' + format(fat_img.shape[3]) + ') not supported!')
197
            else:
198
                fat_img = conform(fat_img, flags=flags, order=args.order, save_path=save_path, mod='fat',
199
                                  axial=args.axial)
200
                fat_array = fat_img.get_data()
201
                fat_array = np.swapaxes(fat_array, 0, 2)
202
                fat_zooms = fat_img.header.get_zooms()
203
204
205
            print('-' * 30)
206
            print('Loading Water Image')
207
            #Check water image
208
            if not water_file:
209
                weighted=False
210
                print('No water image found, weighted volumes would not be calculated')
211
                water_array=np.zeros(fat_array.shape)
212
            else:
213
                print(water_file[0])
214
                weighted=True
215
                water_img = nib.load(water_file[0])
216
                ishape = fat_img.shape
217
                if len(ishape) > 3 and ishape[3] != 1:
218
                    print('ERROR: Multiple input frames (' + format(water_img.shape[3]) + ') not supported!')
219
                    weighted = False
220
                    print('No water image found, weighted volumes would not be calculated')
221
                    water_array = np.zeros(fat_array.shape)
222
                else:
223
                    water_img = conform(water_img, flags=flags, order=args.order, save_path=save_path, mod='water',
224
                                       axial=args.axial)
225
                    water_array = water_img.get_data()
226
                    water_array = np.swapaxes(water_array, 0, 2)
227
228
229
            variable_columns, base_variable_len = stats_variable_initialization(args.compartments,weighted)
230
            ratio_position = variable_columns.index('wb_VAT_VOL_TO_SAT_VOL')
231
232
            pixel_matrix = np.zeros((1, len(variable_columns)), dtype=object)
233
            row_px = 0
234
235
            img_spacing=np.copy(fat_zooms)
236
237
            if not args.run_stats:
238
239
                if args.run_localization:
240
                    high_idx,low_idx=run_adipose_localization(fat_array,flags)
241
                    K.clear_session()
242
                else:
243
                    high_idx=fat_array.shape[0]
244
                    low_idx= 0
245
246
                print('the index values are %d, %d' % (low_idx, high_idx))
247
248
                # Image Segmentation
249
                pred_array=run_adipose_segmentation(fat_array,flags,args)
250
                K.clear_session()
251
252
            else:
253
                pred_file = locate_file('*AAT_pred.nii.gz', data_path)
254
255
                if pred_file :
256
                    pred_img = nib.load(pred_file[0])
257
                    pred_array = pred_img.get_data()
258
                    pred_array = np.swapaxes(pred_array, 0, 2)
259
                    pred_zooms = pred_img.header.get_zooms()
260
                    img_spacing = np.copy(pred_zooms)
261
                    # img_spacing[0] = pred_zooms[2]
262
                    # img_spacing[2] = pred_zooms[0]
263
264
265
                    high_idx, low_idx = find_labels(pred_array)
266
                else :
267
                    print('Subject has no prediction map, a ATT_pred.nii.gz file is required to run the stats option')
268
                    print('-' * 30)
269
                    print('ERROR: Subject doesnt have a AAT_pred.nii.gz')
270
271
            print('-' * 30)
272
            print('Calculating Stats')
273
274
            pred_array[0:low_idx,:,:]=0
275
            pred_array[high_idx:,:,:]=0
276
277
            pred_array [low_idx:high_idx, :, :] = clean_segmentations(pred_array[low_idx:high_idx, :, :])
278
279
            pixel_matrix[row_px:row_px + 1, 0] = sub
280
281
282
            pixel_matrix[row_px:row_px + 1, 3:] = calculate_statistics_v2(pred_array[low_idx:high_idx, :, :],
283
                                                                          water_array[low_idx:high_idx, :, :],
284
                                                                          fat_array[low_idx:high_idx, :, :],
285
                                                                          low_idx, high_idx, variable_columns[3:],
286
                                                                          base_variable_len, img_spacing,
287
                                                                          args.compartments, weighted=weighted)
288
289
290
            pixel_matrix[row_px:row_px + 1, 1] = int(high_idx-low_idx)
291
            pixel_matrix[row_px:row_px + 1, 2] = check_flags(pred_array[low_idx:high_idx, :, :],water_array=water_array,fat_array=fat_array,
292
                                                             ratio_vat_sat=pixel_matrix[row_px, ratio_position],
293
                                                             threshold=args.increase_threshold,sat_to_vat_threshold=args.sat_to_vat_threshold)
294
295
            df = pd.DataFrame(pixel_matrix[row_px:row_px+1, :], columns=variable_columns)
296
297
            if not os.path.isdir(os.path.join(save_path, 'Segmentations')):
298
                os.mkdir(os.path.join(save_path, 'Segmentations'))
299
300
            seg_path=os.path.join(save_path, 'Segmentations')
301
302
            df.to_csv(seg_path+'/'+output_stats, sep='\t', index=False)
303
            df.to_json(seg_path+ '/AAT_variables_summary.json', orient='records')
304
305
            row_px += 1
306
307
            # Modified images for display
308
            disp_fat = np.flipud(fat_array[:])
309
            disp_fat = np.fliplr(disp_fat[:])
310
            disp_pred=np.copy(pred_array)
311
            disp_pred = np.flipud(disp_pred)
312
            disp_pred = np.fliplr(disp_pred)
313
314
            #only display SAT and VAT
315
            disp_pred[disp_pred>=3]=0
316
317
            idx = (np.where(disp_pred > 0))
318
            low_idx = np.min(idx[0])
319
            high_idx = np.max(idx[0])
320
321
            interval = (high_idx - low_idx) // 4
322
323
            # Control images of the segmentation
324
            if not args.control_images:
325
                if not os.path.isdir(os.path.join(save_path, 'QC')):
326
                    os.mkdir(os.path.join(save_path, 'QC'))
327
                for i in range(4):
328
                    control_point = [0, int(np.ceil(disp_fat.shape[1] / 2)), int(np.ceil(disp_fat.shape[2] / 2))]
329
                    control_point[0] = int(np.ceil(np.random.uniform(high_idx - interval * i, high_idx - interval * ((i + 1)))))
330
                    multiview_plotting(disp_fat, disp_pred, control_point, save_path+'/QC/QC_%s.png' % i,
331
                                       classes=5, alpha=0.5, nbviews=3)
332
333
            print('-' * 30)
334
            print('Saving Segmentation')
335
            # Save prediction
336
            pred_array=np.swapaxes(pred_array,2,0)
337
            pred_img = nib.Nifti1Image(pred_array, fat_img.affine, fat_img.header)
338
            nib.save(pred_img, seg_path+'/'+output_pred)
339
340
            pred_array[pred_array>=3]=0
341
            pred_img = nib.Nifti1Image(pred_array, fat_img.affine, fat_img.header)
342
            nib.save(pred_img, seg_path+'/'+output_pred_fat)
343
344
            print('-' * 30)
345
346
            print('Finish Subject %s' % sub)
347
348
            print('-' * 30)
349
350
351
        else :
352
            print('ERROR: Input image empty \n'
353
                  'Note : Volumes from the example_data_folder are empty \n'
354
                  'The example_data_folder is only a ilustrative example on how volumes have to be organized for FatSegNet to work.')
355
            print('Please provided your own dixon MR scans')
356
357
    else:
358
        print('')
359
        print('-' * 30)
360
        print('ERROR: Subject doesnt have a Fat Image named %s,\n'
361
              'Please verified that the name provided to the -fat argument matches the one in the participants folder (default : FatImaging_F.nii.gz )'%str(args.fat_image))
362
363