Diff of /dataloaders/BRAINWEB.py [000000] .. [978658]

Switch to unified view

a b/dataloaders/BRAINWEB.py
1
"""Functions for reading BRAINWEB NII data."""
2
3
from __future__ import absolute_import
4
from __future__ import division
5
from __future__ import print_function
6
7
import glob
8
import math
9
import os.path
10
import pickle
11
12
import cv2
13
import matplotlib.pyplot
14
from imageio import imwrite
15
from scipy.ndimage import rotate
16
17
from utils.MINC import *
18
from utils.image_utils import crop, crop_center
19
from utils.tfrecord_utils import *
20
21
22
class BRAINWEB(object):
23
    FILTER_TYPES = ['NORMAL', 'MILDMS', 'MODERATEMS', 'SEVEREMS']
24
    SET_TYPES = ['TRAIN', 'VAL', 'TEST']
25
    LABELS = {'BACKGROUND': 0, 'CSF': 1, 'GM': 2, 'WM': 3, 'FAT': 4, 'MUSCLE': 5, 'SKIN': 6, 'SKULL': 7, 'GLIALMATTER': 8, 'CONNECTIVE': 9, 'LESION': 10}
26
    VIEW_MAPPING = {'saggital': 0, 'coronal': 1, 'axial': 2}
27
    PROTOCOL_MAPPINGS = {'FLAIR': 'flair*', 'T2': 't2*'}
28
29
    class Options(object):
30
        def __init__(self):
31
            self.description = None
32
            self.dir = os.path.dirname(os.path.realpath(__file__))
33
            self.folderNormal = 'normal'
34
            self.folderMildMS = os.path.join('lesions', 'mild')
35
            self.folderModerateMS = os.path.join('lesions', 'moderate')
36
            self.folderSevereMS = os.path.join('lesions', 'severe')
37
            self.folderGT = 'groundtruth'
38
            self.numSamples = -1
39
            self.partition = {'TRAIN': 0.6, 'VAL': 0.15, 'TEST': 0.25}
40
            self.sliceStart = 20
41
            self.sliceEnd = 140
42
            self.useCrops = False
43
            self.cropType = 'random'  # random or center
44
            self.numRandomCropsPerSlice = 5
45
            self.rotations = [0]
46
            self.cropWidth = 128
47
            self.cropHeight = 128
48
            self.cache = False
49
            self.sliceResolution = None  # format: HxW
50
            self.addInstanceNoise = False  # Affects only the batch sampling. If True, a tiny bit of noise will be added to every batch
51
            self.filterProtocol = None  # T2 or FLAIR only, not implemented for now
52
            self.filterType = None  # MILDMS, MODERATEMS, SEVEREMS, NORMAL
53
            self.axis = 'axial'  # saggital, coronal or axial
54
            self.debug = False
55
            self.normalizationMethod = 'standardization'
56
            self.skullRemoval = False
57
            self.backgroundRemoval = False
58
59
    def __init__(self, options=Options()):
60
        self.options = options
61
62
        if options.cache and os.path.isfile(self.pckl_name()):
63
            f = open(self.pckl_name(), 'rb')
64
            tmp = pickle.load(f)
65
            f.close()
66
            self._epochs_completed = tmp._epochs_completed
67
            self._index_in_epoch = tmp._index_in_epoch
68
            self.patients = self._get_patients()
69
            self._images, self._labels, self._sets = read_tf_record(self.tfrecord_name())
70
71
            f = open(self.split_name(), 'rb')
72
            self.patients_split = pickle.load(f)
73
            f.close()
74
            if not os.path.exists(self.split_name() + ".deprecated"):
75
                os.rename(self.split_name(), self.split_name() + ".deprecated")
76
            self._convert_patient_split()
77
78
            self._epochs_completed = {'TRAIN': 0, 'VAL': 0, 'TEST': 0}
79
            self._index_in_epoch = {'TRAIN': 0, 'VAL': 0, 'TEST': 0}
80
        else:
81
            # Collect all patients
82
            self.patients = self._get_patients()
