Diff of /dataloaders/NRRD.py [000000] .. [978658]

Switch to side-by-side view

--- a
+++ b/dataloaders/NRRD.py
@@ -0,0 +1,76 @@
+import matplotlib.pyplot as plt
+import nrrd
+import numpy as np
+
+
+# Class for working with a NII file in the context of machine learning
+class NRRD:
+    VIEW_MAPPING = {'saggital': 0, 'coronal': 1, 'axial': 2}
+
+    def __init__(self, filename):
+        self.data, self.info = nrrd.read(filename)
+
+    @property
+    def num_saggital_slices(self):
+        return self.data.shape[NRRD.VIEW_MAPPING['saggital']]
+
+    @property
+    def num_coronal_slices(self):
+        return self.data.shape[NRRD.VIEW_MAPPING['coronal']]
+
+    @property
+    def num_axial_slices(self):
+        return self.data.shape[NRRD.VIEW_MAPPING['axial']]
+
+    @staticmethod
+    def set_view_mapping(mapping):
+        NRRD.VIEW_MAPPING = mapping
+
+    def shape(self):
+        return self.data.shape
+
+    @staticmethod
+    def get_axis_index(axis):
+        return NRRD.VIEW_MAPPING[axis]
+
+    def num_slices_along_axis(self, axis):
+        return self.data.shape[NRRD.VIEW_MAPPING[axis]]
+
+    def normalize(self, method='scaling', lowerpercentile=None, upperpercentile=None):
+        # Convert the attribute "data" to float()
+        self.data = self.data.astype(np.float32)
+
+        if lowerpercentile is not None:
+            qlow = np.percentile(self.data, lowerpercentile)
+        if upperpercentile is not None:
+            qup = np.percentile(self.data, upperpercentile)
+
+        if lowerpercentile is not None:
+            self.data[self.data < qlow] = qlow
+        if upperpercentile is not None:
+            self.data[self.data > qup] = qup
+
+        if method == 'scaling':
+            # Divide "data" by its maximum value
+            self.data -= self.data.min()
+            self.data = np.multiply(self.data, 1.0 / self.data.max())
+        elif method == 'standardization':
+            self.data = self.data - np.mean(self.data)
+            self.data = self.data / np.std(self.data)
+
+    def get_slice(self, the_slice, axis='axial'):
+        indices = [slice(None)] * self.data.ndim
+        indices[NRRD.VIEW_MAPPING[axis]] = the_slice
+        return self.data[indices]
+
+    def get_data(self):
+        return self.data
+
+    def set_to_zero(self):
+        self.data.fill(0.0)
+
+    def visualize(self, axis='axial', pause=0.2):
+        for i in range(self.data.shape[NRRD.VIEW_MAPPING[axis]]):
+            img = self.get_slice(i, axis=axis)
+            plt.imshow(img)
+            plt.pause(pause)