a b/src/LFBNet/data_loader.py
1
"""
2
3
"""
4
import os
5
import glob
6
import numpy as np
7
from numpy import ndarray
8
import matplotlib.pyplot as plt
9
import nibabel as nib
10
from typing import List, Tuple
11
from numpy.random import seed
12
13
# seed random number generator
14
seed(1)
15
16
17
class DataLoader:
18
    """
19
    read preprocessed pet and gt MIP data for training
20
    """
21
22
    def __init__(self, data_dir: str, ids_to_read: ndarray = None, shuffle=True, training: bool = True):
23
        self.data_dir = data_dir
24
        self.ids_to_read = ids_to_read
25
        self.shuffle = shuffle
26
        self.training = training
27
28
    def get_batch_of_data(self):
29
        """
30
        data structure:
31
        -- main directory
32
        ------case Name:
33
                -- pet.nii.gz
34
                -- gt.nii.gz
35
        --Given list of training and testing on .text files
36
            -- train.text
37
            -- valid.text
38
        """
39
40
        # check directory
41
        self.directory_exist(self.data_dir)
42
43
        # get all names of the directories under data_dir
44
        case_ids = os.listdir(self.data_dir)
45
46
        # store batch data
47
        image_batch, ground_truth_batch = [], []
48
49
        # if there are file in data dir
50
        if not len(case_ids):
51
            raise Exception("No files found in %s" % self.data_dir)
52
53
        # else continue getting.reading the files
54
        for get_id in list(case_ids):
55
            if str(get_id) in list(self.ids_to_read):
56
                try:
57
                    # consider there four images in each folder name get_id:
58
                    # e.g. : coronal (gt_1, pet_1) and sagittal  (gt_0, pet_0)
59
                    current_dir = os.path.join(self.data_dir, str(get_id))
60
                    # read sagittal and coronal as independent images
61
                    pet_sagittla_coronal, gt_sagittal_coronal = self.get_nii_files_path(current_dir)
62
63
                    # pet, normalization, standardization
64
                    if len(pet_sagittla_coronal):  # if image is read
65
                        pet_sagittla_coronal = self.data_normalization_standardization(pet_sagittla_coronal,
66
                                                                                       z_score=True,
67
                                                                                       z_score_include_zeros=False)
68
69
                        gt_sagittal_coronal = self.data_normalization_standardization(gt_sagittal_coronal, threshold=True)
70
71
                        # display or save samples
72
                        # self.mip_show(pet=pet_sagittla_coronal, gt=gt_sagittal_coronal, identifier=str(get_id))
73
74
                        # collect all images with case_id
75
                        if not bool(len(image_batch)):  # if it is empty; first time
76
                            image_batch = pet_sagittla_coronal
77
                            ground_truth_batch = gt_sagittal_coronal
78
                        else:
79
                            image_batch = np.concatenate((image_batch, pet_sagittla_coronal), axis=0)
80
                            ground_truth_batch = np.concatenate((ground_truth_batch, gt_sagittal_coronal), axis=0)
81
                except:
82
                    print('Not read %s' %(str(get_id)))
83
84
        return [image_batch, ground_truth_batch]
85
86
    @staticmethod
87
    def directory_exist(dir_check: str = None) -> None:
88
        """
89
        :param dir_check:
90
        """
91
        if os.path.exists(dir_check):
92
            #  print("The directory %s does exist \n" % dir_check)
93
            pass
94
        else:
95
            raise Exception(
96
                "Please provide the correct path to the processed data ! \n %s not found \n" % (dir_check))
97
98
    @staticmethod
99
    def mip_show(pet: ndarray = None, gt: ndarray = None, identifier: str = None) -> None:
100
        """
101
102
        :param pet:
103
        :param gt:
104
        :param identifier:
105
        :return:
106
        """
107
        # consider axis 0 for sagittal and axis 1 for coronal views
108
        fig, axs = plt.subplots(1, 4, figsize=(15, 15))
109
        plt.title(str(identifier))
110
        try:
111
            pet = np.squeeze(pet)
112
            gt = np.squeeze(gt)
113
        except:
114
            pass
115
116
        axs[0].imshow(np.rot90(np.log(pet[0] + 1)))
117
        axs[0].set_title('pet_project_on_axis_0')
118
        axs[1].imshow(np.rot90(np.log(gt[0] + 1)))
119
        axs[1].set_title('gt_project_on_axis_0')
120
        axs[2].imshow(np.rot90(np.log(pet[1] + 1)))
121
        axs[2].set_title('project_on_axis_1')