83
            self.patients_split = {}  # Here we will later store the info whether a patient belongs to train, val or test
84
85
            # Determine Train, Val & Test set based on patients
86
            if not os.path.isfile(self.split_name()):
87
                _num_patients = len(self.patients)
88
                _ridx = numpy.random.permutation(_num_patients)
89
90
                _already_taken = 0
91
                for split in self.options.partition.keys():
92
                    if 1.0 >= self.options.partition[split] > 0.0:
93
                        num_patients_for_current_split = max(1, math.floor(self.options.partition[split] * _num_patients))
94
                    else:
95
                        num_patients_for_current_split = int(self.options.partition[split])
96
97
                    if num_patients_for_current_split > (_num_patients - _already_taken):
98
                        num_patients_for_current_split = _num_patients - _already_taken
99
100
                    self.patients_split[split] = _ridx[_already_taken:_already_taken + num_patients_for_current_split]
101
                    _already_taken += num_patients_for_current_split
102
103
                self._convert_patient_split()  # NEW! We have a new format for storing hte patientsSplit which is OS agnostic.
104
            else:
105
                f = open(self.split_name(), 'rb')
106
                self.patients_split = pickle.load(f)
107
                f.close()
108
                self._convert_patient_split()  # NEW! We have a new format for storing hte patientsSplit which is OS agnostic.
109
110
            # Iterate over all patients and the filtered NII files and extract slices
111
            _images = []
112
            _labels = []
113
            _sets = []
114
            for p, patient in enumerate(self.patients):
115
                if patient["name"] in self.patients_split['TRAIN']:
116
                    _set_of_current_patient = BRAINWEB.SET_TYPES.index('TRAIN')
117
                elif patient["name"] in self.patients_split['VAL']:
118
                    _set_of_current_patient = BRAINWEB.SET_TYPES.index('VAL')
119
                elif patient["name"] in self.patients_split['TEST']:
120
                    _set_of_current_patient = BRAINWEB.SET_TYPES.index('TEST')
121
122
                minc, minc_seg, minc_skullmap = self.load_volume_and_groundtruth(patient["filtered_files"][0], patient)
123
124
                # Iterate over all slices and collect them
125
                for s in range(self.options.sliceStart, min(self.options.sliceEnd, minc.num_slices_along_axis(self.options.axis))):
126
                    if 0 < self.options.numSamples < len(_images):
127
                        break
128
129
                    slice_data = minc.get_slice(s, self.options.axis)
130
                    slice_seg = minc_seg.get_slice(s, self.options.axis)
131
132
                    # Skip the slice if it is entirely black
133
                    if numpy.unique(slice_data).size == 1:
134
                        continue
135
136
                    # assert numpy.max(slice_data) <= 1.0, "Slice range is outside [0; 1]!"
137
138
                    if self.options.sliceResolution is not None:
139
                        # If the images are too big in resolution, do downsampling
140
                        if slice_data.shape[0] > self.options.sliceResolution[0] or slice_data.shape[1] > self.options.sliceResolution[1]:
141
                            slice_data = cv2.resize(slice_data, tuple(self.options.sliceResolution))
142
                            slice_seg = cv2.resize(slice_seg, tuple(self.options.sliceResolution), interpolation=cv2.INTER_NEAREST)
143
                        # Otherwise, do zero padding
144
                        else:
145
                            tmp_slice = numpy.zeros(self.options.sliceResolution)
146
                            tmp_slice_seg = numpy.zeros(self.options.sliceResolution)
147
                            start_x = (self.options.sliceResolution[1] - slice_data.shape[1]) // 2
148
                            start_y = (self.options.sliceResolution[0] - slice_data.shape[0]) // 2
149
                            end_x = start_x + slice_data.shape[1]
150
                            end_y = start_y + slice_data.shape[0]
151
                            tmp_slice[start_y:end_y, start_x:end_x] = slice_data
152
                            tmp_slice_seg[start_y:end_y, start_x:end_x] = slice_seg
153
                            slice_data = tmp_slice
154
                            slice_seg = tmp_slice_seg
