Diff of /dataloaders/MSISBI2015.py [000000] .. [978658]

Switch to unified view

a b/dataloaders/MSISBI2015.py
1
"""Functions for reading MSISBI2015 NRRD data."""
2
3
from __future__ import absolute_import
4
from __future__ import division
5
from __future__ import print_function
6
7
import glob
8
import math
9
import os.path
10
import pickle
11
12
import matplotlib.pyplot
13
from imageio import imwrite
14
from scipy.ndimage import zoom
15
from six.moves import xrange
16
from skimage.measure import label, regionprops
17
18
from utils.NII import *
19
from utils.image_utils import crop, crop_center
20
from utils.tfrecord_utils import *
21
22
matplotlib.pyplot.ion()
23
24
25
class MSISBI2015(object):
26
    PROTOCOL_MAPPINGS = {'FLAIR': ['flair'], 'MPRAGE': ['mprage'], 'PD': ['pd'], 'T2': ['t2']}
27
    SET_TYPES = ['TRAIN', 'VAL', 'TEST']
28
29
    class Options(object):
30
        def __init__(self):
31
            self.dir = os.path.dirname(os.path.realpath(__file__))
32
            self.numSamples = -1
33
            self.partition = {'TRAIN': 0.7, 'VAL': 0.2, 'TEST': 0.1}
34
            self.useCrops = False
35
            self.cropType = 'random'  # random or center
36
            self.numRandomCropsPerSlice = 5
37
            self.onlyPatchesWithLesions = False
38
            self.rotations = 0
39
            self.cropWidth = 128
40
            self.cropHeight = 128
41
            self.cache = False
42
            self.sliceResolution = None  # format: HxW
43
            self.addInstanceNoise = False  # Affects only the batch sampling. If True, a tiny bit of noise will be added to every batch
44
            self.filterProtocol = None  # FLAIR, T1, T2
45
            self.filterType = "train"  # train or test
46
            self.axis = 'axial'  # saggital, coronal or axial
47
            self.debug = False
48
            self.normalizationMethod = 'standardization'
49
            self.sliceStart = 0
50
            self.sliceEnd = 155
51
            self.format = "raw"  # raw or aligned; If aligned, nii-files will be crawled and loaded
52
            self.skullStripping = True
53
            self.viewMapping = {'saggital': 2, 'coronal': 1, 'axial': 0}
54
55
    def __init__(self, options=Options()):
56
        self.options = options
57
58
        if options.cache and os.path.isfile(self.pckl_name()):
59
            f = open(self.pckl_name(), 'rb')
60
            tmp = pickle.load(f)
61
            f.close()
62
            self._epochs_completed = tmp._epochs_completed
63
            self._index_in_epoch = tmp._index_in_epoch
64
            self.patientsSplit = tmp.patients_split
65
            self.patients = tmp.patients
66
            self._images, self._labels, self._sets = read_tf_record(self.tfrecord_name())
67
            self._epochs_completed = {'TRAIN': 0, 'VAL': 0, 'TEST': 0}
68
            self._index_in_epoch = {'TRAIN': 0, 'VAL': 0, 'TEST': 0}
69
        else:
70
            # Collect all patients
71
            self.patients = self._get_patients()
72
            self.patientsSplit = {}
73
74
            if not os.path.isfile(self.split_name()):
75
                _numPatients = len(self.patients)
76
                _ridx = numpy.random.permutation(_numPatients)
77
78
                _already_taken = 0
79
                for split in self.options.partition.keys():
80
                    if self.options.partition[split] <= 1.0:
81
                        num_patients_for_current_split = math.floor(self.options.partition[split] * _numPatients)
82
                    else:
83
                        num_patients_for_current_split = self.options.partition[split]
84
85
                    if num_patients_for_current_split > (_numPatients - _already_taken):
86
                        num_patients_for_current_split = _numPatients - _already_taken
87
88
                    self.patientsSplit[split] = _ridx[_already_taken:_already_taken + num_patients_for_current_split]
89
                    _already_taken += num_patients_for_current_split
90
91
                f = open(self.split_name(), 'wb')
92
                pickle.dump(self.patientsSplit, f)
93
                f.close()
94
            else:
95
                f = open(self.split_name(), 'rb')
96
                self.patientsSplit = pickle.load(f)
97
                f.close()
98
99
            self._create_numpy_arrays()
100
101
            self._epochs_completed = {'TRAIN': 0, 'VAL': 0, 'TEST': 0}
102
            self._index_in_epoch = {'TRAIN': 0, 'VAL': 0, 'TEST': 0}
