Diff of /utils/NII.py [000000] .. [978658]

Switch to side-by-side view

--- a
+++ b/utils/NII.py
@@ -0,0 +1,132 @@
+import copy
+
+import SimpleITK as sitk
+import matplotlib.pyplot as plt
+import numpy as np
+
+
+class NII:
+    VIEW_MAPPING = {'saggital': 0, 'coronal': 1, 'axial': 2}
+
+    def __init__(self, filename):
+        self.nii = sitk.ReadImage(filename, sitk.sitkFloat64)
+        self._update_attributes()
+
+        # Remove NaNs
+        self.data[np.isnan(self.data)] = 0
+
+    def update_sitk(self):
+        self.nii = sitk.GetImageFromArray(self.data)
+        self.nii.SetOrigin(self.origin)
+        self.nii.SetDirection(self.direction)
+
+    def _update_attributes(self):
+        self.origin = self.nii.GetOrigin()
+        self.direction = self.nii.GetDirection()
+        self.data = sitk.GetArrayFromImage(self.nii)
+
+    def save(self, filename):
+        sitk.WriteImage(self.nii, filename)
+
+    @property
+    def num_saggital_slices(self):
+        return self.data.shape[NII.VIEW_MAPPING['saggital']]
+
+    @property
+    def num_coronal_slices(self):
+        return self.data.shape[NII.VIEW_MAPPING['coronal']]
+
+    @property
+    def num_axial_slices(self):
+        return self.data.shape[NII.VIEW_MAPPING['axial']]
+
+    @staticmethod
+    def set_view_mapping(mapping):
+        NII.VIEW_MAPPING = mapping
+
+    def shape(self):
+        return self.data.shape
+
+    def num_slices_along_axis(self, axis):
+        return self.data.shape[NII.VIEW_MAPPING[axis]]
+
+    def normalize(self, method='scaling', lowerpercentile=None, upperpercentile=None):
+        # Convert the attribute "data" to float()
+        self.data = self.data.astype(np.float32)
+
+        if lowerpercentile is not None:
+            qlow = np.percentile(self.data, lowerpercentile)
+        if upperpercentile is not None:
+            qup = np.percentile(self.data, upperpercentile)
+
+        if lowerpercentile is not None:
+            self.data[self.data < qlow] = qlow
+        if upperpercentile is not None:
+            self.data[self.data > qup] = qup
+
+        if method == 'scaling':
+            # Divide "data" by its maximum value
+            if self.data.max() > 0.0:
+                self.data = np.multiply(self.data, 1.0 / self.data.max())
+        elif method == 'standardization':
+            self.data = self.data - np.mean(self.data)
+            self.data = self.data / np.std(self.data)
+
+        self.update_sitk()
+
+    def apply_skullmap(self, skullmap):
+        brainmask = skullmap.get_data()
+        brainmask[brainmask < 0.1] = 0
+        brainmask[brainmask >= 0.1] = 1
+        self.data = self.data * brainmask
+
+        self.update_sitk()
+
+    def denoise(self):
+        self.nii = sitk.CurvatureFlow(image1=self.nii, timeStep=0.125, numberOfIterations=3)
+        self._update_attributes()
+
+    def subtract(self, filename):
+        nii_sub = NII(filename)
+        self.data = self.data - nii_sub.get_data()
+        self.update_sitk()
+
+    def get_slice(self, the_slice, axis='axial'):
+        indices = [slice(None)] * self.data.ndim
+        indices[NII.VIEW_MAPPING[axis]] = the_slice
+        return self.data[tuple(indices)]
+
+    def set_slice(self, the_slice, the_data, axis='axial'):
+        indices = [slice(None)] * self.data.ndim
+        indices[NII.VIEW_MAPPING[axis]] = the_slice
+        self.data[tuple(indices)] = the_data
+        self.update_sitk()
+        self._update_attributes()
+
+    # The first index of the subvolume is expected to be the axis we iterate over
+    def set_subvolume(self, slice_start, slice_end, subvolume, axis='axial'):
+        for s in range(slice_start, slice_end):
+            self.set_slice(s, subvolume[s - slice_start, :, :], axis)
+
+    def get_data(self):
+        return self.data
+
+    def cast_to_float(self):
+        self.nii = sitk.Cast(self.nii, sitk.sitkFloat64)
+        self._update_attributes()
+
+    def set_to_zero(self):
+        self.data.fill(0.0)
+        self.update_sitk()
+
+    def visualize(self, axis='axial', pause=0.2):
+        num_slices = self.data.shape[NII.VIEW_MAPPING[axis]]
+        for i in range(num_slices):
+            img = self.get_slice(i, axis=axis)
+            plt.imshow(img)
+            plt.title(f"Slice {i}/{num_slices}")
+            plt.pause(pause)
+            plt.cla()
+
+    def copy(self):
+        return copy.deepcopy(self)