Diff of /test.py [000000] .. [4cda31]

Switch to unified view

a b/test.py
1
import os
2
import h5py
3
import timeit
4
import numpy as np
5
import nibabel as nib
6
from tensorflow.keras.optimizers import Adam
7
from options.test_options import TestOptions
8
from models import deep_strain_model
9
from data import nifti_dataset, h5py_dataset
10
from utils import myocardial_strain
11
12
# options
13
opt = TestOptions().parse()
14
os.makedirs(opt.results_dir, exist_ok=True)
15
preprocess = opt.preprocess
16
model   = deep_strain_model.DeepStrain(Adam, opt)
17
18
if 'segmentation' in opt.pipeline:
19
    
20
    opt.preprocess = opt.preprocess_carson + '_' + preprocess    
21
    dataset = nifti_dataset.NiftiDataset(opt)
22
    netS    = model.get_netS()
23
    for i, data in enumerate(dataset):
24
25
        filename = os.path.basename(dataset.filenames[i]).split('.')[0]
26
27
        x, nifti, nifti_resampled = data
28
        
29
        y = netS(x).numpy()
30
        y = dataset.transform.apply_inv(y)
31
        nifti_dataset.save_as_nifti(y, nifti, nifti_resampled,
32
                                    filename=os.path.join(opt.results_dir, filename+'_segmentation'))
33
        
34
    del netS
35
36
if 'motion' in opt.pipeline:
37
    opt.number_of_slices = 16 
38
    opt.preprocess = opt.preprocess_carmen + '_' + preprocess
39
    
40
    model   = deep_strain_model.DeepStrain(Adam, opt)
41
    dataset = nifti_dataset.NiftiDataset(opt)
42
    netME   = model.get_netME()
43
    
44
    
45
    for i, data in enumerate(dataset):
46
        filename = os.path.basename(dataset.filenames[i]).split('.')[0]
47
        
48
        # model was trained with cine data from base to apex.
49
        # a different orientation could yield different values. 
50
        # if you have the segmentation you can modify the images
51
        # as shown in the example notebooks.
52
        x, nifti, nifti_resampled = data
53
        x_0, x_t = np.array_split(x,2,-1)
54
        y_t = netME([x_0, x_t]).numpy()
55
        y_t = dataset.transform.apply_inv(y_t)
56
        
57
        HF = h5py.File(os.path.join(opt.results_dir, filename+'_motion.h5'), 'w')
58
        for time_frame in range(y_t.shape[-2]):
59
            hf = HF.create_group('frame_%d' %(time_frame))
60
            hf.create_dataset('u', data=y_t[:,:,:,time_frame])
61
        HF.close()
62
63
    del netME   
64
    
65
if 'strain' in opt.pipeline:   
66
    dataset = h5py_dataset.H5PYDataset(opt)
67
  
68
    
69
    for idx, u in enumerate(dataset): 
70
        
71
        filename  = dataset.filenames[idx].split('_motion.h5')[0]
72
        mask_path = filename+'_segmentation.nii'
73
        try:
74
            mask_nifti = nib.load(mask_path)
75
76
        except:
77
            print('Missing segmentation')
78
            continue
79
80
        mask_zooms = mask_nifti.header.get_zooms()
81
        mask_nifti = nifti_dataset.resample_nifti(mask_nifti, in_plane_resolution_mm=1.25, number_of_slices=16)
82
        mask = mask_nifti.get_fdata()
83
84
        Radial = np.zeros(mask.shape)
85
        Circumferential = np.zeros(mask.shape)
86
        for time_frame in range(u.shape[-1]):
87
            strain = myocardial_strain.MyocardialStrain(mask=mask[:,:,:,0], flow=u[:,:,:,:,time_frame])
88
            strain.calculate_strain(lv_label=3)
89
90
            strain.Err[strain.mask_rot!=2] = 0.0
91
            strain.Ecc[strain.mask_rot!=2] = 0.0
92
93
            Radial[:,:,:,time_frame]          += strain.Err
94
            Circumferential[:,:,:,time_frame] += strain.Ecc
95
96
            GRS = strain.Err[strain.mask_rot==2].mean()
97
            GCS = strain.Ecc[strain.mask_rot==2].mean()
98
            print(GRS, GCS)
99
100
101
        Radial = nib.Nifti1Image(Radial, mask_nifti.affine)
102
        Circumferential = nib.Nifti1Image(Circumferential, mask_nifti.affine)
103
104
        Radial.to_filename(filename+'_radial_strain.nii')
105
        Circumferential.to_filename(filename+'_circumferential_strain.nii')
106
107
108
109
110
111