--- a +++ b/dataloaders/NRRD.py @@ -0,0 +1,76 @@ +import matplotlib.pyplot as plt +import nrrd +import numpy as np + + +# Class for working with a NII file in the context of machine learning +class NRRD: + VIEW_MAPPING = {'saggital': 0, 'coronal': 1, 'axial': 2} + + def __init__(self, filename): + self.data, self.info = nrrd.read(filename) + + @property + def num_saggital_slices(self): + return self.data.shape[NRRD.VIEW_MAPPING['saggital']] + + @property + def num_coronal_slices(self): + return self.data.shape[NRRD.VIEW_MAPPING['coronal']] + + @property + def num_axial_slices(self): + return self.data.shape[NRRD.VIEW_MAPPING['axial']] + + @staticmethod + def set_view_mapping(mapping): + NRRD.VIEW_MAPPING = mapping + + def shape(self): + return self.data.shape + + @staticmethod + def get_axis_index(axis): + return NRRD.VIEW_MAPPING[axis] + + def num_slices_along_axis(self, axis): + return self.data.shape[NRRD.VIEW_MAPPING[axis]] + + def normalize(self, method='scaling', lowerpercentile=None, upperpercentile=None): + # Convert the attribute "data" to float() + self.data = self.data.astype(np.float32) + + if lowerpercentile is not None: + qlow = np.percentile(self.data, lowerpercentile) + if upperpercentile is not None: + qup = np.percentile(self.data, upperpercentile) + + if lowerpercentile is not None: + self.data[self.data < qlow] = qlow + if upperpercentile is not None: + self.data[self.data > qup] = qup + + if method == 'scaling': + # Divide "data" by its maximum value + self.data -= self.data.min() + self.data = np.multiply(self.data, 1.0 / self.data.max()) + elif method == 'standardization': + self.data = self.data - np.mean(self.data) + self.data = self.data / np.std(self.data) + + def get_slice(self, the_slice, axis='axial'): + indices = [slice(None)] * self.data.ndim + indices[NRRD.VIEW_MAPPING[axis]] = the_slice + return self.data[indices] + + def get_data(self): + return self.data + + def set_to_zero(self): + self.data.fill(0.0) + + def visualize(self, axis='axial', pause=0.2): + for i in range(self.data.shape[NRRD.VIEW_MAPPING[axis]]): + img = self.get_slice(i, axis=axis) + plt.imshow(img) + plt.pause(pause)