Diff of /dataloaders/MSLUB.py [000000] .. [978658]

Switch to unified view

a b/dataloaders/MSLUB.py
1
"""Functions for reading MSLUB NRRD data."""
2
3
from __future__ import absolute_import
4
from __future__ import division
5
from __future__ import print_function
6
7
import math
8
import os.path
9
import pickle
10
11
import matplotlib.pyplot
12
from imageio import imwrite
13
from scipy.ndimage import zoom
14
from six.moves import xrange  # pylint: disable=redefined-builtin
15
from skimage.measure import label, regionprops
16
17
from utils.NII import *
18
from utils.image_utils import crop, crop_center
19
from utils.tfrecord_utils import *
20
21
matplotlib.pyplot.ion()
22
23
24
# DATASET CLASS
25
class MSLUB(object):
26
    # Class Variables
27
    PROTOCOL_MAPPINGS = {'FLAIR': ['FLAIR'], 'T1': ['T1W'], 'TWKS': ['T1WKS'], 'T2': ['T2W']}
28
    SET_TYPES = ['TRAIN', 'VAL', 'TEST']
29
30
    class Options(object):
31
        def __init__(self):
32
            self.dir = os.path.dirname(os.path.realpath(__file__))
33
            self.numSamples = -1
34
            self.partition = {'TRAIN': 0.7, 'VAL': 0.2, 'TEST': 0.1}
35
            self.useCrops = False
36
            self.cropType = 'random'  # random or center
37
            self.numRandomCropsPerSlice = 5
38
            self.onlyPatchesWithLesions = False
39
            self.rotations = 0
40
            self.cropWidth = 128
41
            self.cropHeight = 128
42
            self.cache = False
43
            self.sliceResolution = None  # format: HxW
44
            self.addInstanceNoise = False  # Affects only the batch sampling. If True, a tiny bit of noise will be added to every batch
45
            self.filterProtocol = None  # FLAIR, T1, T2
46
            self.filterType = "train"  # train or test
47
            self.axis = 'axial'  # saggital, coronal or axial
48
            self.debug = False
49
            self.normalizationMethod = 'standardization'
50
            self.sliceStart = 0
51
            self.sliceEnd = 155
52
            self.format = "raw"  # raw or aligned; If aligned, nii-files will be crawled and loaded
53
            self.skullStripping = True
54
            self.viewMapping = {'saggital': 2, 'coronal': 1, 'axial': 0}
55
56
    def __init__(self, options=Options()):
57
        self.options = options
58
59
        if options.cache and os.path.isfile(self.pckl_name()):
60
            f = open(self.pckl_name(), 'rb')
61
            tmp = pickle.load(f)
62
            f.close()
63
            self._epochs_completed = tmp._epochs_completed
64
            self._index_in_epoch = tmp._index_in_epoch
65
            self.patientsSplit = tmp.patients_split
66
            self.patients = tmp.patients
67
            self._images, self._labels, self._sets = read_tf_record(self.tfrecord_name())
68
            self._epochs_completed = {'TRAIN': 0, 'VAL': 0, 'TEST': 0}
69
            self._index_in_epoch = {'TRAIN': 0, 'VAL': 0, 'TEST': 0}
70
        else:
71
            # Collect all patients
72
            self.patients = self._get_patients()
73
            self.patientsSplit = {}
74
75
            if not os.path.isfile(self.split_name()):
76
                _numPatients = len(self.patients)
77
                _ridx = numpy.random.permutation(_numPatients)
78
79
                _already_taken = 0
80
                for split in self.options.partition.keys():
81
                    if self.options.partition[split] <= 1.0:
82
                        num_patients_for_current_split = math.floor(self.options.partition[split] * _numPatients)
83
                    else:
84
                        num_patients_for_current_split = self.options.partition[split]
85
86
                    if num_patients_for_current_split > (_numPatients - _already_taken):
87
                        num_patients_for_current_split = _numPatients - _already_taken
88
89
                    self.patientsSplit[split] = _ridx[_already_taken:_already_taken + num_patients_for_current_split]
90
                    _already_taken += num_patients_for_current_split
91
92
                f = open(self.split_name(), 'wb')
93
                pickle.dump(self.patientsSplit, f)
94
                f.close()
95
            else:
96
                f = open(self.split_name(), 'rb')
97
                self.patientsSplit = pickle.load(f)
98
                f.close()
99
100
            self._create_numpy_arrays()
101
102
            self._epochs_completed = {'TRAIN': 0, 'VAL': 0, 'TEST': 0}
103
            self._index_in_epoch = {'TRAIN': 0, 'VAL': 0, 'TEST': 0}
