[8fb459]: / test.py

Download this file

112 lines (78 with data), 3.7 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
import os
import h5py
import timeit
import numpy as np
import nibabel as nib
from tensorflow.keras.optimizers import Adam
from options.test_options import TestOptions
from models import deep_strain_model
from data import nifti_dataset, h5py_dataset
from utils import myocardial_strain
# options
opt = TestOptions().parse()
os.makedirs(opt.results_dir, exist_ok=True)
preprocess = opt.preprocess
model = deep_strain_model.DeepStrain(Adam, opt)
if 'segmentation' in opt.pipeline:
opt.preprocess = opt.preprocess_carson + '_' + preprocess
dataset = nifti_dataset.NiftiDataset(opt)
netS = model.get_netS()
for i, data in enumerate(dataset):
filename = os.path.basename(dataset.filenames[i]).split('.')[0]
x, nifti, nifti_resampled = data
y = netS(x).numpy()
y = dataset.transform.apply_inv(y)
nifti_dataset.save_as_nifti(y, nifti, nifti_resampled,
filename=os.path.join(opt.results_dir, filename+'_segmentation'))
del netS
if 'motion' in opt.pipeline:
opt.number_of_slices = 16
opt.preprocess = opt.preprocess_carmen + '_' + preprocess
model = deep_strain_model.DeepStrain(Adam, opt)
dataset = nifti_dataset.NiftiDataset(opt)
netME = model.get_netME()
for i, data in enumerate(dataset):
filename = os.path.basename(dataset.filenames[i]).split('.')[0]
# model was trained with cine data from base to apex.
# a different orientation could yield different values.
# if you have the segmentation you can modify the images
# as shown in the example notebooks.
x, nifti, nifti_resampled = data
x_0, x_t = np.array_split(x,2,-1)
y_t = netME([x_0, x_t]).numpy()
y_t = dataset.transform.apply_inv(y_t)
HF = h5py.File(os.path.join(opt.results_dir, filename+'_motion.h5'), 'w')
for time_frame in range(y_t.shape[-2]):
hf = HF.create_group('frame_%d' %(time_frame))
hf.create_dataset('u', data=y_t[:,:,:,time_frame])
HF.close()
del netME
if 'strain' in opt.pipeline:
dataset = h5py_dataset.H5PYDataset(opt)
for idx, u in enumerate(dataset):
filename = dataset.filenames[idx].split('_motion.h5')[0]
mask_path = filename+'_segmentation.nii'
try:
mask_nifti = nib.load(mask_path)
except:
print('Missing segmentation')
continue
mask_zooms = mask_nifti.header.get_zooms()
mask_nifti = nifti_dataset.resample_nifti(mask_nifti, in_plane_resolution_mm=1.25, number_of_slices=16)
mask = mask_nifti.get_fdata()
Radial = np.zeros(mask.shape)
Circumferential = np.zeros(mask.shape)
for time_frame in range(u.shape[-1]):
strain = myocardial_strain.MyocardialStrain(mask=mask[:,:,:,0], flow=u[:,:,:,:,time_frame])
strain.calculate_strain(lv_label=3)
strain.Err[strain.mask_rot!=2] = 0.0
strain.Ecc[strain.mask_rot!=2] = 0.0
Radial[:,:,:,time_frame] += strain.Err
Circumferential[:,:,:,time_frame] += strain.Ecc
GRS = strain.Err[strain.mask_rot==2].mean()
GCS = strain.Ecc[strain.mask_rot==2].mean()
print(GRS, GCS)
Radial = nib.Nifti1Image(Radial, mask_nifti.affine)
Circumferential = nib.Nifti1Image(Circumferential, mask_nifti.affine)
Radial.to_filename(filename+'_radial_strain.nii')
Circumferential.to_filename(filename+'_circumferential_strain.nii')