Diff of /dataloaders/MSSEG2008.py [000000] .. [978658]

Switch to unified view

a b/dataloaders/MSSEG2008.py
1
"""Functions for reading MSSEG2008 NRRD data."""
2
3
from __future__ import absolute_import
4
from __future__ import division
5
from __future__ import print_function
6
7
import math
8
import os.path
9
import pickle
10
11
import matplotlib.pyplot
12
from imageio import imwrite
13
from scipy.ndimage import zoom
14
from six.moves import xrange  # pylint: disable=redefined-builtin
15
from skimage.measure import label, regionprops
16
17
from dataloaders.NRRD import *
18
from utils.NII import *
19
from utils.image_utils import crop, crop_center
20
from utils.tfrecord_utils import *
21
22
23
class MSSEG2008(object):
24
    PROTOCOL_MAPPINGS = ['FLAIR', 'T1', 'T2']
25
    SET_TYPES = ['TRAIN', 'VAL', 'TEST']
26
27
    class Options(object):
28
        def __init__(self):
29
            self.dir = os.path.dirname(os.path.realpath(__file__))
30
            self.folderTrainUNC = 'UNC_train'
31
            self.folderTestUNC = 'UNC_test'
32
            self.folderTrainCHB = 'CHB_train'
33
            self.folderTestCHB = 'CHB_test'
34
            self.numSamples = -1
35
            self.partition = {'TRAIN': 0.7, 'VAL': 0.2, 'TEST': 0.1}
36
            self.useCrops = False
37
            self.cropType = 'random'  # random or center
38
            self.numRandomCropsPerSlice = 5
39
            self.onlyPatchesWithLesions = False
40
            self.rotations = 0
41
            self.cropWidth = 128
42
            self.cropHeight = 128
43
            self.cache = False
44
            self.sliceResolution = None  # format: HxW
45
            self.addInstanceNoise = False  # Affects only the batch sampling. If True, a tiny bit of noise will be added to every batch
46
            self.filterProtocol = None  # FLAIR, T1, T2
47
            self.filterScanner = "UNC"  # UNC or CHB
48
            self.filterType = "train"  # train or test
49
            self.axis = 'axial'  # saggital, coronal or axial
50
            self.debug = False
51
            self.normalizationMethod = 'standardization'
52
            self.sliceStart = 0
53
            self.sliceEnd = 155
54
            self.format = "raw"  # raw or aligned; If aligned, nii-files will be crawled and loaded
55
            self.skullStripping = True
56
            self.viewMapping = {'saggital': 2, 'coronal': 1, 'axial': 0}
57
58
    def __init__(self, options=Options()):
59
        self.options = options
60
61
        if options.cache and os.path.isfile(self.pckl_name()):
62
            f = open(self.pckl_name(), 'rb')
63
            tmp = pickle.load(f)
64
            f.close()
65
            self._epochs_completed = tmp._epochs_completed
66
            self._index_in_epoch = tmp._index_in_epoch
67
            self.patientsSplit = tmp.patients_split
68
            self.patients = tmp.patients
69
            self._images, self._labels, self._sets = read_tf_record(self.tfrecord_name())
70
            self._epochs_completed = {'TRAIN': 0, 'VAL': 0, 'TEST': 0}
71
            self._index_in_epoch = {'TRAIN': 0, 'VAL': 0, 'TEST': 0}
72
        else:
73
            # Collect all patients
74
            self.patients = self._get_patients()
75
            self.patientsSplit = {}
76
77
            if not os.path.isfile(self.split_name()):
78
                _numPatients = len(self.patients)
79
                _ridx = numpy.random.permutation(_numPatients)
80
81
                _already_taken = 0
82
                for split in self.options.partition.keys():
83
                    if self.options.partition[split] <= 1.0:
84
                        numPatientsForCurrentSplit = math.floor(self.options.partition[split] * _numPatients)
85
                    else:
86
                        numPatientsForCurrentSplit = self.options.partition[split]
