Diff of /utils/NII.py [000000] .. [978658]

Switch to unified view

a b/utils/NII.py
1
import copy
2
3
import SimpleITK as sitk
4
import matplotlib.pyplot as plt
5
import numpy as np
6
7
8
class NII:
9
    VIEW_MAPPING = {'saggital': 0, 'coronal': 1, 'axial': 2}
10
11
    def __init__(self, filename):
12
        self.nii = sitk.ReadImage(filename, sitk.sitkFloat64)
13
        self._update_attributes()
14
15
        # Remove NaNs
16
        self.data[np.isnan(self.data)] = 0
17
18
    def update_sitk(self):
19
        self.nii = sitk.GetImageFromArray(self.data)
20
        self.nii.SetOrigin(self.origin)
21
        self.nii.SetDirection(self.direction)
22
23
    def _update_attributes(self):
24
        self.origin = self.nii.GetOrigin()
25
        self.direction = self.nii.GetDirection()
26
        self.data = sitk.GetArrayFromImage(self.nii)
27
28
    def save(self, filename):
29
        sitk.WriteImage(self.nii, filename)
30
31
    @property
32
    def num_saggital_slices(self):
33
        return self.data.shape[NII.VIEW_MAPPING['saggital']]
34
35
    @property
36
    def num_coronal_slices(self):
37
        return self.data.shape[NII.VIEW_MAPPING['coronal']]
38
39
    @property
40
    def num_axial_slices(self):
41
        return self.data.shape[NII.VIEW_MAPPING['axial']]
42
43
    @staticmethod
44
    def set_view_mapping(mapping):
45
        NII.VIEW_MAPPING = mapping
46
47
    def shape(self):
48
        return self.data.shape
49
50
    def num_slices_along_axis(self, axis):
51
        return self.data.shape[NII.VIEW_MAPPING[axis]]
52
53
    def normalize(self, method='scaling', lowerpercentile=None, upperpercentile=None):
54
        # Convert the attribute "data" to float()
55
        self.data = self.data.astype(np.float32)
56
57
        if lowerpercentile is not None:
58
            qlow = np.percentile(self.data, lowerpercentile)
59
        if upperpercentile is not None:
60
            qup = np.percentile(self.data, upperpercentile)
61
62
        if lowerpercentile is not None:
63
            self.data[self.data < qlow] = qlow
64
        if upperpercentile is not None:
65
            self.data[self.data > qup] = qup
66
67
        if method == 'scaling':
68
            # Divide "data" by its maximum value
69
            if self.data.max() > 0.0:
70
                self.data = np.multiply(self.data, 1.0 / self.data.max())
71
        elif method == 'standardization':
72
            self.data = self.data - np.mean(self.data)
73
            self.data = self.data / np.std(self.data)
74
75
        self.update_sitk()
76
77
    def apply_skullmap(self, skullmap):
78
        brainmask = skullmap.get_data()
79
        brainmask[brainmask < 0.1] = 0
80
        brainmask[brainmask >= 0.1] = 1
81
        self.data = self.data * brainmask
82
83
        self.update_sitk()
84
85
    def denoise(self):
86
        self.nii = sitk.CurvatureFlow(image1=self.nii, timeStep=0.125, numberOfIterations=3)
87
        self._update_attributes()
88
89
    def subtract(self, filename):
90
        nii_sub = NII(filename)
91
        self.data = self.data - nii_sub.get_data()
92
        self.update_sitk()
93
94
    def get_slice(self, the_slice, axis='axial'):
95
        indices = [slice(None)] * self.data.ndim
96
        indices[NII.VIEW_MAPPING[axis]] = the_slice
97
        return self.data[tuple(indices)]
98
99
    def set_slice(self, the_slice, the_data, axis='axial'):
100
        indices = [slice(None)] * self.data.ndim
101
        indices[NII.VIEW_MAPPING[axis]] = the_slice
102
        self.data[tuple(indices)] = the_data
103
        self.update_sitk()
104
        self._update_attributes()
105
106
    # The first index of the subvolume is expected to be the axis we iterate over
107
    def set_subvolume(self, slice_start, slice_end, subvolume, axis='axial'):
108
        for s in range(slice_start, slice_end):
109
            self.set_slice(s, subvolume[s - slice_start, :, :], axis)
110
111
    def get_data(self):
112
        return self.data
113
114
    def cast_to_float(self):
115
        self.nii = sitk.Cast(self.nii, sitk.sitkFloat64)
116
        self._update_attributes()
117
118
    def set_to_zero(self):
119
        self.data.fill(0.0)
120
        self.update_sitk()
121
122
    def visualize(self, axis='axial', pause=0.2):
123
        num_slices = self.data.shape[NII.VIEW_MAPPING[axis]]
124
        for i in range(num_slices):
125
            img = self.get_slice(i, axis=axis)
126
            plt.imshow(img)
127
            plt.title(f"Slice {i}/{num_slices}")
128
            plt.pause(pause)
129
            plt.cla()
130
131
    def copy(self):
132
        return copy.deepcopy(self)