Switch to side-by-side view

--- a
+++ b/examples/deepstrain_vs_cvi.py
@@ -0,0 +1,241 @@
+import os
+import glob
+import time
+import pydicom
+import numpy as np
+import pandas as pd
+import nibabel as nib
+
+PREPARE_INPUT_DATA_WITH_CARSON = False
+PREDICT = False
+
+if PREPARE_INPUT_DATA_WITH_CARSON:
+    
+    from data import base_dataset
+    from data.nifti_dataset import resample_nifti
+    from tensorflow.keras.optimizers import Adam
+    from options.test_options import TestOptions
+    from models import deep_strain_model
+
+    def normalize(x, axis=(0,1,2)):
+        # normalize per volume (x,y,z) frame
+        mu = x.mean(axis=axis, keepdims=True)
+        sd = x.std(axis=axis, keepdims=True)
+        return (x-mu)/(sd+1e-8)
+
+    def get_mask(V, netS):
+        nx, ny, nz, nt = V.shape
+        
+        M = np.zeros((nx,ny,nz,nt))
+        v = V.transpose((2,3,0,1)).reshape((-1,nx,ny)) # (nz*nt,nx,ny)
+        v = normalize(v)
+        m = netS(v[:,nx//2-64:nx//2+64,ny//2-64:ny//2+64,None])
+        M[nx//2-64:nx//2+64,ny//2-64:ny//2+64] += np.argmax(m, -1).transpose((1,2,0)).reshape((128,128,nz,nt))
+        
+        return M
+
+    # options
+    opt   = TestOptions().parse()
+    model = deep_strain_model.DeepStrain(Adam, opt)
+    netS  = model.get_netS()
+    netS.load_weights('/home/mmorales/main_python/DeepStrain/pretrained_models/carson_Jan2021.h5')
+
+    time_resample = []
+    time_carson   = []
+    
+    # load subjects by batches
+    batches = ['batch_%d'%(j) for j in range(1,11)] + ['HFpEF_batch_%d'%(j) for j in range(1,5)]
+
+    for batch in batches:
+
+        niftis_folder    = '/mnt/alp/Research Data Sets/DeepStrain_vs_CVI/%s/niftis/standard'%(batch)
+        niftis_folder_out_carson = '/mnt/alp/Research Data Sets/DeepStrain_vs_CVI/%s/input_to_DeepStrain_with_CarSON'%(batch)
+
+        for SubjectID_folder in glob.glob(os.path.join(niftis_folder, '*')):
+            for nifti_path in glob.glob(os.path.join(SubjectID_folder, '*.nii.gz')):
+                
+                try:
+                    V_nifti = nib.load(nifti_path)
+                    start = time.time()
+                    V_nifti_resampled = resample_nifti(V_nifti, order=1, in_plane_resolution_mm=1.25, number_of_slices=None)
+                    end = time.time()
+                    time_resample += [end - start]
+
+                    # here we normalize per image, not volume
+                    V = V_nifti_resampled.get_fdata()
+                    V = normalize(V, axis=(0,1))
+
+                    # In this case we don't yet have a segmentation we can use to crop the image. 
+                    # In most cases we can simply center crop (see `get_mask` function): 
+                    start = time.time()
+                    M = get_mask(V, netS)
+                    end = time.time()
+                    time_carson += [end - start]
+
+                    # ONLY IF YOU KNOW YOUR IMAGE IS ROUGHLY NEAR CENTER 
+                    M_nifti_resampled = nib.Nifti1Image(M, affine=V_nifti_resampled.affine)
+                    # resample back to original resolution
+                    start = time.time()
+                    M_nifti = base_dataset.resample_nifti_inv(nifti_resampled=M_nifti_resampled, 
+                                                              zooms=V_nifti.header.get_zooms()[:3], 
+                                                              order=0, mode='nearest')
+                    end = time.time()
+                    time_resample += [end - start]
+                    fname = os.path.basename(nifti_path).strip('.nii.gz').replace('(','').replace(')','')
+                    output_folder = os.path.join(niftis_folder_out_carson, os.path.basename(SubjectID_folder))
+
+                    os.makedirs(output_folder, exist_ok=True)
+
+                    V_nifti.to_filename(os.path.join(output_folder, fname+'.nii.gz'))
+                    M_nifti.to_filename(os.path.join(output_folder, fname+'_segmentation.nii.gz'))
+                except:
+                    print("Error here, check!", nifti_path)
+                    continue
+
+    np.save('/mnt/alp/Research Data Sets/DeepStrain_vs_CVI/time_resample', time_resample)
+    np.save('/mnt/alp/Research Data Sets/DeepStrain_vs_CVI/time_carson', time_carson)
+
+
+
+
+
+if PREDICT:
+
+    from data.nifti_dataset import resample_nifti
+    from data.base_dataset import _roll2center_crop
+    from scipy.ndimage.measurements import center_of_mass
+
+   
+    from aux import myocardial_strain
+    from scipy.ndimage import gaussian_filter
+
+    from tensorflow.keras.optimizers import Adam
+    from options.test_options import TestOptions
+    from models import deep_strain_model
+
+    def normalize(x):
+        # normalize per volume (x,y,z) frame
+        mu = x.mean(axis=(0,1,2), keepdims=True)
+        sd = x.std(axis=(0,1,2), keepdims=True)
+        return (x-mu)/(sd+1e-8)
+
+    # options
+    opt = TestOptions().parse()
+    preprocess = opt.preprocess
+    model   = deep_strain_model.DeepStrain(Adam, opt)
+    
+    opt.number_of_slices = 16 
+    opt.preprocess = opt.preprocess_carmen + '_' + preprocess
+    opt.pretrained_models_netME = '/home/mmorales/main_python/DeepStrain/pretrained_models/carmenJan2021.h5'
+    model   = deep_strain_model.DeepStrain(Adam, opt)
+    netME   = model.get_netME()
+    netME.load_weights('/home/mmorales/main_python/DeepStrain/pretrained_models/carmen_Jan2021.h5')
+
+    batches = ['batch_%d'%(j) for j in range(1,11)] + ['HFpEF_batch_%d'%(j) for j in range(1,5)]
+
+    # calculate using CarSON segmentations. Note that segmentations based on other segmentation models is also possible
+    for method in ['_with_CarSON']:    
+        # verify these labels!
+        if method == '_with_CarSON':
+            tissue_label_blood_pool=3; tissue_label_myocardium=2; tissue_label_rv=1
+        else:
+            tissue_label_blood_pool=1; tissue_label_myocardium=2; tissue_label_rv=3
+            
+        for batch in batches:
+            print(batch)
+            # only use data whose cines and corresponding segmentations have been prepared
+            niftis_folder_out = '/mnt/alp/Research Data Sets/DeepStrain_vs_CVI/%s/input_to_DeepStrain%s'%(batch, method)
+
+            RUN_CARMEN = True
+            if RUN_CARMEN:
+                for SubjectID_folder in glob.glob(os.path.join(niftis_folder_out, '*')):
+                    
+                    for nifti_path in glob.glob(os.path.join(SubjectID_folder, '*_segmentation.nii.gz')):
+
+                        output_folder = os.path.join(os.path.dirname(niftis_folder_out), 
+                                                    'output_from_DeepStrain%s'%(method),
+                                                    os.path.basename(SubjectID_folder))
+                        
+                        if os.path.isdir(output_folder): continue
+
+                        print(output_folder)
+
+                        V_nifti = nib.load(nifti_path.replace('_segmentation', ''))
+                        M_nifti = nib.load(nifti_path)
+
+                        V_nifti = resample_nifti(V_nifti, order=1, number_of_slices=16)
+                        M_nifti = resample_nifti(M_nifti, order=0, number_of_slices=16)
+                        
+                        
+
+                        center = center_of_mass(M_nifti.get_fdata()==tissue_label_myocardium)
+                        V = _roll2center_crop(x=V_nifti.get_fdata(), center=center)
+                        M = _roll2center_crop(x=M_nifti.get_fdata(), center=center)
+
+                        I = np.argmax((M==tissue_label_rv).sum(axis=(0,1,3)))
+                        if I > M.shape[2]//2:
+                            print('Apex to Base. Inverting.')
+                            V = V[:,:,::-1]
+                            M = M[:,:,::-1]
+                        
+                        V = normalize(V)
+
+                        nx, ny, nz, nt = V.shape
+
+                        try:
+                            # calculate volume across the mid-ventricular section to estimate end-diastole
+                            volumes = (M_nifti.get_fdata()[:,:,nz//2-2:nz+3]==tissue_label_blood_pool).sum(axis=(0,1,2))
+                        except:
+                            print('Need to use all volume to estimate ED/ES')
+                            volumes = (M_nifti.get_fdata()==tissue_label_blood_pool).sum(axis=(0,1,2))
+
+                        ED = np.argmax(volumes)
+                        ES = np.argmin(volumes)
+                        
+                        # set end-diastole as the reference time frame
+                        M_0 = M[..., ED]
+                        V_0 = np.repeat(np.expand_dims(V[..., ED],-1), nt, axis=-1)
+                        V_t = V
+
+                        # move time frames to the batch dimension to predict all at onces
+                        V_0 = np.transpose(V_0, (3,0,1,2))
+                        V_t = np.transpose(V_t, (3,0,1,2))
+                        y_t = netME([V_0, V_t]).numpy()
+
+                        
+                        os.makedirs(output_folder, exist_ok=True)
+
+                        # save for calculation. Only the the end-diastolic mask is necessary
+                        np.save(os.path.join(output_folder, 'V_0.npy'), V_0)
+                        np.save(os.path.join(output_folder, 'V_t.npy'), V_t)
+                        np.save(os.path.join(output_folder, 'y_t.npy'), y_t)
+                        np.save(os.path.join(output_folder, 'M_0.npy'), M_0)
+
+
+                
+            folder = '/mnt/alp/Research Data Sets/DeepStrain_vs_CVI/%s/output_from_DeepStrain%s'%(batch, method)
+
+            df = {'SubjectID':[], 'RadialStain':[], 'CircumferentialStrain':[], 'TimeFrame':[]}
+            for j, subject_folder in enumerate(glob.glob(os.path.join(folder, '*'))):
+                V_0 = np.load(os.path.join(subject_folder, 'V_0.npy'))
+                V_t = np.load(os.path.join(subject_folder, 'V_t.npy'))
+                y_t = np.load(os.path.join(subject_folder, 'y_t.npy'))
+                M_0 = np.load(os.path.join(subject_folder, 'M_0.npy'))
+
+                y_t = gaussian_filter(y_t, sigma=(0,2,2,0,0))
+
+                for time_frame in range(len(y_t)):
+                    try:
+                        strain = myocardial_strain.MyocardialStrain(mask=M_0, flow=y_t[time_frame,:,:,:,:])
+                        strain.calculate_strain(lv_label=tissue_label_blood_pool)
+
+                        df['SubjectID']             += [os.path.basename(subject_folder)]
+                        df['RadialStain']           += [100*strain.Err[strain.mask_rot==tissue_label_myocardium].mean()]   
+                        df['CircumferentialStrain'] += [100*strain.Ecc[strain.mask_rot==tissue_label_myocardium].mean()]
+                        df['TimeFrame']             += [time_frame]
+                    except:
+                        print('Error in ', subject_folder)
+
+            df = pd.DataFrame(df)
+            df.to_csv('/mnt/alp/Research Data Sets/DeepStrain_vs_CVI/%s/output_from_DeepStrain%s.csv'%(batch, method))
+