122
        axs[3].imshow(np.rot90(np.log(gt[1] + 1)))
123
        axs[3].set_title('gt_project_on_axis_1')
124
        plt.show()
125
126
    @staticmethod
127
    def get_nii_files_path(data_directory: str) -> List[ndarray]:
128
        """
129
        read .nii or .nii.gz files from a given folder of path data_directory
130
        :param data_directory:
131
        :return:
132
        """
133
        # more than one .nii or .nii.gz is found in the folder the first will be returned
134
        types = ('/*.nii', '/*.nii.gz')  # the tuple of file types
135
        nii_paths = []
136
        for files in types:
137
            nii_paths.extend([i for i in glob.glob(str(data_directory) + files)])
138
139
        pet, gt = [], []
140
        if not len(nii_paths):  # if no file exists that ends wtih .nii.gz or .nii
141
            # raise Exception("No .nii or .nii.gz found in %s dirctory" % data_directory)
142
            pass
143
        else:
144
            # assuming the folder contains coronal mips: pet_1, gt_1, and sagittal mips: pet_0, gt_0,
145
            pet_saggital, pet_coronal, gt_saggital, gt_coronal = [], [], [], []
146
            for path in list(nii_paths):
147
                # get the base name: means the file name
148
                identifier_base_name = str(os.path.basename(path)).split('.')[0]
149
                if "pet_sagittal" == str(identifier_base_name):
150
                    pet_saggital = np.asanyarray(nib.load(path).dataobj)
151
                    pet_saggital = np.expand_dims(pet_saggital, axis=0)
152
153
                elif "pet_coronal" == str(identifier_base_name):
154
                    pet_coronal = np.asanyarray(nib.load(path).dataobj)
155
                    pet_coronal = np.expand_dims(pet_coronal, axis=0)
156
157
                if "ground_truth_sagittal" == str(identifier_base_name):
158
                    gt_saggital = np.asanyarray(nib.load(path).dataobj)
159
                    gt_saggital = np.expand_dims(gt_saggital, axis=0)
160
161
                elif "ground_truth_coronal" == str(identifier_base_name):
162
                    gt_coronal = np.asanyarray(nib.load(path).dataobj)
163
                    gt_coronal = np.expand_dims(gt_coronal, axis=0)
164
165
            # concatenate coronal and sagita images
166
            # show
167
            pet = np.concatenate((pet_saggital, pet_coronal), axis=0)
168
            gt = np.concatenate((gt_saggital, gt_coronal), axis=0)
169
        return [pet, gt]
170
171
    @staticmethod
172
    def z_score(image: ndarray, include_zeros: bool = False):
173
        """
174
175
        :param image:
176
        :param include_zeros:
177
        :return:
178
        """
179
        # include zeros
180
        if include_zeros:
181
            image = (image - np.mean(image)) / (np.std(image) + 1e-8)
182
        else:
183
            # Don't include zeros
184
            means = np.true_divide(image.sum(), (image != 0).sum())
185
            stds = np.nanstd(np.where(np.isclose(image, 0), np.nan, image))
186
            image = (image - means) / (stds + 1e-8)
187
        return image
188
189
    def data_normalization_standardization(self, data: ndarray, threshold: bool = False, z_score: bool = False,
190
                                           z_score_include_zeros: bool = False,
191
                                           min_max_scale: bool = False, log_transform: bool = False) -> List[ndarray]:
192
        """
193
        Data normalization and standardization function
194
        :param data:
195
        :param threshold:
196
        :param z_score:
197
        :param z_score_include_zeros:
198
        :param min_max_scale:
199
        :param log_transform:
200
        :return:
201
        """
202
203
        if not isinstance(data, List):
204
            data = np.array(data)
205
206
        # groundtruh > 0 is 1 and <=0 is 0
207
        if threshold:
208
            data[data > 0] = 1
209
210
        if z_score:
211
            data = self.z_score(data, include_zeros=z_score_include_zeros)
212
213
        if min_max_scale:
214
            data = (data - min(data)) / (max(data) - min(data))
215
216
        if log_transform:
217
            data = np.log(data + 1)
218
219
        return data
220
221
222
if __name__ == '__main__':
223
    # for Example
224
    print("data_loader for preprocessed coronal and sagittal MIPs, pet, and gt")
225
    data_dir = "../data/vienna_default_MIP_dir/"
226
    ids_to_read = os.listdir(data_dir)
227
228
    data_loader = DataLoader(data_dir=data_dir, ids_to_read=ids_to_read)
229
    loaded_data = data_loader.get_batch_of_data()
230
    print(np.array(loaded_data).shape)