155
156
                    for angle in self.options.rotations:
157
                        if angle != 0:
158
                            slice_data_rotated = rotate(slice_data, angle, reshape=False)
159
                            slice_seg_rotated = rotate(slice_seg, angle, reshape=False, mode='nearest')
160
                        else:
161
                            slice_data_rotated = slice_data
162
                            slice_seg_rotated = slice_seg
163
164
                        # Either collect crops
165
                        if self.options.useCrops:
166
                            if self.options.cropType == 'random':
167
                                rx = numpy.random.randint(0, high=(slice_data_rotated.shape[1] - self.options.cropWidth),
168
                                                          size=self.options.numRandomCropsPerSlice)
169
                                ry = numpy.random.randint(0, high=(slice_data_rotated.shape[0] - self.options.cropHeight),
170
                                                          size=self.options.numRandomCropsPerSlice)
171
                                for r in range(self.options.numRandomCropsPerSlice):
172
                                    _images.append(crop(slice_data_rotated, ry[r], rx[r], self.options.cropHeight, self.options.cropWidth))
173
                                    _labels.append(crop(slice_data_rotated, ry[r], rx[r], self.options.cropHeight, self.options.cropWidth))
174
                                    _sets.append(_set_of_current_patient)
175
                            elif self.options.cropType == 'center':
176
                                slice_data_cropped = crop_center(slice_data_rotated, self.options.cropWidth, self.options.cropHeight)
177
                                slice_seg_cropped = crop_center(slice_seg_rotated, self.options.cropWidth, self.options.cropHeight)
178
                                _images.append(slice_data_cropped)
179
                                _labels.append(slice_seg_cropped)
180
                                _sets.append(_set_of_current_patient)
181
                        # Or whole slices
182
                        else:
183
                            _images.append(slice_data_rotated)
184
                            _labels.append(slice_seg_rotated)
185
                            _sets.append(_set_of_current_patient)
186
187
            self._images = numpy.array(_images).astype(numpy.float32)
188
            self._labels = numpy.array(_labels).astype(numpy.float32)
189
            # assert numpy.max(self._images) <= 1.0, "MINC range is outside [0; 1]!"
190
            if self._images.ndim < 4:
191
                self._images = numpy.expand_dims(self._images, 3)
192
            self._sets = numpy.array(_sets).astype(numpy.int32)
193
            self._epochs_completed = {'TRAIN': 0, 'VAL': 0, 'TEST': 0}
194
            self._index_in_epoch = {'TRAIN': 0, 'VAL': 0, 'TEST': 0}
195
196
            if self.options.cache:
197
                write_tf_record(self._images, self._labels, self._sets, self.tfrecord_name())
198
                tmp = copy.copy(self)
199
                tmp._images = None
200
                tmp._labels = None
201
                tmp._sets = None
202
                f = open(self.pckl_name(), 'wb')
203
                pickle.dump(tmp, f)
204
                f.close()
205
206
    def _get_patients(self):
207
        return BRAINWEB.get_patients(self.options)
208
209
    @staticmethod
210
    def get_patients(options):
211
        minc_folders = [options.folderNormal, options.folderMildMS, options.folderModerateMS, options.folderSevereMS]
212
213
        # Iterate over all folders and collect patients
214
        patients = []
215
        for n, minc_folder in enumerate(minc_folders):
216
            if minc_folder == options.folderNormal:
217
                _type = 'NORMAL'
218
            elif minc_folder == options.folderMildMS:
219
                _type = 'MILDMS'
220
            elif minc_folder == options.folderModerateMS:
221
                _type = 'MODERATEMS'
222
            elif minc_folder == options.folderSevereMS:
223
                _type = 'SEVEREMS'
224
225
            # Continue with the next patient if the current one is not part of the desired types
226
            if _type not in options.filterType:
227
                continue
228
229
            if options.filterProtocol:
230
                _regex = BRAINWEB.PROTOCOL_MAPPINGS[options.filterProtocol] + ".mnc.gz"
231
            else:
232
                _regex = "*.mnc.gz"
