Diff of /data/nifti_dataset.py [000000] .. [4cda31]

Switch to side-by-side view

--- a
+++ b/data/nifti_dataset.py
@@ -0,0 +1,117 @@
+# Manuel A. Morales (moralesq@mit.edu)
+# Harvard-MIT Department of Health Sciences & Technology  
+# Athinoula A. Martinos Center for Biomedical Imaging
+
+import os
+import glob
+import warnings
+import numpy as np
+import nibabel as nib
+
+from dipy.align.reslice import reslice
+from data.base_dataset import BaseDataset, Transforms
+from data.image_folder import make_dataset
+
+
+class NiftiDataset(BaseDataset):
+
+    def __init__(self, opt):
+        BaseDataset.__init__(self, opt)
+        self.filenames = sorted(make_dataset(opt.dataroot, opt.max_dataset_size, opt.dataformat))
+        self.transform = Transforms(opt)
+    
+    def __len__(self):
+        return len(self.filenames)
+                
+    def __getitem__(self, idx):      
+        nifti           = nib.load(self.filenames[idx])
+        nifti_resampled = resample_nifti(nifti, 
+                                         order=self.opt.order,
+                                         mode=self.opt.mode,
+                                         in_plane_resolution_mm=self.opt.in_plane_resolution_mm,
+                                         slice_thickness_mm=self.opt.slice_thickness_mm,
+                                         number_of_slices=self.opt.number_of_slices)
+        
+        x = self.transform.apply(nifti_resampled.get_fdata())
+        return x, nifti, nifti_resampled
+    
+
+def save_as_nifti(y, nifti, nifti_resampled, filename):
+    
+    if len(y.shape) == 4:
+        y_nifti_resampled = nib.Nifti1Image(y, nifti_resampled.affine)
+        y_nifti = resample_nifti(y_nifti_resampled, 
+                                 in_plane_resolution_mm=nifti.header.get_zooms()[0],
+                                 slice_thickness_mm=nifti.header.get_zooms()[2])
+        
+        y_nifti = nib.Nifti1Image(np.array(np.argmax(y_nifti.get_fdata(),-1)), 
+                                  affine=nifti.affine,
+                                  header=nifti.header)
+        
+    elif len(y.shape) == 5:
+        
+        Y = []
+        for label in range(y.shape[-1]):
+            y_nifti_resampled = nib.Nifti1Image(y[:,:,:,:,label], nifti_resampled.affine)
+            y_nifti = resample_nifti(y_nifti_resampled, 
+                                    in_plane_resolution_mm=nifti.header.get_zooms()[0],
+                                    slice_thickness_mm=nifti.header.get_zooms()[2])
+            Y += [y_nifti.get_fdata()]
+            
+        y_nifti = nib.Nifti1Image(np.argmax(np.stack(Y,-1),-1).astype(int), nifti.affine)
+
+    nib.Nifti1Image(np.array(y_nifti.get_fdata()), affine=y_nifti.affine).to_filename(filename+'.nii') 
+ 
+
+def resample_nifti(nifti, 
+                   order=1,
+                   mode='nearest',
+                   in_plane_resolution_mm=1.25,
+                   slice_thickness_mm=None,
+                   number_of_slices=None):
+    
+    # sometimes dicom to nifti programs don't define affine correctly.
+    resolution = np.array(nifti.header.get_zooms()[:3] + (1,))
+    if (np.abs(nifti.affine)==np.identity(4)).all():
+        nifti.set_sform(nifti.affine*resolution)
+        warnings.warn("Affine in nifti might be set incorrectly. Setting to affine=affine*zooms")
+
+    data   = nifti.get_fdata().copy()
+    shape  = nifti.shape[:3]
+    affine = nifti.affine.copy()
+    zooms  = nifti.header.get_zooms()[:3]
+
+    if number_of_slices is not None:
+        new_zooms = (in_plane_resolution_mm,
+                     in_plane_resolution_mm,
+                     (zooms[2] * shape[2]) / number_of_slices)
+    elif slice_thickness_mm is not None:
+        new_zooms = (in_plane_resolution_mm,
+                     in_plane_resolution_mm,
+                     slice_thickness_mm)            
+    else:
+        new_zooms = (in_plane_resolution_mm,
+                     in_plane_resolution_mm,
+                     zooms[2])
+
+    new_zooms = np.array(new_zooms)
+    for i, (n_i, res_i, res_new_i) in enumerate(zip(shape, zooms, new_zooms)):
+        n_new_i = (n_i * res_i) / res_new_i
+        # to avoid rounding ambiguities
+        if (n_new_i  % 1) == 0.5: 
+            new_zooms[i] -= 0.001
+
+    data_resampled, affine_resampled = reslice(data, affine, zooms, new_zooms, order=order, mode=mode)
+    nifti_resampled = nib.Nifti1Image(data_resampled, affine_resampled)
+
+    x=nifti_resampled.header.get_zooms()[:3]
+    y=new_zooms
+    if not np.allclose(x,y, rtol=1e-02):
+        print(x,y)
+        warnings.warn('Output resolutions are different than expected!')
+
+    return nifti_resampled       
+    
+    
+    
+    
\ No newline at end of file