a b/examples/deepstrain_vs_cvi.py
1
import os
2
import glob
3
import time
4
import pydicom
5
import numpy as np
6
import pandas as pd
7
import nibabel as nib
8
9
PREPARE_INPUT_DATA_WITH_CARSON = False
10
PREDICT = False
11
12
if PREPARE_INPUT_DATA_WITH_CARSON:
13
    
14
    from data import base_dataset
15
    from data.nifti_dataset import resample_nifti
16
    from tensorflow.keras.optimizers import Adam
17
    from options.test_options import TestOptions
18
    from models import deep_strain_model
19
20
    def normalize(x, axis=(0,1,2)):
21
        # normalize per volume (x,y,z) frame
22
        mu = x.mean(axis=axis, keepdims=True)
23
        sd = x.std(axis=axis, keepdims=True)
24
        return (x-mu)/(sd+1e-8)
25
26
    def get_mask(V, netS):
27
        nx, ny, nz, nt = V.shape
28
        
29
        M = np.zeros((nx,ny,nz,nt))
30
        v = V.transpose((2,3,0,1)).reshape((-1,nx,ny)) # (nz*nt,nx,ny)
31
        v = normalize(v)
32
        m = netS(v[:,nx//2-64:nx//2+64,ny//2-64:ny//2+64,None])
33
        M[nx//2-64:nx//2+64,ny//2-64:ny//2+64] += np.argmax(m, -1).transpose((1,2,0)).reshape((128,128,nz,nt))
34
        
35
        return M
36
37
    # options
38
    opt   = TestOptions().parse()
39
    model = deep_strain_model.DeepStrain(Adam, opt)
40
    netS  = model.get_netS()
41
    netS.load_weights('/home/mmorales/main_python/DeepStrain/pretrained_models/carson_Jan2021.h5')
42
43
    time_resample = []
44
    time_carson   = []
45
    
46
    # load subjects by batches
47
    batches = ['batch_%d'%(j) for j in range(1,11)] + ['HFpEF_batch_%d'%(j) for j in range(1,5)]
48
49
    for batch in batches:
50
51
        niftis_folder    = '/mnt/alp/Research Data Sets/DeepStrain_vs_CVI/%s/niftis/standard'%(batch)
52
        niftis_folder_out_carson = '/mnt/alp/Research Data Sets/DeepStrain_vs_CVI/%s/input_to_DeepStrain_with_CarSON'%(batch)
53
54
        for SubjectID_folder in glob.glob(os.path.join(niftis_folder, '*')):
55
            for nifti_path in glob.glob(os.path.join(SubjectID_folder, '*.nii.gz')):
56
                
57
                try:
58
                    V_nifti = nib.load(nifti_path)
59
                    start = time.time()
60
                    V_nifti_resampled = resample_nifti(V_nifti, order=1, in_plane_resolution_mm=1.25, number_of_slices=None)
61
                    end = time.time()
62
                    time_resample += [end - start]
63
64
                    # here we normalize per image, not volume
65
                    V = V_nifti_resampled.get_fdata()
66
                    V = normalize(V, axis=(0,1))
67
68
                    # In this case we don't yet have a segmentation we can use to crop the image. 
69
                    # In most cases we can simply center crop (see `get_mask` function): 
70
                    start = time.time()
71
                    M = get_mask(V, netS)
72
                    end = time.time()
73
                    time_carson += [end - start]
74
75
                    # ONLY IF YOU KNOW YOUR IMAGE IS ROUGHLY NEAR CENTER 
76
                    M_nifti_resampled = nib.Nifti1Image(M, affine=V_nifti_resampled.affine)
77
                    # resample back to original resolution
78
                    start = time.time()
79
                    M_nifti = base_dataset.resample_nifti_inv(nifti_resampled=M_nifti_resampled, 
80
                                                              zooms=V_nifti.header.get_zooms()[:3], 
81
                                                              order=0, mode='nearest')
82
                    end = time.time()
83
                    time_resample += [end - start]
84
                    fname = os.path.basename(nifti_path).strip('.nii.gz').replace('(','').replace(')','')
85
                    output_folder = os.path.join(niftis_folder_out_carson, os.path.basename(SubjectID_folder))
86
87
                    os.makedirs(output_folder, exist_ok=True)
88
89
                    V_nifti.to_filename(os.path.join(output_folder, fname+'.nii.gz'))
90
                    M_nifti.to_filename(os.path.join(output_folder, fname+'_segmentation.nii.gz'))
91
                except:
92
                    print("Error here, check!", nifti_path)
93
                    continue
94
95
    np.save('/mnt/alp/Research Data Sets/DeepStrain_vs_CVI/time_resample', time_resample)
96
    np.save('/mnt/alp/Research Data Sets/DeepStrain_vs_CVI/time_carson', time_carson)
97
98
99
100
101
102
if PREDICT:
103
104
    from data.nifti_dataset import resample_nifti
105
    from data.base_dataset import _roll2center_crop
106
    from scipy.ndimage.measurements import center_of_mass
107
108
   
109
    from aux import myocardial_strain
110
    from scipy.ndimage import gaussian_filter
111
112
    from tensorflow.keras.optimizers import Adam
113
    from options.test_options import TestOptions
114
    from models import deep_strain_model
115
116
    def normalize(x):
117
        # normalize per volume (x,y,z) frame
118
        mu = x.mean(axis=(0,1,2), keepdims=True)
119
        sd = x.std(axis=(0,1,2), keepdims=True)
120
        return (x-mu)/(sd+1e-8)
121
122
    # options
123
    opt = TestOptions().parse()
124
    preprocess = opt.preprocess
125
    model   = deep_strain_model.DeepStrain(Adam, opt)
126
    
127
    opt.number_of_slices = 16 
128
    opt.preprocess = opt.preprocess_carmen + '_' + preprocess
129
    opt.pretrained_models_netME = '/home/mmorales/main_python/DeepStrain/pretrained_models/carmenJan2021.h5'
130
    model   = deep_strain_model.DeepStrain(Adam, opt)
131
    netME   = model.get_netME()
132
    netME.load_weights('/home/mmorales/main_python/DeepStrain/pretrained_models/carmen_Jan2021.h5')
133
134
    batches = ['batch_%d'%(j) for j in range(1,11)] + ['HFpEF_batch_%d'%(j) for j in range(1,5)]
135
136
    # calculate using CarSON segmentations. Note that segmentations based on other segmentation models is also possible
137
    for method in ['_with_CarSON']:    
138
        # verify these labels!
139
        if method == '_with_CarSON':
140
            tissue_label_blood_pool=3; tissue_label_myocardium=2; tissue_label_rv=1
141
        else:
142
            tissue_label_blood_pool=1; tissue_label_myocardium=2; tissue_label_rv=3
143
            
144
        for batch in batches:
145
            print(batch)
146
            # only use data whose cines and corresponding segmentations have been prepared
147
            niftis_folder_out = '/mnt/alp/Research Data Sets/DeepStrain_vs_CVI/%s/input_to_DeepStrain%s'%(batch, method)
148
149
            RUN_CARMEN = True
150
            if RUN_CARMEN:
151
                for SubjectID_folder in glob.glob(os.path.join(niftis_folder_out, '*')):
152
                    
153
                    for nifti_path in glob.glob(os.path.join(SubjectID_folder, '*_segmentation.nii.gz')):
154
155
                        output_folder = os.path.join(os.path.dirname(niftis_folder_out), 
156
                                                    'output_from_DeepStrain%s'%(method),
157
                                                    os.path.basename(SubjectID_folder))
158
                        
159
                        if os.path.isdir(output_folder): continue
160
161
                        print(output_folder)
162
163
                        V_nifti = nib.load(nifti_path.replace('_segmentation', ''))
164
                        M_nifti = nib.load(nifti_path)
165
166
                        V_nifti = resample_nifti(V_nifti, order=1, number_of_slices=16)
167
                        M_nifti = resample_nifti(M_nifti, order=0, number_of_slices=16)
168
                        
169
                        
170
171
                        center = center_of_mass(M_nifti.get_fdata()==tissue_label_myocardium)
172
                        V = _roll2center_crop(x=V_nifti.get_fdata(), center=center)
173
                        M = _roll2center_crop(x=M_nifti.get_fdata(), center=center)
174
175
                        I = np.argmax((M==tissue_label_rv).sum(axis=(0,1,3)))
176
                        if I > M.shape[2]//2:
177
                            print('Apex to Base. Inverting.')
178
                            V = V[:,:,::-1]
179
                            M = M[:,:,::-1]
180
                        
181
                        V = normalize(V)
182
183
                        nx, ny, nz, nt = V.shape
184
185
                        try:
186
                            # calculate volume across the mid-ventricular section to estimate end-diastole
187
                            volumes = (M_nifti.get_fdata()[:,:,nz//2-2:nz+3]==tissue_label_blood_pool).sum(axis=(0,1,2))
188
                        except:
189
                            print('Need to use all volume to estimate ED/ES')
190
                            volumes = (M_nifti.get_fdata()==tissue_label_blood_pool).sum(axis=(0,1,2))
191
192
                        ED = np.argmax(volumes)
193
                        ES = np.argmin(volumes)
194
                        
195
                        # set end-diastole as the reference time frame
196
                        M_0 = M[..., ED]
197
                        V_0 = np.repeat(np.expand_dims(V[..., ED],-1), nt, axis=-1)
198
                        V_t = V
199
200
                        # move time frames to the batch dimension to predict all at onces
201
                        V_0 = np.transpose(V_0, (3,0,1,2))
202
                        V_t = np.transpose(V_t, (3,0,1,2))
203
                        y_t = netME([V_0, V_t]).numpy()
204
205
                        
206
                        os.makedirs(output_folder, exist_ok=True)
207
208
                        # save for calculation. Only the the end-diastolic mask is necessary
209
                        np.save(os.path.join(output_folder, 'V_0.npy'), V_0)
210
                        np.save(os.path.join(output_folder, 'V_t.npy'), V_t)
211
                        np.save(os.path.join(output_folder, 'y_t.npy'), y_t)
212
                        np.save(os.path.join(output_folder, 'M_0.npy'), M_0)
213
214
215
                
216
            folder = '/mnt/alp/Research Data Sets/DeepStrain_vs_CVI/%s/output_from_DeepStrain%s'%(batch, method)
217
218
            df = {'SubjectID':[], 'RadialStain':[], 'CircumferentialStrain':[], 'TimeFrame':[]}
219
            for j, subject_folder in enumerate(glob.glob(os.path.join(folder, '*'))):
220
                V_0 = np.load(os.path.join(subject_folder, 'V_0.npy'))
221
                V_t = np.load(os.path.join(subject_folder, 'V_t.npy'))
222
                y_t = np.load(os.path.join(subject_folder, 'y_t.npy'))
223
                M_0 = np.load(os.path.join(subject_folder, 'M_0.npy'))
224
225
                y_t = gaussian_filter(y_t, sigma=(0,2,2,0,0))
226
227
                for time_frame in range(len(y_t)):
228
                    try:
229
                        strain = myocardial_strain.MyocardialStrain(mask=M_0, flow=y_t[time_frame,:,:,:,:])
230
                        strain.calculate_strain(lv_label=tissue_label_blood_pool)
231
232
                        df['SubjectID']             += [os.path.basename(subject_folder)]
233
                        df['RadialStain']           += [100*strain.Err[strain.mask_rot==tissue_label_myocardium].mean()]   
234
                        df['CircumferentialStrain'] += [100*strain.Ecc[strain.mask_rot==tissue_label_myocardium].mean()]
235
                        df['TimeFrame']             += [time_frame]
236
                    except:
237
                        print('Error in ', subject_folder)
238
239
            df = pd.DataFrame(df)
240
            df.to_csv('/mnt/alp/Research Data Sets/DeepStrain_vs_CVI/%s/output_from_DeepStrain%s.csv'%(batch, method))
241