Diff of /dataloaders/NRRD.py [000000] .. [978658]

Switch to unified view

a b/dataloaders/NRRD.py
1
import matplotlib.pyplot as plt
2
import nrrd
3
import numpy as np
4
5
6
# Class for working with a NII file in the context of machine learning
7
class NRRD:
8
    VIEW_MAPPING = {'saggital': 0, 'coronal': 1, 'axial': 2}
9
10
    def __init__(self, filename):
11
        self.data, self.info = nrrd.read(filename)
12
13
    @property
14
    def num_saggital_slices(self):
15
        return self.data.shape[NRRD.VIEW_MAPPING['saggital']]
16
17
    @property
18
    def num_coronal_slices(self):
19
        return self.data.shape[NRRD.VIEW_MAPPING['coronal']]
20
21
    @property
22
    def num_axial_slices(self):
23
        return self.data.shape[NRRD.VIEW_MAPPING['axial']]
24
25
    @staticmethod
26
    def set_view_mapping(mapping):
27
        NRRD.VIEW_MAPPING = mapping
28
29
    def shape(self):
30
        return self.data.shape
31
32
    @staticmethod
33
    def get_axis_index(axis):
34
        return NRRD.VIEW_MAPPING[axis]
35
36
    def num_slices_along_axis(self, axis):
37
        return self.data.shape[NRRD.VIEW_MAPPING[axis]]
38
39
    def normalize(self, method='scaling', lowerpercentile=None, upperpercentile=None):
40
        # Convert the attribute "data" to float()
41
        self.data = self.data.astype(np.float32)
42
43
        if lowerpercentile is not None:
44
            qlow = np.percentile(self.data, lowerpercentile)
45
        if upperpercentile is not None:
46
            qup = np.percentile(self.data, upperpercentile)
47
48
        if lowerpercentile is not None:
49
            self.data[self.data < qlow] = qlow
50
        if upperpercentile is not None:
51
            self.data[self.data > qup] = qup
52
53
        if method == 'scaling':
54
            # Divide "data" by its maximum value
55
            self.data -= self.data.min()
56
            self.data = np.multiply(self.data, 1.0 / self.data.max())
57
        elif method == 'standardization':
58
            self.data = self.data - np.mean(self.data)
59
            self.data = self.data / np.std(self.data)
60
61
    def get_slice(self, the_slice, axis='axial'):
62
        indices = [slice(None)] * self.data.ndim
63
        indices[NRRD.VIEW_MAPPING[axis]] = the_slice
64
        return self.data[indices]
65
66
    def get_data(self):
67
        return self.data
68
69
    def set_to_zero(self):
70
        self.data.fill(0.0)
71
72
    def visualize(self, axis='axial', pause=0.2):
73
        for i in range(self.data.shape[NRRD.VIEW_MAPPING[axis]]):
74
            img = self.get_slice(i, axis=axis)
75
            plt.imshow(img)
76
            plt.pause(pause)