233
            _files = glob.glob(os.path.join(options.dir, minc_folder, _regex))
234
            for f, fname in enumerate(_files):
235
                patient = {
236
                    'name': os.path.basename(fname),
237
                    'type': _type,
238
                    'fullpath': fname
239
                }
240
                patient['filtered_files'] = patient['fullpath']
241
242
                if patient['type'] == 'NORMAL':
243
                    patient['groundtruth_filename'] = os.path.join(options.dir, options.folderGT, 'normal.mnc.gz')
244
                elif patient['type'] == 'MILDMS':
245
                    patient['groundtruth_filename'] = os.path.join(options.dir, options.folderGT, 'mild_lesions.mnc.gz')
246
                elif patient['type'] == 'MODERATEMS':
247
                    patient['groundtruth_filename'] = os.path.join(options.dir, options.folderGT, 'moderate_lesions.mnc.gz')
248
                elif patient['type'] == 'SEVEREMS':
249
                    patient['groundtruth_filename'] = os.path.join(options.dir, options.folderGT, 'severe_lesions.mnc.gz')
250
251
                patients.append(patient)
252
253
        return patients
254
255
    def load_volume_and_groundtruth(self, minc_filename, patient):
256
        minc_filename = patient['fullpath']
257
        try:
258
            minc = MINC(minc_filename)  # NII also works with MINC
259
            minc.set_view_mapping(BRAINWEB.VIEW_MAPPING)
260
        except:
261
            print('BRAINWEB: Failed to open file ' + minc_filename)
262
263
        # Try to load the segmentation ground-truth
264
        minc_seg_path = patient["groundtruth_filename"]
265
        minc_seg = MINC(minc_seg_path)
266
        skullmap = MINC(minc_seg_path)
267
        skullmap.data = (skullmap.data * 0.0) + 1.0
268
        skullmap.set_view_mapping(BRAINWEB.VIEW_MAPPING)
269
        minc_seg.set_view_mapping(BRAINWEB.VIEW_MAPPING)
270
271
        # If desired, compute the skullmap
272
        if self.options.skullRemoval:
273
            skullmap.data[minc_seg.data == BRAINWEB.LABELS['FAT']] = 0
274
            skullmap.data[minc_seg.data == BRAINWEB.LABELS['MUSCLE']] = 0
275
            skullmap.data[minc_seg.data == BRAINWEB.LABELS['SKIN']] = 0
276
            skullmap.data[minc_seg.data == BRAINWEB.LABELS['SKULL']] = 0
277
            skullmap.data[minc_seg.data == BRAINWEB.LABELS['CONNECTIVE']] = 0
278
279
        if self.options.backgroundRemoval:
280
            skullmap.data[minc_seg.data == BRAINWEB.LABELS['BACKGROUND']] = 0
281
282
        # Binarize minc_seg
283
        lesion_idx = (minc_seg.data == BRAINWEB.LABELS['LESION'])
284
        nonlesion_idx = (minc_seg.data != BRAINWEB.LABELS['LESION'])
285
        minc_seg.data[lesion_idx] = 1
286
        minc_seg.data[nonlesion_idx] = 0
287
288
        if self.options.skullRemoval or self.options.backgroundRemoval:
289
            minc.apply_skullmap(skullmap)
290
291
        # In-place normalize the loaded volume
292
        minc.normalize(method=self.options.normalizationMethod, lowerpercentile=0.0, upperpercentile=99.8)
293
        # 99.8 percentile described in LG Ny´ul, Jayaram K Udupa, and Xuan Zhang.
294
        # New variants of a method of MRI scale standardization.
295
        # IEEE transactions on medical imaging, 19(2):143–150, 2000.
296
        # assert numpy.max(minc.getData()) <= 1.0, "MINC range is outside [0; 1]!"
297
298
        return minc, minc_seg, skullmap
299
300
    # Returns the indices of patients which belong to either TRAIN, VAL or TEST. Your choice
301
    def get_patient_idx(self, split='TRAIN'):
302
        idx = []
303
        for pidx, patient in enumerate(self.patients):