87
88
                    if numPatientsForCurrentSplit > (_numPatients - _already_taken):
89
                        numPatientsForCurrentSplit = _numPatients - _already_taken
90
91
                    self.patientsSplit[split] = _ridx[_already_taken:_already_taken + numPatientsForCurrentSplit]
92
                    _already_taken += numPatientsForCurrentSplit
93
94
                f = open(self.split_name(), 'wb')
95
                pickle.dump(self.patientsSplit, f)
96
                f.close()
97
            else:
98
                f = open(self.split_name(), 'rb')
99
                self.patientsSplit = pickle.load(f)
100
                f.close()
101
102
            self._create_numpy_arrays()
103
104
            self._epochs_completed = {'TRAIN': 0, 'VAL': 0, 'TEST': 0}
105
            self._index_in_epoch = {'TRAIN': 0, 'VAL': 0, 'TEST': 0}
106
107
            if self.options.cache:
108
                write_tf_record(self._images, self._labels, self._sets, self.tfrecord_name())
109
                tmp = copy.copy(self)
110
                tmp._images = None
111
                tmp._labels = None
112
                tmp._sets = None
113
                f = open(self.pckl_name(), 'wb')
114
                pickle.dump(tmp, f)
115
                f.close()
116
117
    def _create_numpy_arrays(self):
118
        # Iterate over all patients and extract slices
119
        _images = []
120
        _labels = []
121
        _sets = []
122
        for p, patient in enumerate(self.patients):
123
            if p in self.patientsSplit['TRAIN']:
124
                _set_of_current_patient = MSSEG2008.SET_TYPES.index('TRAIN')
125
            elif p in self.patientsSplit['VAL']:
126
                _set_of_current_patient = MSSEG2008.SET_TYPES.index('VAL')
127
            elif p in self.patientsSplit['TEST']:
128
                _set_of_current_patient = MSSEG2008.SET_TYPES.index('TEST')
129
130
            for n, nrrd_filename in enumerate(patient['filtered_files']):
131
                # try:
132
                _images_tmp, _labels_tmp = self.gather_data(patient, nrrd_filename)
133
                _images += _images_tmp
134
                _labels += _labels_tmp
135
                # _mask += _mask_tmp
136
                _sets += [_set_of_current_patient] * len(_images_tmp)
137
                # except:
138
                #  print('MSSEG2008: Failed to open file ' + nrrd_filename)
139
                #  continue
140
141
        self._images = numpy.array(_images).astype(numpy.float32)
142
        self._labels = numpy.array(_labels).astype(numpy.float32)
143
        if self._images.ndim < 4:
144
            self._images = numpy.expand_dims(self._images, 3)
145
        self._sets = numpy.array(_sets).astype(numpy.int32)
146
147
    def gather_data(self, patient, nrrd_filename):
148
        _images = []
149
        _labels = []
150
151
        nrrd, nrrd_seg, nrrd_skullmap = self.load_volume_and_groundtruth(nrrd_filename, patient)
152
153
        # Iterate over all slices and collect them
154
        # We only want to select in the range from 15 to 125 (in axial view)
155
        for s in xrange(self.options.sliceStart, min(self.options.sliceEnd, nrrd.num_slices_along_axis(self.options.axis))):
156
            if 0 < self.options.numSamples < len(_images):
157
                break
158
159
            slice_data = nrrd.get_slice(s, self.options.axis)
160
            slice_seg = nrrd_seg.get_slice(s, self.options.axis)
161
            slice_skullmap = nrrd_skullmap.get_slice(s, self.options.axis)
162
163
            # Skip the slice if it is "empty"
164
            # if numpy.max(slice_data) < empty_thresh:
165
            if numpy.percentile(slice_data, 90) < 0.2:
166
                continue
167
168
            # assert numpy.max(slice_data) <= 1.0, "Slice range is outside [0; 1]!"
169
170
            if self.options.sliceResolution is not None:
171
                # Pad withzeros to top and bottom, if the image is too small
