|
a |
|
b/examples/deepstrain_vs_cvi.py |
|
|
1 |
import os |
|
|
2 |
import glob |
|
|
3 |
import time |
|
|
4 |
import pydicom |
|
|
5 |
import numpy as np |
|
|
6 |
import pandas as pd |
|
|
7 |
import nibabel as nib |
|
|
8 |
|
|
|
9 |
PREPARE_INPUT_DATA_WITH_CARSON = False |
|
|
10 |
PREDICT = False |
|
|
11 |
|
|
|
12 |
if PREPARE_INPUT_DATA_WITH_CARSON: |
|
|
13 |
|
|
|
14 |
from data import base_dataset |
|
|
15 |
from data.nifti_dataset import resample_nifti |
|
|
16 |
from tensorflow.keras.optimizers import Adam |
|
|
17 |
from options.test_options import TestOptions |
|
|
18 |
from models import deep_strain_model |
|
|
19 |
|
|
|
20 |
def normalize(x, axis=(0,1,2)): |
|
|
21 |
# normalize per volume (x,y,z) frame |
|
|
22 |
mu = x.mean(axis=axis, keepdims=True) |
|
|
23 |
sd = x.std(axis=axis, keepdims=True) |
|
|
24 |
return (x-mu)/(sd+1e-8) |
|
|
25 |
|
|
|
26 |
def get_mask(V, netS): |
|
|
27 |
nx, ny, nz, nt = V.shape |
|
|
28 |
|
|
|
29 |
M = np.zeros((nx,ny,nz,nt)) |
|
|
30 |
v = V.transpose((2,3,0,1)).reshape((-1,nx,ny)) # (nz*nt,nx,ny) |
|
|
31 |
v = normalize(v) |
|
|
32 |
m = netS(v[:,nx//2-64:nx//2+64,ny//2-64:ny//2+64,None]) |
|
|
33 |
M[nx//2-64:nx//2+64,ny//2-64:ny//2+64] += np.argmax(m, -1).transpose((1,2,0)).reshape((128,128,nz,nt)) |
|
|
34 |
|
|
|
35 |
return M |
|
|
36 |
|
|
|
37 |
# options |
|
|
38 |
opt = TestOptions().parse() |
|
|
39 |
model = deep_strain_model.DeepStrain(Adam, opt) |
|
|
40 |
netS = model.get_netS() |
|
|
41 |
netS.load_weights('/home/mmorales/main_python/DeepStrain/pretrained_models/carson_Jan2021.h5') |
|
|
42 |
|
|
|
43 |
time_resample = [] |
|
|
44 |
time_carson = [] |
|
|
45 |
|
|
|
46 |
# load subjects by batches |
|
|
47 |
batches = ['batch_%d'%(j) for j in range(1,11)] + ['HFpEF_batch_%d'%(j) for j in range(1,5)] |
|
|
48 |
|
|
|
49 |
for batch in batches: |
|
|
50 |
|
|
|
51 |
niftis_folder = '/mnt/alp/Research Data Sets/DeepStrain_vs_CVI/%s/niftis/standard'%(batch) |
|
|
52 |
niftis_folder_out_carson = '/mnt/alp/Research Data Sets/DeepStrain_vs_CVI/%s/input_to_DeepStrain_with_CarSON'%(batch) |
|
|
53 |
|
|
|
54 |
for SubjectID_folder in glob.glob(os.path.join(niftis_folder, '*')): |
|
|
55 |
for nifti_path in glob.glob(os.path.join(SubjectID_folder, '*.nii.gz')): |
|
|
56 |
|
|
|
57 |
try: |
|
|
58 |
V_nifti = nib.load(nifti_path) |
|
|
59 |
start = time.time() |
|
|
60 |
V_nifti_resampled = resample_nifti(V_nifti, order=1, in_plane_resolution_mm=1.25, number_of_slices=None) |
|
|
61 |
end = time.time() |
|
|
62 |
time_resample += [end - start] |
|
|
63 |
|
|
|
64 |
# here we normalize per image, not volume |
|
|
65 |
V = V_nifti_resampled.get_fdata() |
|
|
66 |
V = normalize(V, axis=(0,1)) |
|
|
67 |
|
|
|
68 |
# In this case we don't yet have a segmentation we can use to crop the image. |
|
|
69 |
# In most cases we can simply center crop (see `get_mask` function): |
|
|
70 |
start = time.time() |
|
|
71 |
M = get_mask(V, netS) |
|
|
72 |
end = time.time() |
|
|
73 |
time_carson += [end - start] |
|
|
74 |
|
|
|
75 |
# ONLY IF YOU KNOW YOUR IMAGE IS ROUGHLY NEAR CENTER |
|
|
76 |
M_nifti_resampled = nib.Nifti1Image(M, affine=V_nifti_resampled.affine) |
|
|
77 |
# resample back to original resolution |
|
|
78 |
start = time.time() |
|
|
79 |
M_nifti = base_dataset.resample_nifti_inv(nifti_resampled=M_nifti_resampled, |
|
|
80 |
zooms=V_nifti.header.get_zooms()[:3], |
|
|
81 |
order=0, mode='nearest') |
|
|
82 |
end = time.time() |
|
|
83 |
time_resample += [end - start] |
|
|
84 |
fname = os.path.basename(nifti_path).strip('.nii.gz').replace('(','').replace(')','') |
|
|
85 |
output_folder = os.path.join(niftis_folder_out_carson, os.path.basename(SubjectID_folder)) |
|
|
86 |
|
|
|
87 |
os.makedirs(output_folder, exist_ok=True) |
|
|
88 |
|
|
|
89 |
V_nifti.to_filename(os.path.join(output_folder, fname+'.nii.gz')) |
|
|
90 |
M_nifti.to_filename(os.path.join(output_folder, fname+'_segmentation.nii.gz')) |
|
|
91 |
except: |
|
|
92 |
print("Error here, check!", nifti_path) |
|
|
93 |
continue |
|
|
94 |
|
|
|
95 |
np.save('/mnt/alp/Research Data Sets/DeepStrain_vs_CVI/time_resample', time_resample) |
|
|
96 |
np.save('/mnt/alp/Research Data Sets/DeepStrain_vs_CVI/time_carson', time_carson) |
|
|
97 |
|
|
|
98 |
|
|
|
99 |
|
|
|
100 |
|
|
|
101 |
|
|
|
102 |
if PREDICT: |
|
|
103 |
|
|
|
104 |
from data.nifti_dataset import resample_nifti |
|
|
105 |
from data.base_dataset import _roll2center_crop |
|
|
106 |
from scipy.ndimage.measurements import center_of_mass |
|
|
107 |
|
|
|
108 |
|
|
|
109 |
from aux import myocardial_strain |
|
|
110 |
from scipy.ndimage import gaussian_filter |
|
|
111 |
|
|
|
112 |
from tensorflow.keras.optimizers import Adam |
|
|
113 |
from options.test_options import TestOptions |
|
|
114 |
from models import deep_strain_model |
|
|
115 |
|
|
|
116 |
def normalize(x): |
|
|
117 |
# normalize per volume (x,y,z) frame |
|
|
118 |
mu = x.mean(axis=(0,1,2), keepdims=True) |
|
|
119 |
sd = x.std(axis=(0,1,2), keepdims=True) |
|
|
120 |
return (x-mu)/(sd+1e-8) |
|
|
121 |
|
|
|
122 |
# options |
|
|
123 |
opt = TestOptions().parse() |
|
|
124 |
preprocess = opt.preprocess |
|
|
125 |
model = deep_strain_model.DeepStrain(Adam, opt) |
|
|
126 |
|
|
|
127 |
opt.number_of_slices = 16 |
|
|
128 |
opt.preprocess = opt.preprocess_carmen + '_' + preprocess |
|
|
129 |
opt.pretrained_models_netME = '/home/mmorales/main_python/DeepStrain/pretrained_models/carmenJan2021.h5' |
|
|
130 |
model = deep_strain_model.DeepStrain(Adam, opt) |
|
|
131 |
netME = model.get_netME() |
|
|
132 |
netME.load_weights('/home/mmorales/main_python/DeepStrain/pretrained_models/carmen_Jan2021.h5') |
|
|
133 |
|
|
|
134 |
batches = ['batch_%d'%(j) for j in range(1,11)] + ['HFpEF_batch_%d'%(j) for j in range(1,5)] |
|
|
135 |
|
|
|
136 |
# calculate using CarSON segmentations. Note that segmentations based on other segmentation models is also possible |
|
|
137 |
for method in ['_with_CarSON']: |
|
|
138 |
# verify these labels! |
|
|
139 |
if method == '_with_CarSON': |
|
|
140 |
tissue_label_blood_pool=3; tissue_label_myocardium=2; tissue_label_rv=1 |
|
|
141 |
else: |
|
|
142 |
tissue_label_blood_pool=1; tissue_label_myocardium=2; tissue_label_rv=3 |
|
|
143 |
|
|
|
144 |
for batch in batches: |
|
|
145 |
print(batch) |
|
|
146 |
# only use data whose cines and corresponding segmentations have been prepared |
|
|
147 |
niftis_folder_out = '/mnt/alp/Research Data Sets/DeepStrain_vs_CVI/%s/input_to_DeepStrain%s'%(batch, method) |
|
|
148 |
|
|
|
149 |
RUN_CARMEN = True |
|
|
150 |
if RUN_CARMEN: |
|
|
151 |
for SubjectID_folder in glob.glob(os.path.join(niftis_folder_out, '*')): |
|
|
152 |
|
|
|
153 |
for nifti_path in glob.glob(os.path.join(SubjectID_folder, '*_segmentation.nii.gz')): |
|
|
154 |
|
|
|
155 |
output_folder = os.path.join(os.path.dirname(niftis_folder_out), |
|
|
156 |
'output_from_DeepStrain%s'%(method), |
|
|
157 |
os.path.basename(SubjectID_folder)) |
|
|
158 |
|
|
|
159 |
if os.path.isdir(output_folder): continue |
|
|
160 |
|
|
|
161 |
print(output_folder) |
|
|
162 |
|
|
|
163 |
V_nifti = nib.load(nifti_path.replace('_segmentation', '')) |
|
|
164 |
M_nifti = nib.load(nifti_path) |
|
|
165 |
|
|
|
166 |
V_nifti = resample_nifti(V_nifti, order=1, number_of_slices=16) |
|
|
167 |
M_nifti = resample_nifti(M_nifti, order=0, number_of_slices=16) |
|
|
168 |
|
|
|
169 |
|
|
|
170 |
|
|
|
171 |
center = center_of_mass(M_nifti.get_fdata()==tissue_label_myocardium) |
|
|
172 |
V = _roll2center_crop(x=V_nifti.get_fdata(), center=center) |
|
|
173 |
M = _roll2center_crop(x=M_nifti.get_fdata(), center=center) |
|
|
174 |
|
|
|
175 |
I = np.argmax((M==tissue_label_rv).sum(axis=(0,1,3))) |
|
|
176 |
if I > M.shape[2]//2: |
|
|
177 |
print('Apex to Base. Inverting.') |
|
|
178 |
V = V[:,:,::-1] |
|
|
179 |
M = M[:,:,::-1] |
|
|
180 |
|
|
|
181 |
V = normalize(V) |
|
|
182 |
|
|
|
183 |
nx, ny, nz, nt = V.shape |
|
|
184 |
|
|
|
185 |
try: |
|
|
186 |
# calculate volume across the mid-ventricular section to estimate end-diastole |
|
|
187 |
volumes = (M_nifti.get_fdata()[:,:,nz//2-2:nz+3]==tissue_label_blood_pool).sum(axis=(0,1,2)) |
|
|
188 |
except: |
|
|
189 |
print('Need to use all volume to estimate ED/ES') |
|
|
190 |
volumes = (M_nifti.get_fdata()==tissue_label_blood_pool).sum(axis=(0,1,2)) |
|
|
191 |
|
|
|
192 |
ED = np.argmax(volumes) |
|
|
193 |
ES = np.argmin(volumes) |
|
|
194 |
|
|
|
195 |
# set end-diastole as the reference time frame |
|
|
196 |
M_0 = M[..., ED] |
|
|
197 |
V_0 = np.repeat(np.expand_dims(V[..., ED],-1), nt, axis=-1) |
|
|
198 |
V_t = V |
|
|
199 |
|
|
|
200 |
# move time frames to the batch dimension to predict all at onces |
|
|
201 |
V_0 = np.transpose(V_0, (3,0,1,2)) |
|
|
202 |
V_t = np.transpose(V_t, (3,0,1,2)) |
|
|
203 |
y_t = netME([V_0, V_t]).numpy() |
|
|
204 |
|
|
|
205 |
|
|
|
206 |
os.makedirs(output_folder, exist_ok=True) |
|
|
207 |
|
|
|
208 |
# save for calculation. Only the the end-diastolic mask is necessary |
|
|
209 |
np.save(os.path.join(output_folder, 'V_0.npy'), V_0) |
|
|
210 |
np.save(os.path.join(output_folder, 'V_t.npy'), V_t) |
|
|
211 |
np.save(os.path.join(output_folder, 'y_t.npy'), y_t) |
|
|
212 |
np.save(os.path.join(output_folder, 'M_0.npy'), M_0) |
|
|
213 |
|
|
|
214 |
|
|
|
215 |
|
|
|
216 |
folder = '/mnt/alp/Research Data Sets/DeepStrain_vs_CVI/%s/output_from_DeepStrain%s'%(batch, method) |
|
|
217 |
|
|
|
218 |
df = {'SubjectID':[], 'RadialStain':[], 'CircumferentialStrain':[], 'TimeFrame':[]} |
|
|
219 |
for j, subject_folder in enumerate(glob.glob(os.path.join(folder, '*'))): |
|
|
220 |
V_0 = np.load(os.path.join(subject_folder, 'V_0.npy')) |
|
|
221 |
V_t = np.load(os.path.join(subject_folder, 'V_t.npy')) |
|
|
222 |
y_t = np.load(os.path.join(subject_folder, 'y_t.npy')) |
|
|
223 |
M_0 = np.load(os.path.join(subject_folder, 'M_0.npy')) |
|
|
224 |
|
|
|
225 |
y_t = gaussian_filter(y_t, sigma=(0,2,2,0,0)) |
|
|
226 |
|
|
|
227 |
for time_frame in range(len(y_t)): |
|
|
228 |
try: |
|
|
229 |
strain = myocardial_strain.MyocardialStrain(mask=M_0, flow=y_t[time_frame,:,:,:,:]) |
|
|
230 |
strain.calculate_strain(lv_label=tissue_label_blood_pool) |
|
|
231 |
|
|
|
232 |
df['SubjectID'] += [os.path.basename(subject_folder)] |
|
|
233 |
df['RadialStain'] += [100*strain.Err[strain.mask_rot==tissue_label_myocardium].mean()] |
|
|
234 |
df['CircumferentialStrain'] += [100*strain.Ecc[strain.mask_rot==tissue_label_myocardium].mean()] |
|
|
235 |
df['TimeFrame'] += [time_frame] |
|
|
236 |
except: |
|
|
237 |
print('Error in ', subject_folder) |
|
|
238 |
|
|
|
239 |
df = pd.DataFrame(df) |
|
|
240 |
df.to_csv('/mnt/alp/Research Data Sets/DeepStrain_vs_CVI/%s/output_from_DeepStrain%s.csv'%(batch, method)) |
|
|
241 |
|