--- a +++ b/utils/NII.py @@ -0,0 +1,132 @@ +import copy + +import SimpleITK as sitk +import matplotlib.pyplot as plt +import numpy as np + + +class NII: + VIEW_MAPPING = {'saggital': 0, 'coronal': 1, 'axial': 2} + + def __init__(self, filename): + self.nii = sitk.ReadImage(filename, sitk.sitkFloat64) + self._update_attributes() + + # Remove NaNs + self.data[np.isnan(self.data)] = 0 + + def update_sitk(self): + self.nii = sitk.GetImageFromArray(self.data) + self.nii.SetOrigin(self.origin) + self.nii.SetDirection(self.direction) + + def _update_attributes(self): + self.origin = self.nii.GetOrigin() + self.direction = self.nii.GetDirection() + self.data = sitk.GetArrayFromImage(self.nii) + + def save(self, filename): + sitk.WriteImage(self.nii, filename) + + @property + def num_saggital_slices(self): + return self.data.shape[NII.VIEW_MAPPING['saggital']] + + @property + def num_coronal_slices(self): + return self.data.shape[NII.VIEW_MAPPING['coronal']] + + @property + def num_axial_slices(self): + return self.data.shape[NII.VIEW_MAPPING['axial']] + + @staticmethod + def set_view_mapping(mapping): + NII.VIEW_MAPPING = mapping + + def shape(self): + return self.data.shape + + def num_slices_along_axis(self, axis): + return self.data.shape[NII.VIEW_MAPPING[axis]] + + def normalize(self, method='scaling', lowerpercentile=None, upperpercentile=None): + # Convert the attribute "data" to float() + self.data = self.data.astype(np.float32) + + if lowerpercentile is not None: + qlow = np.percentile(self.data, lowerpercentile) + if upperpercentile is not None: + qup = np.percentile(self.data, upperpercentile) + + if lowerpercentile is not None: + self.data[self.data < qlow] = qlow + if upperpercentile is not None: + self.data[self.data > qup] = qup + + if method == 'scaling': + # Divide "data" by its maximum value + if self.data.max() > 0.0: + self.data = np.multiply(self.data, 1.0 / self.data.max()) + elif method == 'standardization': + self.data = self.data - np.mean(self.data) + self.data = self.data / np.std(self.data) + + self.update_sitk() + + def apply_skullmap(self, skullmap): + brainmask = skullmap.get_data() + brainmask[brainmask < 0.1] = 0 + brainmask[brainmask >= 0.1] = 1 + self.data = self.data * brainmask + + self.update_sitk() + + def denoise(self): + self.nii = sitk.CurvatureFlow(image1=self.nii, timeStep=0.125, numberOfIterations=3) + self._update_attributes() + + def subtract(self, filename): + nii_sub = NII(filename) + self.data = self.data - nii_sub.get_data() + self.update_sitk() + + def get_slice(self, the_slice, axis='axial'): + indices = [slice(None)] * self.data.ndim + indices[NII.VIEW_MAPPING[axis]] = the_slice + return self.data[tuple(indices)] + + def set_slice(self, the_slice, the_data, axis='axial'): + indices = [slice(None)] * self.data.ndim + indices[NII.VIEW_MAPPING[axis]] = the_slice + self.data[tuple(indices)] = the_data + self.update_sitk() + self._update_attributes() + + # The first index of the subvolume is expected to be the axis we iterate over + def set_subvolume(self, slice_start, slice_end, subvolume, axis='axial'): + for s in range(slice_start, slice_end): + self.set_slice(s, subvolume[s - slice_start, :, :], axis) + + def get_data(self): + return self.data + + def cast_to_float(self): + self.nii = sitk.Cast(self.nii, sitk.sitkFloat64) + self._update_attributes() + + def set_to_zero(self): + self.data.fill(0.0) + self.update_sitk() + + def visualize(self, axis='axial', pause=0.2): + num_slices = self.data.shape[NII.VIEW_MAPPING[axis]] + for i in range(num_slices): + img = self.get_slice(i, axis=axis) + plt.imshow(img) + plt.title(f"Slice {i}/{num_slices}") + plt.pause(pause) + plt.cla() + + def copy(self): + return copy.deepcopy(self)