--- a
+++ b/dataloaders/MSSEG2008.py
@@ -0,0 +1,493 @@
+"""Functions for reading MSSEG2008 NRRD data."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import os.path
+import pickle
+
+import matplotlib.pyplot
+from imageio import imwrite
+from scipy.ndimage import zoom
+from six.moves import xrange  # pylint: disable=redefined-builtin
+from skimage.measure import label, regionprops
+
+from dataloaders.NRRD import *
+from utils.NII import *
+from utils.image_utils import crop, crop_center
+from utils.tfrecord_utils import *
+
+
+class MSSEG2008(object):
+    PROTOCOL_MAPPINGS = ['FLAIR', 'T1', 'T2']
+    SET_TYPES = ['TRAIN', 'VAL', 'TEST']
+
+    class Options(object):
+        def __init__(self):
+            self.dir = os.path.dirname(os.path.realpath(__file__))
+            self.folderTrainUNC = 'UNC_train'
+            self.folderTestUNC = 'UNC_test'
+            self.folderTrainCHB = 'CHB_train'
+            self.folderTestCHB = 'CHB_test'
+            self.numSamples = -1
+            self.partition = {'TRAIN': 0.7, 'VAL': 0.2, 'TEST': 0.1}
+            self.useCrops = False
+            self.cropType = 'random'  # random or center
+            self.numRandomCropsPerSlice = 5
+            self.onlyPatchesWithLesions = False
+            self.rotations = 0
+            self.cropWidth = 128
+            self.cropHeight = 128
+            self.cache = False
+            self.sliceResolution = None  # format: HxW
+            self.addInstanceNoise = False  # Affects only the batch sampling. If True, a tiny bit of noise will be added to every batch
+            self.filterProtocol = None  # FLAIR, T1, T2
+            self.filterScanner = "UNC"  # UNC or CHB
+            self.filterType = "train"  # train or test
+            self.axis = 'axial'  # saggital, coronal or axial
+            self.debug = False
+            self.normalizationMethod = 'standardization'
+            self.sliceStart = 0
+            self.sliceEnd = 155
+            self.format = "raw"  # raw or aligned; If aligned, nii-files will be crawled and loaded
+            self.skullStripping = True
+            self.viewMapping = {'saggital': 2, 'coronal': 1, 'axial': 0}
+
+    def __init__(self, options=Options()):
+        self.options = options
+
+        if options.cache and os.path.isfile(self.pckl_name()):
+            f = open(self.pckl_name(), 'rb')
+            tmp = pickle.load(f)
+            f.close()
+            self._epochs_completed = tmp._epochs_completed
+            self._index_in_epoch = tmp._index_in_epoch
+            self.patientsSplit = tmp.patients_split
+            self.patients = tmp.patients
+            self._images, self._labels, self._sets = read_tf_record(self.tfrecord_name())
+            self._epochs_completed = {'TRAIN': 0, 'VAL': 0, 'TEST': 0}
+            self._index_in_epoch = {'TRAIN': 0, 'VAL': 0, 'TEST': 0}
+        else:
+            # Collect all patients
+            self.patients = self._get_patients()
+            self.patientsSplit = {}
+
+            if not os.path.isfile(self.split_name()):
+                _numPatients = len(self.patients)
+                _ridx = numpy.random.permutation(_numPatients)
+
+                _already_taken = 0
+                for split in self.options.partition.keys():
+                    if self.options.partition[split] <= 1.0:
+                        numPatientsForCurrentSplit = math.floor(self.options.partition[split] * _numPatients)
+                    else:
+                        numPatientsForCurrentSplit = self.options.partition[split]
+
+                    if numPatientsForCurrentSplit > (_numPatients - _already_taken):
+                        numPatientsForCurrentSplit = _numPatients - _already_taken
+
+                    self.patientsSplit[split] = _ridx[_already_taken:_already_taken + numPatientsForCurrentSplit]
+                    _already_taken += numPatientsForCurrentSplit
+
+                f = open(self.split_name(), 'wb')
+                pickle.dump(self.patientsSplit, f)
+                f.close()
+            else:
+                f = open(self.split_name(), 'rb')
+                self.patientsSplit = pickle.load(f)
+                f.close()
+
+            self._create_numpy_arrays()
+
+            self._epochs_completed = {'TRAIN': 0, 'VAL': 0, 'TEST': 0}
+            self._index_in_epoch = {'TRAIN': 0, 'VAL': 0, 'TEST': 0}
+
+            if self.options.cache:
+                write_tf_record(self._images, self._labels, self._sets, self.tfrecord_name())
+                tmp = copy.copy(self)
+                tmp._images = None
+                tmp._labels = None
+                tmp._sets = None
+                f = open(self.pckl_name(), 'wb')
+                pickle.dump(tmp, f)
+                f.close()
+
+    def _create_numpy_arrays(self):
+        # Iterate over all patients and extract slices
+        _images = []
+        _labels = []
+        _sets = []
+        for p, patient in enumerate(self.patients):
+            if p in self.patientsSplit['TRAIN']:
+                _set_of_current_patient = MSSEG2008.SET_TYPES.index('TRAIN')
+            elif p in self.patientsSplit['VAL']:
+                _set_of_current_patient = MSSEG2008.SET_TYPES.index('VAL')
+            elif p in self.patientsSplit['TEST']:
+                _set_of_current_patient = MSSEG2008.SET_TYPES.index('TEST')
+
+            for n, nrrd_filename in enumerate(patient['filtered_files']):
+                # try:
+                _images_tmp, _labels_tmp = self.gather_data(patient, nrrd_filename)
+                _images += _images_tmp
+                _labels += _labels_tmp
+                # _mask += _mask_tmp
+                _sets += [_set_of_current_patient] * len(_images_tmp)
+                # except:
+                #  print('MSSEG2008: Failed to open file ' + nrrd_filename)
+                #  continue
+
+        self._images = numpy.array(_images).astype(numpy.float32)
+        self._labels = numpy.array(_labels).astype(numpy.float32)
+        if self._images.ndim < 4:
+            self._images = numpy.expand_dims(self._images, 3)
+        self._sets = numpy.array(_sets).astype(numpy.int32)
+
+    def gather_data(self, patient, nrrd_filename):
+        _images = []
+        _labels = []
+
+        nrrd, nrrd_seg, nrrd_skullmap = self.load_volume_and_groundtruth(nrrd_filename, patient)
+
+        # Iterate over all slices and collect them
+        # We only want to select in the range from 15 to 125 (in axial view)
+        for s in xrange(self.options.sliceStart, min(self.options.sliceEnd, nrrd.num_slices_along_axis(self.options.axis))):
+            if 0 < self.options.numSamples < len(_images):
+                break
+
+            slice_data = nrrd.get_slice(s, self.options.axis)
+            slice_seg = nrrd_seg.get_slice(s, self.options.axis)
+            slice_skullmap = nrrd_skullmap.get_slice(s, self.options.axis)
+
+            # Skip the slice if it is "empty"
+            # if numpy.max(slice_data) < empty_thresh:
+            if numpy.percentile(slice_data, 90) < 0.2:
+                continue
+
+            # assert numpy.max(slice_data) <= 1.0, "Slice range is outside [0; 1]!"
+
+            if self.options.sliceResolution is not None:
+                # Pad withzeros to top and bottom, if the image is too small
+                if slice_data.shape[0] < self.options.sliceResolution[0]:
+                    before_y = math.floor((self.options.sliceResolution[0] - slice_data.shape[0]) / 2.0)
+                    after_y = math.ceil((self.options.sliceResolution[0] - slice_data.shape[0]) / 2.0)
+                if slice_data.shape[1] < self.options.sliceResolution[1]:
+                    before_x = math.floor((self.options.sliceResolution[1] - slice_data.shape[1]) / 2.0)
+                    after_x = math.ceil((self.options.sliceResolution[1] - slice_data.shape[1]) / 2.0)
+                if slice_data.shape[0] < self.options.sliceResolution[0] or slice_data.shape[1] < self.options.sliceResolution[1]:
+                    slice_data = np.pad(slice_data, ((before_y, after_y), (before_x, after_x)), 'constant', constant_values=(0, 0))
+                    slice_seg = np.pad(slice_seg, ((before_y, after_y), (before_x, after_x)), 'constant', constant_values=(0, 0))
+                slice_data = zoom(slice_data, float(self.options.sliceResolution[0]) / float(slice_data.shape[0]))
+                slice_seg = zoom(slice_seg, float(self.options.sliceResolution[0]) / float(slice_seg.shape[0]), mode="nearest")
+                slice_seg[slice_seg < 0.9] = 0.0
+                slice_seg[slice_seg >= 0.9] = 1.0
+
+            # Either collect crops
+            if self.options.useCrops:
+                if self.options.cropType == 'random':
+                    rx = numpy.random.randint(0, high=(slice_data.shape[1] - self.options.cropWidth),
+                                              size=self.options.numRandomCropsPerSlice)
+                    ry = numpy.random.randint(0, high=(slice_data.shape[0] - self.options.cropHeight),
+                                              size=self.options.numRandomCropsPerSlice)
+                    for r in range(self.options.numRandomCropsPerSlice):
+                        _images.append(crop(slice_data, ry(r), rx(r), self.options.cropHeight, self.options.cropWidth))
+                        _labels.append(crop(slice_data, ry(r), rx(r), self.options.cropHeight, self.options.cropWidth))
+                elif self.options.cropType == 'center':
+                    slice_data_cropped = crop_center(slice_data, self.options.cropWidth, self.options.cropHeight)
+                    slice_seg_cropped = crop_center(slice_seg, self.options.cropWidth, self.options.cropHeight)
+                    _images.append(slice_data_cropped)
+                    _labels.append(slice_seg_cropped)
+                elif self.options.cropType == 'lesions':
+                    cc_slice = label(slice_seg)
+                    props = regionprops(cc_slice)
+                    if len(props) > 0:
+                        for prop in props:
+                            cx = prop['centroid'][1]
+                            cy = prop['centroid'][0]
+                            if cy < self.options.cropHeight // 2:
+                                cy = self.options.cropHeight // 2
+                            if cy > (slice_data.shape[0] - (self.options.cropHeight // 2)):
+                                cy = (slice_data.shape[0] - (self.options.cropHeight // 2))
+                            if cx < self.options.cropWidth // 2:
+                                cx = self.options.cropWidth // 2
+                            if cx > (slice_data.shape[1] - (self.options.cropWidth // 2)):
+                                cx = (slice_data.shape[1] - (self.options.cropWidth // 2))
+                            image_crop = crop(slice_data, int(cy) - (self.options.cropHeight // 2), int(cx) - (self.options.cropWidth // 2),
+                                              self.options.cropHeight, self.options.cropWidth)
+                            seg_crop = crop(slice_seg, int(cy) - (self.options.cropHeight // 2), int(cx) - (self.options.cropWidth // 2),
+                                            self.options.cropHeight, self.options.cropWidth)
+                            if image_crop.shape[0] != self.options.cropHeight or image_crop.shape[1] != self.options.cropWidth:
+                                continue
+                            _images.append(image_crop)
+                            _labels.append(seg_crop)
+                            # _masks.append(crop(slice_data, prop['centroid'][0], prop['centroid'][1], self.options.cropHeight, self.options.cropWidth))
+                        # find connected components in segmentation slice
+                        # for every connected component, do a center crop from the segmentation slice, the mask and the actual slice
+            # Or whole slices
+            else:
+                _images.append(slice_data)
+                _labels.append(slice_seg)
+
+        return _images, _labels
+
+    def load_volume_and_groundtruth(self, nrrd_filename, patient):
+        # Load the nrrd
+        try:
+            if self.options.format == "raw":
+                nrrd = NRRD(nrrd_filename)
+                nrrd_groundtruth = NRRD(patient['groundtruth'])
+
+                nrrd.denoise()
+                nrrd.set_view_mapping(self.options.viewMapping)
+            elif self.options.format == "aligned":
+                nrrd = NII(nrrd_filename)
+                nrrd_groundtruth = NII(patient['groundtruth'])
+                nrrd.denoise()
+                nrrd.set_view_mapping(self.options.viewMapping)
+        except:
+            print('MSSEG2008: Failed to open file ' + nrrd_filename)
+
+        # Make sure ground-truth is binary and nrrd doesnt have NaNs
+        nrrd.data[np.isnan(nrrd.data)] = 0.0
+        nrrd_groundtruth.data[nrrd_groundtruth.data < 0.9] = 0.0
+        nrrd_groundtruth.data[nrrd_groundtruth.data >= 0.9] = 1.0
+
+        # Do skull-stripping, if desired
+        if self.options.skullStripping:
+            try:
+                nii_skullmap = NII(patient['skullmap'])
+                nii_skullmap.set_view_mapping(self.options.viewMapping)
+                nrrd.apply_skullmap(nii_skullmap)
+            except:
+                print('MSSEG2008: Failed to open file ' + patient['skullmap'] + ', skipping skullremoval')
+
+        # In-place normalize the loaded volume
+        nrrd.normalize(method=self.options.normalizationMethod, lowerpercentile=0, upperpercentile=99.8)
+        # nrrd_skullmap.data = nrrd_skullmap.data > 0.0
+
+        return nrrd, nrrd_groundtruth, nii_skullmap
+
+    # Hidden helper function, not supposed to be called from outside!
+    def _get_patients(self):
+        return MSSEG2008.get_patients(self.options)
+
+    @staticmethod
+    def get_patients(options):
+        folders = [options.folderTrainUNC, options.folderTestUNC, options.folderTrainCHB, options.folderTestCHB]
+
+        # Iterate over all folderHC, folderNC, folderPC and collect patients
+        patients = []
+        for f, folder in enumerate(folders):
+            if options.filterScanner and options.filterScanner not in folder:
+                continue
+            if options.filterType and options.filterType not in folder:
+                continue
+
+            # Get all files that can be used for training and validation
+            _patients = [f.name for f in os.scandir(os.path.join(options.dir, folder)) if f.is_dir()]
+            for p, pname in enumerate(_patients):
+                patient = {
+                    'name': pname,
+                    'fullpath': os.path.join(options.dir, folder, pname)
+                }
+                if "train" in folder:
+                    patient["type"] = "train"
+                else:
+                    patient["type"] = "test"
+
+                patient["filtered_files"] = []
+                for pr, protocol in enumerate(MSSEG2008.PROTOCOL_MAPPINGS):
+                    if options.format == "raw":
+                        patient[protocol] = os.path.join(options.dir, folder, pname, pname + '_' + protocol + '.nhdr')
+                    elif options.format == "aligned":
+                        patient[protocol] = os.path.join(options.dir, folder, pname, pname + '_' + protocol + '.aligned.nii.gz')
+
+                    if len(options.filterProtocols) > 0 and protocol not in options.filterProtocols:
+                        continue
+                    else:
+                        if options.format == "raw":
+                            patient["filtered_files"] += [os.path.join(options.dir, folder, pname, pname + '_' + protocol + '.nhdr')]
+                        elif options.format == "aligned":
+                            patient["filtered_files"] += [os.path.join(options.dir, folder, pname, pname + '_' + protocol + '.aligned.nii.gz')]
+
+                if options.format == "raw":
+                    patient['groundtruth'] = os.path.join(options.dir, folder, pname, pname + '_lesion.nhdr')
+                    patient['skullmap'] = os.path.join(options.dir, folder, pname, pname + '_skullmap.nhdr')
+                elif options.format == "aligned":
+                    patient['groundtruth'] = os.path.join(options.dir, folder, pname, pname + '_lesion.aligned.nii.gz')
+                    patient['skullmap'] = os.path.join(options.dir, folder, pname, pname + '_skullmap.nii.gz')
+
+                # Append to the list of all patients
+                patients.append(patient)
+
+        return patients
+
+    # Returns the indices of patients which belong to either TRAIN, VAL or TEST. Your choice
+    def get_patient_idx(self, split='TRAIN'):
+        return self.patientsSplit[split]
+
+    def get_patient_split(self):
+        return self.patientsSplit
+
+    @property
+    def images(self):
+        return self._images
+
+    def get_images(self, set=None):
+        _setIdx = self.SET_TYPES.index(set)
+        images_in_set = numpy.where(self._sets == _setIdx)[0]
+        return self._images[images_in_set]
+
+    def get_image(self, i):
+        return self._images[i, :, :, :]
+
+    def get_label(self, i):
+        return self._labels[i, :, :, :]
+
+    def get_patient(self, i):
+        return self.patients[i]
+
+    @property
+    def labels(self):
+        return self._labels
+
+    @property
+    def sets(self):
+        return self._sets
+
+    @property
+    def meta(self):
+        return self._meta
+
+    @property
+    def num_examples(self):
+        return self._images.shape[0]
+
+    @property
+    def width(self):
+        return self._images.shape[2]
+
+    @property
+    def height(self):
+        return self._images.shape[1]
+
+    @property
+    def num_channels(self):
+        return self._images.shape[3]
+
+    @property
+    def epochs_completed(self):
+        return self._epochs_completed
+
+    def name(self):
+        _name = "MSSEG2008"
+        if self.options.filterScanner:
+            _name += self.options.filterScanner
+        if self.options.numSamples > 0:
+            _name += '_n{}'.format(self.options.numSamples)
+        _name += "_p{}-{}".format(self.options.partition['TRAIN'], self.options.partition['VAL'])
+        if self.options.useCrops:
+            _name += "_{}crops{}x{}".format(self.options.cropType, self.options.cropWidth, self.options.cropHeight)
+            if self.options.cropType == "random":
+                _name += "_{}cropsPerSlice".format(self.options.numRandomCropsPerSlice)
+        if self.options.sliceResolution is not None:
+            _name += "_res{}x{}".format(self.options.sliceResolution[0], self.options.sliceResolution[1])
+        _name += "_{}".format(self.options.format)
+        return _name
+
+    def split_name(self):
+        return os.path.join(self.dir(), 'split-{}-{}.pckl'.format(self.options.partition['TRAIN'], self.options.partition['VAL']))
+
+    def pckl_name(self):
+        return os.path.join(self.dir(), self.name() + ".pckl")
+
+    def tfrecord_name(self):
+        return os.path.join(self.dir(), self.name() + ".tfrecord")
+
+    def dir(self):
+        return self.options.dir
+
+    def export_slices(self, dir):
+        for i in range(self.num_examples):
+            imwrite(os.path.join(dir, '{}.png'.format(i)), np.squeeze(self.get_image(i) * 255).astype('uint8'))
+
+    def visualize(self, pause=1):
+        f, (ax1, ax2) = matplotlib.pyplot.subplots(1, 2)
+        images_tmp, labels_tmp, _ = self.next_batch(10)
+        for i in range(images_tmp.shape[0]):
+            img = numpy.squeeze(images_tmp[i])
+            lbl = numpy.squeeze(labels_tmp[i])
+            ax1.imshow(img)
+            ax1.set_title('Patch')
+            ax2.imshow(lbl)
+            ax2.set_title('Groundtruth')
+            matplotlib.pyplot.pause(pause)
+
+    def num_batches(self, batchsize, set='TRAIN'):
+        _setIdx = MSSEG2008.SET_TYPES.index(set)
+        images_in_set = numpy.where(self._sets == _setIdx)[0]
+        return len(images_in_set) // batchsize
+
+    def next_batch(self, batch_size, shuffle=True, set='TRAIN', return_brainmask=True):
+        """Return the next `batch_size` examples from this data set."""
+        _setIdx = MSSEG2008.SET_TYPES.index(set)
+        images_in_set = numpy.where(self._sets == _setIdx)[0]
+        samples_in_set = len(images_in_set)
+
+        start = self._index_in_epoch[set]
+        # Shuffle for the first epoch
+        if self._epochs_completed[set] == 0 and start == 0 and shuffle:
+            perm0 = numpy.arange(samples_in_set)
+            numpy.random.shuffle(perm0)
+            self._images[images_in_set] = self.images[images_in_set[perm0]]
+            self._labels[images_in_set] = self.labels[images_in_set[perm0]]
+            self._sets[images_in_set] = self.sets[images_in_set[perm0]]
+
+        # Go to the next epoch
+        if start + batch_size > samples_in_set:
+            # Finished epoch
+            self._epochs_completed[set] += 1
+
+            # Get the rest examples in this epoch
+            rest_num_examples = samples_in_set - start
+            images_rest_part = self._images[images_in_set[start:samples_in_set]]
+            labels_rest_part = self._labels[images_in_set[start:samples_in_set]]
+
+            # Shuffle the data
+            if shuffle:
+                perm = numpy.arange(samples_in_set)
+                numpy.random.shuffle(perm)
+                self._images[images_in_set] = self.images[images_in_set[perm]]
+                self._labels[images_in_set] = self.labels[images_in_set[perm]]
+                self._sets[images_in_set] = self.sets[images_in_set[perm]]
+
+            # Start next epoch
+            start = 0
+            self._index_in_epoch[set] = batch_size - rest_num_examples
+            end = self._index_in_epoch[set]
+            images_new_part = self._images[images_in_set[start:end]]
+            labels_new_part = self._labels[images_in_set[start:end]]
+
+            images_tmp = numpy.concatenate((images_rest_part, images_new_part), axis=0)
+            labels_tmp = numpy.concatenate((labels_rest_part, labels_new_part), axis=0)
+        else:
+            self._index_in_epoch[set] += batch_size
+            end = self._index_in_epoch[set]
+            images_tmp = self._images[images_in_set[start:end]]
+            labels_tmp = self._labels[images_in_set[start:end]]
+
+        if self.options.addInstanceNoise:
+            noise = numpy.random.normal(0, 0.01, images_tmp.shape)
+            images_tmp += noise
+
+        # Check the batch
+        assert images_tmp.size, "The batch is empty!"
+        assert labels_tmp.size, "The labels of the current batch are empty!"
+
+        if return_brainmask:
+            brainmasks = images_tmp > 0.05
+        else:
+            brainmasks = None
+
+        return images_tmp, labels_tmp, brainmasks