104
105
            if self.options.cache:
106
                write_tf_record(self._images, self._labels, self._sets, self.tfrecord_name())
107
                tmp = copy.copy(self)
108
                tmp._images = None
109
                tmp._labels = None
110
                tmp._sets = None
111
                f = open(self.pckl_name(), 'wb')
112
                pickle.dump(tmp, f)
113
                f.close()
114
115
    def _create_numpy_arrays(self):
116
        # Iterate over all patients and extract slices
117
        _images = []
118
        _labels = []
119
        _sets = []
120
        for p, patient in enumerate(self.patients):
121
            if p in self.patientsSplit['TRAIN']:
122
                _set_of_current_patient = MSLUB.SET_TYPES.index('TRAIN')
123
            elif p in self.patientsSplit['VAL']:
124
                _set_of_current_patient = MSLUB.SET_TYPES.index('VAL')
125
            elif p in self.patientsSplit['TEST']:
126
                _set_of_current_patient = MSLUB.SET_TYPES.index('TEST')
127
128
            for n, nii_filename in enumerate(patient['filtered_files']):
129
                # try:
130
                _images_tmp, _labels_tmp = self.gather_data(patient, nii_filename)
131
                _images += _images_tmp
132
                _labels += _labels_tmp
133
                # _mask += _mask_tmp
134
                _sets += [_set_of_current_patient] * len(_images_tmp)
135
                # except:
136
                #  print('MSLUB: Failed to open file ' + nii_filename)
137
                #  continue
138
139
        self._images = numpy.array(_images).astype(numpy.float32)
140
        self._labels = numpy.array(_labels).astype(numpy.float32)
141
        if self._images.ndim < 4:
142
            self._images = numpy.expand_dims(self._images, 3)
143
        self._sets = numpy.array(_sets).astype(numpy.int32)
144
145
    def gather_data(self, patient, nii_filename):
146
        _images = []
147
        _labels = []
148
149
        nii, nii_seg, nii_skullmap = self.load_volume_and_groundtruth(nii_filename, patient)
150
151
        # Iterate over all slices and collect them
152
        # We only want to select in the range from 15 to 125 (in axial view)
153
        for s in xrange(self.options.sliceStart, min(self.options.sliceEnd, nii.num_slices_along_axis(self.options.axis))):
154
            if 0 < self.options.numSamples < len(_images):
155
                break
156
157
            slice_data = nii.get_slice(s, self.options.axis)
158
            slice_seg = nii_seg.get_slice(s, self.options.axis)
159
160
            # Skip the slice if it is "empty"
161
            # if numpy.max(slice_data) < empty_thresh:
162
            if numpy.percentile(slice_data, 90) < 0.2:
163
                continue
164
165
            # assert numpy.max(slice_data) <= 1.0, "Slice range is outside [0; 1]!"
166
167
            if self.options.sliceResolution is not None:
168
                # Pad withzeros to top and bottom, if the image is too small
169
                if slice_data.shape[0] < self.options.sliceResolution[0]:
170
                    before_y = math.floor((self.options.sliceResolution[0] - slice_data.shape[0]) / 2.0)
171
                    after_y = math.ceil((self.options.sliceResolution[0] - slice_data.shape[0]) / 2.0)
172
                if slice_data.shape[1] < self.options.sliceResolution[1]:
173
                    before_x = math.floor((self.options.sliceResolution[1] - slice_data.shape[1]) / 2.0)
174
                    after_x = math.ceil((self.options.sliceResolution[1] - slice_data.shape[1]) / 2.0)
175
                if slice_data.shape[0] < self.options.sliceResolution[0] or slice_data.shape[1] < self.options.sliceResolution[1]:
176
                    slice_data = np.pad(slice_data, ((before_y, after_y), (before_x, after_x)), 'constant',
177
                                        constant_values=(0, 0))
178
                    slice_seg = np.pad(slice_seg, ((before_y, after_y), (before_x, after_x)), 'constant', constant_values=(0, 0))
179
                slice_data = zoom(slice_data, float(self.options.sliceResolution[0]) / float(slice_data.shape[0]))
180
                slice_seg = zoom(slice_seg, float(self.options.sliceResolution[0]) / float(slice_seg.shape[0]), mode="nearest")
181
                slice_seg[slice_seg < 0.9] = 0.0
182
                slice_seg[slice_seg >= 0.9] = 1.0
183
                # assert numpy.max(slice_data) <= 1.0, "Resized slice range is outside [0; 1]!"
184
185
            # Either collect crops
186
            if self.options.useCrops:
187
                if self.options.cropType == 'random':