304
            if patient["name"] in self.patients_split[split]:
305
                idx += [pidx]
306
        return idx
307
308
    def get_patient_split(self):
309
        return self.patients_split
310
311
    @property
312
    def images(self):
313
        return self._images
314
315
    def get_images(self, set=None):
316
        _setIdx = BRAINWEB.SET_TYPES.index(set)
317
        images_in_set = numpy.where(self._sets == _setIdx)[0]
318
        return self._images[images_in_set]
319
320
    def get_image(self, i):
321
        return self._images[i, :, :, :]
322
323
    def get_label(self, i):
324
        return self._labels[i, :, :, :]
325
326
    @property
327
    def labels(self):
328
        return self._labels
329
330
    @property
331
    def sets(self):
332
        return self._sets
333
334
    @property
335
    def meta(self):
336
        return self._meta
337
338
    @property
339
    def num_examples(self):
340
        return self._images.shape[0]
341
342
    @property
343
    def width(self):
344
        return self._images.shape[2]
345
346
    @property
347
    def height(self):
348
        return self._images.shape[1]
349
350
    @property
351
    def num_channels(self):
352
        return self._images.shape[3]
353
354
    @property
355
    def epochs_completed(self):
356
        return self._epochs_completed
357
358
    def name(self):
359
        _name = "BRAINWEB"
360
        if self.options.description:
361
            _name += "_{}".format(self.options.description)
362
        if self.options.numSamples > 0:
363
            _name += '_n{}'.format(self.options.numSamples)
364
        _name += "_p{}-{}-{}".format(self.options.partition['TRAIN'], self.options.partition['VAL'], self.options.partition['TEST'])
365
        if self.options.useCrops:
366
            _name += "_{}crops{}x{}".format(self.options.cropType, self.options.cropWidth, self.options.cropHeight)
367
            if self.options.cropType == "random":
368
                _name += "_{}cropsPerSlice".format(self.options.numRandomCropsPerSlice)
369
        if self.options.sliceResolution is not None:
370
            _name += "_res{}x{}".format(self.options.sliceResolution[0], self.options.sliceResolution[1])
371
        if self.options.skullRemoval:
372
            _name += "_noSkull"
373
        if self.options.backgroundRemoval:
374
            _name += "_noBackground"
375
        return _name
376
377
    def pckl_name(self):
378
        return os.path.join(self.dir(), self.name() + ".pckl")
379
380
    def tfrecord_name(self):
381
        return os.path.join(self.dir(), self.name() + ".tfrecord")
382
383
    def split_name(self):
384
        return os.path.join(self.dir(),
385
                            'split-{}-{}-{}.pckl'.format(self.options.partition['TRAIN'], self.options.partition['VAL'], self.options.partition['TEST']))
386
387
    def dir(self):
388
        return self.options.dir
389
390
    def export_slices(self, dir):
391
        for i in range(self.num_examples):
392
            imwrite(os.path.join(dir, '{}.png'.format(i)), np.squeeze(self.get_image(i) * 255).astype('uint8'))
393
394
    def visualize(self, pause=1, set='TRAIN'):
395
        f, (ax1, ax2) = matplotlib.pyplot.subplots(1, 2)
396
        images_tmp, labels_tmp, _ = self.next_batch(10, set=set)
397
        for i in range(images_tmp.shape[0]):
398
            img = numpy.squeeze(images_tmp[i])
399
            lbl = numpy.squeeze(labels_tmp[i])
400
            ax1.imshow(img)
401
            ax1.set_title('Patch')
402
            ax2.imshow(lbl)
403
            ax2.set_title('Groundtruth')
404
            matplotlib.pyplot.pause(pause)
405
406
    def num_batches(self, batchsize, set='TRAIN'):
407
        _setIdx = BRAINWEB.SET_TYPES.index(set)
408
        images_in_set = numpy.where(self._sets == _setIdx)[0]
409
        return len(images_in_set) // batchsize
410
411
    def next_batch(self, batch_size, shuffle=True, set='TRAIN', return_brainmask=False):