172
                if slice_data.shape[0] < self.options.sliceResolution[0]:
173
                    before_y = math.floor((self.options.sliceResolution[0] - slice_data.shape[0]) / 2.0)
174
                    after_y = math.ceil((self.options.sliceResolution[0] - slice_data.shape[0]) / 2.0)
175
                if slice_data.shape[1] < self.options.sliceResolution[1]:
176
                    before_x = math.floor((self.options.sliceResolution[1] - slice_data.shape[1]) / 2.0)
177
                    after_x = math.ceil((self.options.sliceResolution[1] - slice_data.shape[1]) / 2.0)
178
                if slice_data.shape[0] < self.options.sliceResolution[0] or slice_data.shape[1] < self.options.sliceResolution[1]:
179
                    slice_data = np.pad(slice_data, ((before_y, after_y), (before_x, after_x)), 'constant', constant_values=(0, 0))
180
                    slice_seg = np.pad(slice_seg, ((before_y, after_y), (before_x, after_x)), 'constant', constant_values=(0, 0))
181
                slice_data = zoom(slice_data, float(self.options.sliceResolution[0]) / float(slice_data.shape[0]))
182
                slice_seg = zoom(slice_seg, float(self.options.sliceResolution[0]) / float(slice_seg.shape[0]), mode="nearest")
183
                slice_seg[slice_seg < 0.9] = 0.0
184
                slice_seg[slice_seg >= 0.9] = 1.0
185
186
            # Either collect crops
187
            if self.options.useCrops:
188
                if self.options.cropType == 'random':
189
                    rx = numpy.random.randint(0, high=(slice_data.shape[1] - self.options.cropWidth),
190
                                              size=self.options.numRandomCropsPerSlice)
191
                    ry = numpy.random.randint(0, high=(slice_data.shape[0] - self.options.cropHeight),
192
                                              size=self.options.numRandomCropsPerSlice)
193
                    for r in range(self.options.numRandomCropsPerSlice):
194
                        _images.append(crop(slice_data, ry(r), rx(r), self.options.cropHeight, self.options.cropWidth))
195
                        _labels.append(crop(slice_data, ry(r), rx(r), self.options.cropHeight, self.options.cropWidth))
196
                elif self.options.cropType == 'center':
197
                    slice_data_cropped = crop_center(slice_data, self.options.cropWidth, self.options.cropHeight)
198
                    slice_seg_cropped = crop_center(slice_seg, self.options.cropWidth, self.options.cropHeight)
199
                    _images.append(slice_data_cropped)
200
                    _labels.append(slice_seg_cropped)
201
                elif self.options.cropType == 'lesions':
202
                    cc_slice = label(slice_seg)
203
                    props = regionprops(cc_slice)
204
                    if len(props) > 0:
205
                        for prop in props:
206
                            cx = prop['centroid'][1]
207
                            cy = prop['centroid'][0]
208
                            if cy < self.options.cropHeight // 2:
209
                                cy = self.options.cropHeight // 2