188
                    rx = numpy.random.randint(0, high=(slice_data.shape[1] - self.options.cropWidth),
189
                                              size=self.options.numRandomCropsPerSlice)
190
                    ry = numpy.random.randint(0, high=(slice_data.shape[0] - self.options.cropHeight),
191
                                              size=self.options.numRandomCropsPerSlice)
192
                    for r in range(self.options.numRandomCropsPerSlice):
193
                        _images.append(crop(slice_data, ry(r), rx(r), self.options.cropHeight, self.options.cropWidth))
194
                        _labels.append(crop(slice_data, ry(r), rx(r), self.options.cropHeight, self.options.cropWidth))
195
                elif self.options.cropType == 'center':
196
                    slice_data_cropped = crop_center(slice_data, self.options.cropWidth, self.options.cropHeight)
197
                    slice_seg_cropped = crop_center(slice_seg, self.options.cropWidth, self.options.cropHeight)
198
                    _images.append(slice_data_cropped)
199
                    _labels.append(slice_seg_cropped)
200
                elif self.options.cropType == 'lesions':
201
                    cc_slice = label(slice_seg)
202
                    props = regionprops(cc_slice)
203
                    if len(props) > 0:
204
                        for prop in props:
205
                            cx = prop['centroid'][1]
206
                            cy = prop['centroid'][0]
207
                            if cy < self.options.cropHeight // 2:
208
                                cy = self.options.cropHeight // 2