412
        """Return the next `batch_size` examples from this data set."""
413
        _setIdx = BRAINWEB.SET_TYPES.index(set)
414
        images_in_set = numpy.where(self._sets == _setIdx)[0]
415
        samples_in_set = len(images_in_set)
416
417
        start = self._index_in_epoch[set]
418
        # Shuffle for the first epoch
419
        if self._epochs_completed == 0 and start == 0 and shuffle:
420
            perm0 = numpy.arange(samples_in_set)
421
            numpy.random.shuffle(perm0)
422
            self._images[images_in_set] = self.images[images_in_set[perm0]]
423
            self._labels[images_in_set] = self.labels[images_in_set[perm0]]
424
            self._sets[images_in_set] = self.sets[images_in_set[perm0]]
425
426
        # Go to the next epoch
427
        if start + batch_size > samples_in_set:
428
            # Finished epoch
429
            self._epochs_completed[set] += 1
430
431
            # Get the rest examples in this epoch
432
            rest_num_examples = samples_in_set - start
433
            images_rest_part = self._images[images_in_set[start:samples_in_set]]
434
            labels_rest_part = self._labels[images_in_set[start:samples_in_set]]
435
436
            # Shuffle the data
437
            if shuffle:
438
                perm = numpy.arange(samples_in_set)
439
                numpy.random.shuffle(perm)
440
                self._images[images_in_set] = self.images[images_in_set[perm]]
441
                self._labels[images_in_set] = self.labels[images_in_set[perm]]
442
                self._sets[images_in_set] = self.sets[images_in_set[perm]]
443
444
            # Start next epoch
445
            start = 0
446
            self._index_in_epoch[set] = batch_size - rest_num_examples
447
            end = self._index_in_epoch[set]
448
            images_new_part = self._images[images_in_set[start:end]]
449
            labels_new_part = self._labels[images_in_set[start:end]]
450
451
            images_tmp = numpy.concatenate((images_rest_part, images_new_part), axis=0)
452
            labels_tmp = numpy.concatenate((labels_rest_part, labels_new_part), axis=0)
453
        else:
454
            self._index_in_epoch[set] += batch_size
455
            end = self._index_in_epoch[set]
456
            images_tmp = self._images[images_in_set[start:end]]
457
            labels_tmp = self._labels[images_in_set[start:end]]
458
459
        if self.options.addInstanceNoise:
460
            noise = numpy.random.normal(0, 0.01, images_tmp.shape)
461
            images_tmp += noise
462
463
        # Check the batch
464
        assert images_tmp.size, "The batch is empty!"
465
        assert labels_tmp.size, "The labels of the current batch are empty!"
466
467
        if return_brainmask:
468
            brainmasks = np.copy(labels_tmp)
469
            brainmasks[brainmasks == BRAINWEB.LABELS['FAT']] = 0
470
            brainmasks[brainmasks == BRAINWEB.LABELS['MUSCLE']] = 0
471
            brainmasks[brainmasks == BRAINWEB.LABELS['SKIN']] = 0
472
            brainmasks[brainmasks == BRAINWEB.LABELS['SKULL']] = 0
473
            brainmasks[brainmasks == BRAINWEB.LABELS['CONNECTIVE']] = 0
474
            brainmasks[brainmasks == BRAINWEB.LABELS['BACKGROUND']] = 0
475
            brainmasks[brainmasks > 0] = 1
476
            return images_tmp, labels_tmp, brainmasks
477
478
        return images_tmp, labels_tmp, None
479
480
    def _convert_patient_split(self):
481
        for split in self.patients_split.keys():
482
            _list_of_patient_names = []
483
            for pidx in self.patients_split[split]:
484
                if not isinstance(pidx, str):
485
                    _list_of_patient_names += [self.patients[pidx]['name']]
486
                else:
487
                    _list_of_patient_names = self.patients_split[split]
488
                    break
489
            self.patients_split[split] = _list_of_patient_names
490
491
        f = open(self.split_name(), 'wb')
492
        pickle.dump(self.patients_split, f)
493
        f.close()