103
104
            if self.options.cache:
105
                write_tf_record(self._images, self._labels, self._sets, self.tfrecord_name())
106
                tmp = copy.copy(self)
107
                tmp._images = None
108
                tmp._labels = None
109
                tmp._sets = None
110
                f = open(self.pckl_name(), 'wb')
111
                pickle.dump(tmp, f)
112
                f.close()
113
114
    def _create_numpy_arrays(self):
115
        # Iterate over all patients and extract slices
116
        _images = []
117
        _labels = []
118
        _sets = []
119
        for p, patient in enumerate(self.patients):
120
            if p in self.patientsSplit['TRAIN']:
121
                _set_of_current_patient = MSISBI2015.SET_TYPES.index('TRAIN')
122
            elif p in self.patientsSplit['VAL']:
123
                _set_of_current_patient = MSISBI2015.SET_TYPES.index('VAL')
124
            elif p in self.patientsSplit['TEST']:
125
                _set_of_current_patient = MSISBI2015.SET_TYPES.index('TEST')
126
127
            for n, nii_filename in enumerate(patient['filtered_files']):
128
                # try:
129
                _images_tmp, _labels_tmp = self.gather_data(patient, nii_filename)
130
                _images += _images_tmp
131
                _labels += _labels_tmp
132
                _sets += [_set_of_current_patient] * len(_images_tmp)
133
134
        self._images = numpy.array(_images).astype(numpy.float32)
135
        self._labels = numpy.array(_labels).astype(numpy.float32)
136
        if self._images.ndim < 4:
137
            self._images = numpy.expand_dims(self._images, 3)
138
        self._sets = numpy.array(_sets).astype(numpy.int32)
139
140
    def gather_data(self, patient, nii_filename):
141
        _images = []
142
        _labels = []
143
144
        nii, nii_seg, nii_skullmap = self.load_volume_and_groundtruth(nii_filename, patient)
145
146
        # Iterate over all slices and collect them
147
        # We only want to select in the range from 15 to 125 (in axial view)
148
        for s in xrange(self.options.sliceStart, min(self.options.sliceEnd, nii.num_slices_along_axis(self.options.axis))):
149
            if 0 < self.options.numSamples < len(_images):
150
                break
151
152
            slice_data = nii.get_slice(s, self.options.axis)
153
            slice_seg = nii_seg.get_slice(s, self.options.axis)
154
155
            # Skip the slice if it is "empty"
156
            if numpy.percentile(slice_data, 90) < 0.2:
157
                continue
158
159
            if self.options.sliceResolution is not None:
160
                # Pad withzeros to top and bottom, if the image is too small
161
                if slice_data.shape[0] < self.options.sliceResolution[0]:
162
                    before_y = math.floor((self.options.sliceResolution[0] - slice_data.shape[0]) / 2.0)
163
                    after_y = math.ceil((self.options.sliceResolution[0] - slice_data.shape[0]) / 2.0)
164
                if slice_data.shape[1] < self.options.sliceResolution[1]:
165
                    before_x = math.floor((self.options.sliceResolution[1] - slice_data.shape[1]) / 2.0)
166
                    after_x = math.ceil((self.options.sliceResolution[1] - slice_data.shape[1]) / 2.0)
167
                if slice_data.shape[0] < self.options.sliceResolution[0] or slice_data.shape[1] < self.options.sliceResolution[1]:
168
                    slice_data = np.pad(slice_data, ((before_y, after_y), (before_x, after_x)), 'constant', constant_values=(0, 0))
169
                    slice_seg = np.pad(slice_seg, ((before_y, after_y), (before_x, after_x)), 'constant', constant_values=(0, 0))
170
                slice_data = zoom(slice_data, float(self.options.sliceResolution[0]) / float(
171
                    slice_data.shape[0]))
172
                slice_seg = zoom(
173
                    slice_seg, float(self.options.sliceResolution[0]) / float(slice_seg.shape[0]), mode="nearest"
174
                )
175
                slice_seg[slice_seg < 0.9] = 0.0
176
                slice_seg[slice_seg >= 0.9] = 1.0
177
                # assert numpy.max(slice_data) <= 1.0, "Resized slice range is outside [0; 1]!"
178
179
            # Either collect crops
180
            if self.options.useCrops:
181
                if self.options.cropType == 'random':
182
                    rx = numpy.random.randint(0, high=(slice_data.shape[1] - self.options.cropWidth),
183
                                              size=self.options.numRandomCropsPerSlice)