209
                            if cy > (slice_data.shape[0] - (self.options.cropHeight // 2)):
210
                                cy = (slice_data.shape[0] - (self.options.cropHeight // 2))
211
                            if cx < self.options.cropWidth // 2:
212
                                cx = self.options.cropWidth // 2
213
                            if cx > (slice_data.shape[1] - (self.options.cropWidth // 2)):
214
                                cx = (slice_data.shape[1] - (self.options.cropWidth // 2))
215
                            image_crop = crop(slice_data, int(cy) - (self.options.cropHeight // 2), int(cx) - (self.options.cropWidth // 2),
216
                                              self.options.cropHeight, self.options.cropWidth)
217
                            seg_crop = crop(slice_seg, int(cy) - (self.options.cropHeight // 2), int(cx) - (self.options.cropWidth // 2),
218
                                            self.options.cropHeight, self.options.cropWidth)
219
                            if image_crop.shape[0] != self.options.cropHeight or image_crop.shape[1] != self.options.cropWidth:
220
                                continue
221
                            _images.append(image_crop)
222
                            _labels.append(seg_crop)
223
                            # _masks.append(crop(slice_data, prop['centroid'][0], prop['centroid'][1], self.options.cropHeight, self.options.cropWidth))
224
                        # find connected components in segmentation slice
225
                        # for every connected component, do a center crop from the segmentation slice, the mask and the actual slice
226
            # Or whole slices
227
            else:
228
                _images.append(slice_data)
229
                _labels.append(slice_seg)
230
231
        return _images, _labels
232
233
    def load_volume_and_groundtruth(self, nii_filename, patient):
234
        # Load the nrrd
235
        try:
236
            nii = NII(nii_filename)
237
            # nii.printStats()
238
            # nii.visualize(axis=self.options.axis)
239
            nii_groundtruth = NII(patient['groundtruth'])
240
            # nii_groundtruth.printStats()
241
242
            nii.denoise()
243
            nii.set_view_mapping(self.options.viewMapping)
244
        except:
245
            print('MSLUB: Failed to open file ' + nii_filename)
246
247
        # Make sure ground-truth is binary and nrrd doesnt have NaNs
248
        nii.data[np.isnan(nii.data)] = 0.0
249
        nii_groundtruth.data[nii_groundtruth.data < 0.9] = 0.0
250
        nii_groundtruth.data[nii_groundtruth.data >= 0.9] = 1.0
251
252
        # Do skull-stripping, if desired
253
        if self.options.skullStripping:
254
            try:
255
                nii_skullmap = NII(patient['skullmap'])
256
                nii_skullmap.set_view_mapping(self.options.viewMapping)
257
                nii.apply_skullmap(nii_skullmap)
258
            except:
259
                print('MSLUB: Failed to open file ' + patient['skullmap'] + ', skipping skullremoval')
260
261
        # In-place normalize the loaded volume
262
        nii.normalize(method=self.options.normalizationMethod, lowerpercentile=0, upperpercentile=99.8)
263
        # nii_skullmap.data = nii_skullmap.data > 0.0
264
265
        return nii, nii_groundtruth, nii_skullmap
266
267
    # Hidden helper function, not supposed to be called from outside!
268
    def _get_patients(self):
269
        return MSLUB.get_patients(self.options)
270
271
    @staticmethod
272
    def get_patients(options):
273
        folders = ["data"]
274
275
        # Iterate over all folders in folders and collect all patients
276
        patients = []
277
        for f, folder in enumerate(folders):
278
            # Get all files that can be used for training and validation
279
            _patients = [f.name for f in os.scandir(os.path.join(options.dir, folder)) if f.is_dir()]
280
            for p, pname in enumerate(_patients):
281
                patient = {
282
                    'name': pname,
283
                    'fullpath': os.path.join(options.dir, folder, pname),
284
                    "filtered_files": []
285
                }
286
287
                for protocol, protocol_array in MSLUB.PROTOCOL_MAPPINGS.items():
288
                    if options.format == "raw":
289
                        patient[protocol] = os.path.join(patient['fullpath'], patient['name'] + '_' + protocol_array[0] + '.nii.gz')
290
                    elif options.format == "aligned":
291
                        patient[protocol] = os.path.join(patient['fullpath'],
292
                                                         patient['name'] + '_' + protocol_array[0] + '.aligned.nii.gz')
293
294
                    if len(options.filterProtocols) > 0 and protocol not in options.filterProtocols:
295
                        continue
296
                    else:
297
                        if options.format == "raw":
298
                            patient["filtered_files"] += [os.path.join(patient['fullpath'], patient['name'] + '_' + protocol_array[0] + '.nii.gz')]
299
                        elif options.format == "aligned":
300
                            patient["filtered_files"] += [os.path.join(patient['fullpath'],
301
                                                                       patient['name'] + '_' + protocol_array[0] + '.aligned.nii.gz')]
302
303
                if options.format == "raw":
304
                    patient['groundtruth'] = os.path.join(patient['fullpath'], patient['name'] + "_consensus_gt.nii.gz")
305
                    patient['skullmap'] = os.path.join(patient['fullpath'], patient['name'] + "_brainmask.nii.gz")
306
                elif options.format == "aligned":
307
                    patient['groundtruth'] = os.path.join(patient['fullpath'], patient['name'] + "_consensus_gt.aligned.nii.gz")
308
                    patient['skullmap'] = os.path.join(patient['fullpath'], patient['name'] + "_brainmask.aligned.nii.gz")
309
310
                # Append to the list of all patients
311
                patients.append(patient)
312
313
        return patients
314
315
    # Returns the indices of patients which belong to either TRAIN, VAL or TEST. Your choice
316
    def get_patient_idx(self, split='TRAIN'):
317
        return self.patientsSplit[split]
318
319
    def get_patient_split(self):
320
        return self.patientsSplit
321
322
    @property
323
    def images(self):
324
        return self._images
325
326
    def get_images(self, set=None):
327
        _setIdx = MSLUB.SET_TYPES.index(set)
328
        images_in_set = numpy.where(self._sets == _setIdx)[0]
329
        return self._images[images_in_set]
330
331
    def get_image(self, i):
332
        return self._images[i, :, :, :]
333
334
    def get_label(self, i):
335
        return self._labels[i, :, :, :]
336
337
    def get_patient(self, i):
338
        return self.patients[i]
339
340
    @property
341
    def labels(self):
342
        return self._labels
343
344
    @property
345
    def sets(self):
346
        return self._sets
347
348
    @property
349
    def meta(self):
350
        return self._meta
351
352
    @property
353
    def num_examples(self):
354
        return self._images.shape[0]
355
356
    @property
357
    def width(self):
358
        return self._images.shape[2]
359
360
    @property
361
    def height(self):
362
        return self._images.shape[1]
363
364
    @property
365
    def num_channels(self):
366
        return self._images.shape[3]
367
368
    @property
369
    def epochs_completed(self):
370
        return self._epochs_completed
371
372
    def name(self):
373
        _name = "MSLUB"
374
        if self.options.numSamples > 0:
375
            _name += '_n{}'.format(self.options.numSamples)
376
        _name += "_p{}-{}".format(self.options.partition['TRAIN'], self.options.partition['VAL'])
377
        if self.options.useCrops:
378
            _name += "_{}crops{}x{}".format(self.options.cropType, self.options.cropWidth, self.options.cropHeight)
379
            if self.options.cropType == "random":
380
                _name += "_{}cropsPerSlice".format(self.options.numRandomCropsPerSlice)
381
        if self.options.sliceResolution is not None:
382
            _name += "_res{}x{}".format(self.options.sliceResolution[0], self.options.sliceResolution[1])
383
        _name += "_{}".format(self.options.format)
384
        return _name
385
386
    def split_name(self):
387
        return os.path.join(self.dir(),
388
                            'split-{}-{}-{}.pckl'.format(self.options.partition['TRAIN'], self.options.partition['VAL'], self.options.partition['TEST']))
389
390
    def pckl_name(self):
391
        return os.path.join(self.dir(), self.name() + ".pckl")
392
393
    def tfrecord_name(self):
394
        return os.path.join(self.dir(), self.name() + ".tfrecord")
395
396
    def dir(self):
397
        return self.options.dir
398
399
    def export_slices(self, dir):
400
        for i in range(self.num_examples):
401
            imwrite(os.path.join(dir, '{}.png'.format(i)), np.squeeze(self.get_image(i) * 255).astype('uint8'))
402
403
    def visualize(self, pause=1):
404
        f, (ax1, ax2) = matplotlib.pyplot.subplots(1, 2)
405
        images_tmp, labels_tmp, _ = self.next_batch(10)
406
        for i in range(images_tmp.shape[0]):
407
            img = numpy.squeeze(images_tmp[i])
408
            lbl = numpy.squeeze(labels_tmp[i])
409
            ax1.imshow(img)
410
            ax1.set_title('Patch')
411
            ax2.imshow(lbl)
412
            ax2.set_title('Groundtruth')
413
            matplotlib.pyplot.pause(pause)
414
415
    def num_batches(self, batchsize, set='TRAIN'):
416
        _setIdx = MSLUB.SET_TYPES.index(set)
417
        images_in_set = numpy.where(self._sets == _setIdx)[0]
418
        return len(images_in_set) // batchsize
419
420
    def next_batch(self, batch_size, shuffle=True, set='TRAIN', return_brainmask=True):
421
        """Return the next `batch_size` examples from this data set."""
422
        _setIdx = MSLUB.SET_TYPES.index(set)
423
        images_in_set = numpy.where(self._sets == _setIdx)[0]
424
        samples_in_set = len(images_in_set)
425
426
        start = self._index_in_epoch[set]
427
        # Shuffle for the first epoch
428
        if self._epochs_completed[set] == 0 and start == 0 and shuffle:
429
            perm0 = numpy.arange(samples_in_set)
430
            numpy.random.shuffle(perm0)
431
            self._images[images_in_set] = self.images[images_in_set[perm0]]
432
            self._labels[images_in_set] = self.labels[images_in_set[perm0]]
433
            self._sets[images_in_set] = self.sets[images_in_set[perm0]]
434
435
        # Go to the next epoch
436
        if start + batch_size > samples_in_set:
437
            # Finished epoch
438
            self._epochs_completed[set] += 1
439
440
            # Get the rest examples in this epoch
441
            rest_num_examples = samples_in_set - start
442
            images_rest_part = self._images[images_in_set[start:samples_in_set]]
443
            labels_rest_part = self._labels[images_in_set[start:samples_in_set]]
444
445
            # Shuffle the data
446
            if shuffle:
447
                perm = numpy.arange(samples_in_set)
448
                numpy.random.shuffle(perm)
449
                self._images[images_in_set] = self.images[images_in_set[perm]]
450
                self._labels[images_in_set] = self.labels[images_in_set[perm]]
451
                self._sets[images_in_set] = self.sets[images_in_set[perm]]
452
453
            # Start next epoch
454
            start = 0
455
            self._index_in_epoch[set] = batch_size - rest_num_examples
456
            end = self._index_in_epoch[set]
457
            images_new_part = self._images[images_in_set[start:end]]
458
            labels_new_part = self._labels[images_in_set[start:end]]
459
460
            images_tmp = numpy.concatenate((images_rest_part, images_new_part), axis=0)
461
            labels_tmp = numpy.concatenate((labels_rest_part, labels_new_part), axis=0)
462
        else:
463
            self._index_in_epoch[set] += batch_size
464
            end = self._index_in_epoch[set]
465
            images_tmp = self._images[images_in_set[start:end]]
466
            labels_tmp = self._labels[images_in_set[start:end]]
467
468
        if self.options.addInstanceNoise:
469
            noise = numpy.random.normal(0, 0.01, images_tmp.shape)
470
            images_tmp += noise
471
472
        # Check the batch
473
        assert images_tmp.size, "The batch is empty!"
474
        assert labels_tmp.size, "The labels of the current batch are empty!"
475
476
        if return_brainmask:
477
            brainmasks = images_tmp > 0.05
478
        else:
479
            brainmasks = None
480
481
        return images_tmp, labels_tmp, brainmasks