210
                            if cy > (slice_data.shape[0] - (self.options.cropHeight // 2)):
211
                                cy = (slice_data.shape[0] - (self.options.cropHeight // 2))
212
                            if cx < self.options.cropWidth // 2:
213
                                cx = self.options.cropWidth // 2
214
                            if cx > (slice_data.shape[1] - (self.options.cropWidth // 2)):
215
                                cx = (slice_data.shape[1] - (self.options.cropWidth // 2))
216
                            image_crop = crop(slice_data, int(cy) - (self.options.cropHeight // 2), int(cx) - (self.options.cropWidth // 2),
217
                                              self.options.cropHeight, self.options.cropWidth)
218
                            seg_crop = crop(slice_seg, int(cy) - (self.options.cropHeight // 2), int(cx) - (self.options.cropWidth // 2),
219
                                            self.options.cropHeight, self.options.cropWidth)
220
                            if image_crop.shape[0] != self.options.cropHeight or image_crop.shape[1] != self.options.cropWidth:
221
                                continue
222
                            _images.append(image_crop)
223
                            _labels.append(seg_crop)
224
                            # _masks.append(crop(slice_data, prop['centroid'][0], prop['centroid'][1], self.options.cropHeight, self.options.cropWidth))
225
                        # find connected components in segmentation slice
226
                        # for every connected component, do a center crop from the segmentation slice, the mask and the actual slice
227
            # Or whole slices
228
            else:
229
                _images.append(slice_data)
230
                _labels.append(slice_seg)
231
232
        return _images, _labels
233
234
    def load_volume_and_groundtruth(self, nrrd_filename, patient):
235
        # Load the nrrd
236
        try:
237
            if self.options.format == "raw":
238
                nrrd = NRRD(nrrd_filename)
239
                nrrd_groundtruth = NRRD(patient['groundtruth'])
240
241
                nrrd.denoise()
242
                nrrd.set_view_mapping(self.options.viewMapping)
243
            elif self.options.format == "aligned":
244
                nrrd = NII(nrrd_filename)
245
                nrrd_groundtruth = NII(patient['groundtruth'])
246
                nrrd.denoise()
247
                nrrd.set_view_mapping(self.options.viewMapping)
248
        except:
249
            print('MSSEG2008: Failed to open file ' + nrrd_filename)
250
251
        # Make sure ground-truth is binary and nrrd doesnt have NaNs
252
        nrrd.data[np.isnan(nrrd.data)] = 0.0
253
        nrrd_groundtruth.data[nrrd_groundtruth.data < 0.9] = 0.0
254
        nrrd_groundtruth.data[nrrd_groundtruth.data >= 0.9] = 1.0
255
256
        # Do skull-stripping, if desired
257
        if self.options.skullStripping:
258
            try:
259
                nii_skullmap = NII(patient['skullmap'])
260
                nii_skullmap.set_view_mapping(self.options.viewMapping)
261
                nrrd.apply_skullmap(nii_skullmap)
262
            except:
263
                print('MSSEG2008: Failed to open file ' + patient['skullmap'] + ', skipping skullremoval')
264
265
        # In-place normalize the loaded volume
266
        nrrd.normalize(method=self.options.normalizationMethod, lowerpercentile=0, upperpercentile=99.8)
267
        # nrrd_skullmap.data = nrrd_skullmap.data > 0.0
268
269
        return nrrd, nrrd_groundtruth, nii_skullmap
270
271
    # Hidden helper function, not supposed to be called from outside!
272
    def _get_patients(self):
273
        return MSSEG2008.get_patients(self.options)
274
275
    @staticmethod
276
    def get_patients(options):
277
        folders = [options.folderTrainUNC, options.folderTestUNC, options.folderTrainCHB, options.folderTestCHB]
278
279
        # Iterate over all folderHC, folderNC, folderPC and collect patients
280
        patients = []
281
        for f, folder in enumerate(folders):
282
            if options.filterScanner and options.filterScanner not in folder:
283
                continue
284
            if options.filterType and options.filterType not in folder:
285
                continue
286
287
            # Get all files that can be used for training and validation
288
            _patients = [f.name for f in os.scandir(os.path.join(options.dir, folder)) if f.is_dir()]
289
            for p, pname in enumerate(_patients):
290
                patient = {
291
                    'name': pname,
292
                    'fullpath': os.path.join(options.dir, folder, pname)
293
                }
294
                if "train" in folder:
295
                    patient["type"] = "train"
296
                else:
297
                    patient["type"] = "test"
298
299
                patient["filtered_files"] = []
300
                for pr, protocol in enumerate(MSSEG2008.PROTOCOL_MAPPINGS):
301
                    if options.format == "raw":
302
                        patient[protocol] = os.path.join(options.dir, folder, pname, pname + '_' + protocol + '.nhdr')
303
                    elif options.format == "aligned":
304
                        patient[protocol] = os.path.join(options.dir, folder, pname, pname + '_' + protocol + '.aligned.nii.gz')
305
306
                    if len(options.filterProtocols) > 0 and protocol not in options.filterProtocols:
307
                        continue
308
                    else:
309
                        if options.format == "raw":
310
                            patient["filtered_files"] += [os.path.join(options.dir, folder, pname, pname + '_' + protocol + '.nhdr')]
311
                        elif options.format == "aligned":
312
                            patient["filtered_files"] += [os.path.join(options.dir, folder, pname, pname + '_' + protocol + '.aligned.nii.gz')]
313
314
                if options.format == "raw":
315
                    patient['groundtruth'] = os.path.join(options.dir, folder, pname, pname + '_lesion.nhdr')
316
                    patient['skullmap'] = os.path.join(options.dir, folder, pname, pname + '_skullmap.nhdr')
317
                elif options.format == "aligned":
318
                    patient['groundtruth'] = os.path.join(options.dir, folder, pname, pname + '_lesion.aligned.nii.gz')
319
                    patient['skullmap'] = os.path.join(options.dir, folder, pname, pname + '_skullmap.nii.gz')
320
321
                # Append to the list of all patients
322
                patients.append(patient)
323
324
        return patients
325
326
    # Returns the indices of patients which belong to either TRAIN, VAL or TEST. Your choice
327
    def get_patient_idx(self, split='TRAIN'):
328
        return self.patientsSplit[split]
329
330
    def get_patient_split(self):
331
        return self.patientsSplit
332
333
    @property
334
    def images(self):
335
        return self._images
336
337
    def get_images(self, set=None):
338
        _setIdx = self.SET_TYPES.index(set)
339
        images_in_set = numpy.where(self._sets == _setIdx)[0]
340
        return self._images[images_in_set]
341
342
    def get_image(self, i):
343
        return self._images[i, :, :, :]
344
345
    def get_label(self, i):
346
        return self._labels[i, :, :, :]
347
348
    def get_patient(self, i):
349
        return self.patients[i]
350
351
    @property
352
    def labels(self):
353
        return self._labels
354
355
    @property
356
    def sets(self):
357
        return self._sets
358
359
    @property
360
    def meta(self):
361
        return self._meta
362
363
    @property
364
    def num_examples(self):
365
        return self._images.shape[0]
366
367
    @property
368
    def width(self):
369
        return self._images.shape[2]
370
371
    @property
372
    def height(self):
373
        return self._images.shape[1]
374
375
    @property
376
    def num_channels(self):
377
        return self._images.shape[3]
378
379
    @property
380
    def epochs_completed(self):
381
        return self._epochs_completed
382
383
    def name(self):
384
        _name = "MSSEG2008"
385
        if self.options.filterScanner:
386
            _name += self.options.filterScanner
387
        if self.options.numSamples > 0:
388
            _name += '_n{}'.format(self.options.numSamples)
389
        _name += "_p{}-{}".format(self.options.partition['TRAIN'], self.options.partition['VAL'])
390
        if self.options.useCrops:
391
            _name += "_{}crops{}x{}".format(self.options.cropType, self.options.cropWidth, self.options.cropHeight)
392
            if self.options.cropType == "random":
393
                _name += "_{}cropsPerSlice".format(self.options.numRandomCropsPerSlice)
394
        if self.options.sliceResolution is not None:
395
            _name += "_res{}x{}".format(self.options.sliceResolution[0], self.options.sliceResolution[1])
396
        _name += "_{}".format(self.options.format)
397
        return _name
398
399
    def split_name(self):
400
        return os.path.join(self.dir(), 'split-{}-{}.pckl'.format(self.options.partition['TRAIN'], self.options.partition['VAL']))
401
402
    def pckl_name(self):
403
        return os.path.join(self.dir(), self.name() + ".pckl")
404
405
    def tfrecord_name(self):
406
        return os.path.join(self.dir(), self.name() + ".tfrecord")
407
408
    def dir(self):
409
        return self.options.dir
410
411
    def export_slices(self, dir):
412
        for i in range(self.num_examples):
413
            imwrite(os.path.join(dir, '{}.png'.format(i)), np.squeeze(self.get_image(i) * 255).astype('uint8'))
414
415
    def visualize(self, pause=1):
416
        f, (ax1, ax2) = matplotlib.pyplot.subplots(1, 2)
417
        images_tmp, labels_tmp, _ = self.next_batch(10)
418
        for i in range(images_tmp.shape[0]):
419
            img = numpy.squeeze(images_tmp[i])
420
            lbl = numpy.squeeze(labels_tmp[i])
421
            ax1.imshow(img)
422
            ax1.set_title('Patch')
423
            ax2.imshow(lbl)
424
            ax2.set_title('Groundtruth')
425
            matplotlib.pyplot.pause(pause)
426
427
    def num_batches(self, batchsize, set='TRAIN'):
428
        _setIdx = MSSEG2008.SET_TYPES.index(set)
429
        images_in_set = numpy.where(self._sets == _setIdx)[0]
430
        return len(images_in_set) // batchsize
431
432
    def next_batch(self, batch_size, shuffle=True, set='TRAIN', return_brainmask=True):
433
        """Return the next `batch_size` examples from this data set."""
434
        _setIdx = MSSEG2008.SET_TYPES.index(set)
435
        images_in_set = numpy.where(self._sets == _setIdx)[0]
436
        samples_in_set = len(images_in_set)
437
438
        start = self._index_in_epoch[set]
439
        # Shuffle for the first epoch
440
        if self._epochs_completed[set] == 0 and start == 0 and shuffle:
441
            perm0 = numpy.arange(samples_in_set)
442
            numpy.random.shuffle(perm0)
443
            self._images[images_in_set] = self.images[images_in_set[perm0]]
444
            self._labels[images_in_set] = self.labels[images_in_set[perm0]]
445
            self._sets[images_in_set] = self.sets[images_in_set[perm0]]
446
447
        # Go to the next epoch
448
        if start + batch_size > samples_in_set:
449
            # Finished epoch
450
            self._epochs_completed[set] += 1
451
452
            # Get the rest examples in this epoch
453
            rest_num_examples = samples_in_set - start
454
            images_rest_part = self._images[images_in_set[start:samples_in_set]]
455
            labels_rest_part = self._labels[images_in_set[start:samples_in_set]]
456
457
            # Shuffle the data
458
            if shuffle:
459
                perm = numpy.arange(samples_in_set)
460
                numpy.random.shuffle(perm)
461
                self._images[images_in_set] = self.images[images_in_set[perm]]
462
                self._labels[images_in_set] = self.labels[images_in_set[perm]]
463
                self._sets[images_in_set] = self.sets[images_in_set[perm]]
464
465
            # Start next epoch
466
            start = 0
467
            self._index_in_epoch[set] = batch_size - rest_num_examples
468
            end = self._index_in_epoch[set]
469
            images_new_part = self._images[images_in_set[start:end]]
470
            labels_new_part = self._labels[images_in_set[start:end]]
471
472
            images_tmp = numpy.concatenate((images_rest_part, images_new_part), axis=0)
473
            labels_tmp = numpy.concatenate((labels_rest_part, labels_new_part), axis=0)
474
        else:
475
            self._index_in_epoch[set] += batch_size
476
            end = self._index_in_epoch[set]
477
            images_tmp = self._images[images_in_set[start:end]]
478
            labels_tmp = self._labels[images_in_set[start:end]]
479
480
        if self.options.addInstanceNoise:
481
            noise = numpy.random.normal(0, 0.01, images_tmp.shape)
482
            images_tmp += noise
483
484
        # Check the batch
485
        assert images_tmp.size, "The batch is empty!"
486
        assert labels_tmp.size, "The labels of the current batch are empty!"
487
488
        if return_brainmask:
489
            brainmasks = images_tmp > 0.05
490
        else:
491
            brainmasks = None
492
493
        return images_tmp, labels_tmp, brainmasks