184
                    ry = numpy.random.randint(0, high=(slice_data.shape[0] - self.options.cropHeight),
185
                                              size=self.options.numRandomCropsPerSlice)
186
                    for r in range(self.options.numRandomCropsPerSlice):
187
                        _images.append(crop(slice_data, ry(r), rx(r), self.options.cropHeight, self.options.cropWidth))
188
                        _labels.append(crop(slice_data, ry(r), rx(r), self.options.cropHeight, self.options.cropWidth))
189
                elif self.options.cropType == 'center':
190
                    slice_data_cropped = crop_center(slice_data, self.options.cropWidth, self.options.cropHeight)
191
                    slice_seg_cropped = crop_center(slice_seg, self.options.cropWidth, self.options.cropHeight)
192
                    _images.append(slice_data_cropped)
193
                    _labels.append(slice_seg_cropped)
194
                elif self.options.cropType == 'lesions':
195
                    cc_slice = label(slice_seg)
196
                    props = regionprops(cc_slice)
197
                    if len(props) > 0:
198
                        for prop in props:
199
                            cx = prop['centroid'][1]
200
                            cy = prop['centroid'][0]
201
                            if cy < self.options.cropHeight // 2:
202
                                cy = self.options.cropHeight // 2
203
                            if cy > (slice_data.shape[0] - (self.options.cropHeight // 2)):
204
                                cy = (slice_data.shape[0] - (self.options.cropHeight // 2))
205
                            if cx < self.options.cropWidth // 2:
206
                                cx = self.options.cropWidth // 2
207
                            if cx > (slice_data.shape[1] - (self.options.cropWidth // 2)):
208
                                cx = (slice_data.shape[1] - (self.options.cropWidth // 2))
209
                            image_crop = crop(slice_data, int(cy) - (self.options.cropHeight // 2), int(cx) - (self.options.cropWidth // 2),
210
                                              self.options.cropHeight, self.options.cropWidth)
211
                            seg_crop = crop(slice_seg, int(cy) - (self.options.cropHeight // 2), int(cx) - (self.options.cropWidth // 2),
212
                                            self.options.cropHeight, self.options.cropWidth)
213
                            if image_crop.shape[0] != self.options.cropHeight or image_crop.shape[1] != self.options.cropWidth:
214
                                continue
215
                            _images.append(image_crop)
216
                            _labels.append(seg_crop)
217
218
            # Or whole slices
219
            else:
220
                _images.append(slice_data)
221
                _labels.append(slice_seg)
222
223
        return _images, _labels
224
225
    def load_volume_and_groundtruth(self, nii_filename, patient):
226
        # Load the nrrd
227
        try:
228
            nii = NII(nii_filename)
229
            nii_groundtruth = NII(patient['groundtruth'])
230
231
            nii.denoise()
232
            nii.set_view_mapping(self.options.viewMapping)
233
        except:
234
            print('MSISBI2015: Failed to open file ' + nii_filename)
235
236
        # Make sure ground-truth is binary and nrrd doesnt have NaNs
237
        nii.data[np.isnan(nii.data)] = 0.0
238
        nii_groundtruth.data[nii_groundtruth.data < 0.9] = 0.0
239
        nii_groundtruth.data[nii_groundtruth.data >= 0.9] = 1.0
240
241
        # Do skull-stripping, if desired
242
        if self.options.skullStripping:
243
            try:
244
                nii_skullmap = NII(patient['skullmap'])
245
                nii_skullmap.set_view_mapping(self.options.viewMapping)
246
                nii.apply_skullmap(nii_skullmap)
247
            except:
248
                print('MSISBI2015: Failed to open file ' + patient['skullmap'] + ', skipping skullremoval')
249
250
        # In-place normalize the loaded volume
251
        nii.normalize(method=self.options.normalizationMethod, lowerpercentile=0, upperpercentile=99.8)
252
        # nii_skullmap.data = nii_skullmap.data > 0.0
253
254
        return nii, nii_groundtruth, nii_skullmap
255
256
    # Hidden helper function, not supposed to be called from outside!
257
    def _get_patients(self):
258
        return MSISBI2015.get_patients(self.options)
259
260
    @staticmethod
261
    def get_patients(options):
262
        folders = ["training01", "training02", "training03", "training04", "training05"]
263
264
        # Iterate over all folders in folders and collect all patients
265
        patients = []
266
        for f, folder in enumerate(folders):
267
            # Get all files that can be used for training and validation
268
            _patients = glob.glob(os.path.join(options.dir, folder, "preprocessed", folder + "_*_flair_pp.nii"))
269
            for p, pname in enumerate(_patients):
270
                patient = {}
271
                _tmp = os.path.normpath(pname).split(os.path.sep)
272
                patient['name'] = _tmp[-1].replace("_flair_pp.nii", "")
273
                patient['fullpath'] = os.path.join(options.dir, folder, "preprocessed")
274
275
                patient["filtered_files"] = []
276
                for protocol, protocol_array in MSISBI2015.PROTOCOL_MAPPINGS.items():
277
                    if len(options.filterProtocols) > 0 and protocol not in options.filterProtocols:
278
                        continue
279
                    else:
280
                        if options.format == "raw":
281
                            patient[protocol] = os.path.join(options.dir, folder, "preprocessed", patient['name'] + '_' + protocol_array[0] + '_pp.nii')
282
                        elif options.format == "aligned":
283
                            patient[protocol] = os.path.join(options.dir, folder, "preprocessed",
284
                                                             patient['name'] + '_' + protocol_array[0] + '.aligned.nii.gz')
285
286
                    if len(options.filterProtocols) > 0 and protocol not in options.filterProtocols:
287
                        continue
288
                    else:
289
                        if options.format == "raw":
290
                            patient["filtered_files"] += [
291
                                os.path.join(options.dir, folder, "preprocessed", patient['name'] + '_' + protocol_array[0] + '_pp.nii')]
292
                        elif options.format == "aligned":
293
                            patient["filtered_files"] += [
294
                                os.path.join(options.dir, folder, "preprocessed", patient['name'] + '_' + protocol_array[0] + '.aligned.nii.gz')]
295
296
                if options.format == "raw":
297
                    patient['groundtruth'] = os.path.join(options.dir, folder, "masks", patient['name'] + "_mask1.nii")
298
                    patient['skullmap'] = os.path.join(options.dir, folder, "preprocessed", patient['name'] + "_skullmap.nii.gz")
299
                elif options.format == "aligned":
300
                    patient['groundtruth'] = os.path.join(options.dir, folder, "preprocessed", patient['name'] + "_mask1.aligned.nii.gz")
301
                    patient['skullmap'] = os.path.join(options.dir, folder, "preprocessed", patient['name'] + "_skullmap.aligned.nii.gz")
302
303
                # Append to the list of all patients
304
                patients.append(patient)
305
306
        return patients
307
308
    # Returns the indices of patients which belong to either TRAIN, VAL or TEST. Your choice
309
    def get_patient_idx(self, split='TRAIN'):
310
        return self.patientsSplit[split]
311
312
    def get_patient_split(self):
313
        return self.patientsSplit
314
315
    @property
316
    def images(self):
317
        return self._images
318
319
    def get_images(self, set=None):
320
        _setIdx = MSISBI2015.SET_TYPES.index(set)
321
        images_in_set = numpy.where(self._sets == _setIdx)[0]
322
        return self._images[images_in_set]
323
324
    def get_image(self, i):
325
        return self._images[i, :, :, :]
326
327
    def get_label(self, i):
328
        return self._labels[i, :, :, :]
329
330
    def get_patient(self, i):
331
        return self.patients[i]
332
333
    @property
334
    def labels(self):
335
        return self._labels
336
337
    @property
338
    def sets(self):
339
        return self._sets
340
341
    @property
342
    def meta(self):
343
        return self._meta
344
345
    @property
346
    def num_examples(self):
347
        return self._images.shape[0]
348
349
    @property
350
    def width(self):
351
        return self._images.shape[2]
352
353
    @property
354
    def height(self):
355
        return self._images.shape[1]
356
357
    @property
358
    def num_channels(self):
359
        return self._images.shape[3]
360
361
    @property
362
    def epochs_completed(self):
363
        return self._epochs_completed
364
365
    def name(self):
366
        _name = "MSISBI2015"
367
        if self.options.numSamples > 0:
368
            _name += '_n{}'.format(self.options.numSamples)
369
        _name += "_p{}-{}".format(self.options.partition['TRAIN'], self.options.partition['VAL'])
370
        if self.options.useCrops:
371
            _name += "_{}crops{}x{}".format(self.options.cropType, self.options.cropWidth, self.options.cropHeight)
372
            if self.options.cropType == "random":
373
                _name += "_{}cropsPerSlice".format(self.options.numRandomCropsPerSlice)
374
        if self.options.sliceResolution is not None:
375
            _name += "_res{}x{}".format(self.options.sliceResolution[0], self.options.sliceResolution[1])
376
        _name += "_{}".format(self.options.format)
377
        return _name
378
379
    def split_name(self):
380
        return os.path.join(self.dir(), 'split-{}-{}.pckl'.format(self.options.partition['TRAIN'], self.options.partition['VAL']))
381
382
    def pckl_name(self):
383
        return os.path.join(self.dir(), self.name() + ".pckl")
384
385
    def tfrecord_name(self):
386
        return os.path.join(self.dir(), self.name() + ".tfrecord")
387
388
    def dir(self):
389
        return self.options.dir
390
391
    def export_slices(self, dir):
392
        for i in range(self.num_examples):
393
            imwrite(os.path.join(dir, '{}.png'.format(i)), np.squeeze(self.get_image(i) * 255).astype('uint8'))
394
395
    def visualize(self, pause=1):
396
        f, (ax1, ax2) = matplotlib.pyplot.subplots(1, 2)
397
        images_tmp, labels_tmp, _ = self.next_batch(10)
398
        for i in range(images_tmp.shape[0]):
399
            img = numpy.squeeze(images_tmp[i])
400
            lbl = numpy.squeeze(labels_tmp[i])
401
            ax1.imshow(img)
402
            ax1.set_title('Patch')
403
            ax2.imshow(lbl)
404
            ax2.set_title('Groundtruth')
405
            matplotlib.pyplot.pause(pause)
406
407
    def num_batches(self, batchsize, set='TRAIN'):
408
        _setIdx = MSISBI2015.SET_TYPES.index(set)
409
        images_in_set = numpy.where(self._sets == _setIdx)[0]
410
        return len(images_in_set) // batchsize
411
412
    def next_batch(self, batch_size, shuffle=True, set='TRAIN', return_brainmask=True):
413
        """Return the next `batch_size` examples from this data set."""
414
        _setIdx = MSISBI2015.SET_TYPES.index(set)
415
        images_in_set = numpy.where(self._sets == _setIdx)[0]
416
        samples_in_set = len(images_in_set)
417
418
        start = self._index_in_epoch[set]
419
        # Shuffle for the first epoch
420
        if self._epochs_completed[set] == 0 and start == 0 and shuffle:
421
            perm0 = numpy.arange(samples_in_set)
422
            numpy.random.shuffle(perm0)
423
            self._images[images_in_set] = self.images[images_in_set[perm0]]
424
            self._labels[images_in_set] = self.labels[images_in_set[perm0]]
425
            self._sets[images_in_set] = self.sets[images_in_set[perm0]]
426
427
        # Go to the next epoch
428
        if start + batch_size > samples_in_set:
429
            # Finished epoch
430
            self._epochs_completed[set] += 1
431
432
            # Get the rest examples in this epoch
433
            rest_num_examples = samples_in_set - start
434
            images_rest_part = self._images[images_in_set[start:samples_in_set]]
435
            labels_rest_part = self._labels[images_in_set[start:samples_in_set]]
436
437
            # Shuffle the data
438
            if shuffle:
439
                perm = numpy.arange(samples_in_set)
440
                numpy.random.shuffle(perm)
441
                self._images[images_in_set] = self.images[images_in_set[perm]]
442
                self._labels[images_in_set] = self.labels[images_in_set[perm]]
443
                self._sets[images_in_set] = self.sets[images_in_set[perm]]
444
445
            # Start next epoch
446
            start = 0
447
            self._index_in_epoch[set] = batch_size - rest_num_examples
448
            end = self._index_in_epoch[set]
449
            images_new_part = self._images[images_in_set[start:end]]
450
            labels_new_part = self._labels[images_in_set[start:end]]
451
452
            images_tmp = numpy.concatenate((images_rest_part, images_new_part), axis=0)
453
            labels_tmp = numpy.concatenate((labels_rest_part, labels_new_part), axis=0)
454
        else:
455
            self._index_in_epoch[set] += batch_size
456
            end = self._index_in_epoch[set]
457
            images_tmp = self._images[images_in_set[start:end]]
458
            labels_tmp = self._labels[images_in_set[start:end]]
459
460
        if self.options.addInstanceNoise:
461
            noise = numpy.random.normal(0, 0.01, images_tmp.shape)
462
            images_tmp += noise
463
464
        # Check the batch
465
        assert images_tmp.size, "The batch is empty!"
466
        assert labels_tmp.size, "The labels of the current batch are empty!"
467
468
        if return_brainmask:
469
            brainmasks = images_tmp > 0.05
470
        else:
471
            brainmasks = None
472
473
        return images_tmp, labels_tmp, brainmasks