--- a +++ b/dataloaders/MSISBI2015.py @@ -0,0 +1,473 @@ +"""Functions for reading MSISBI2015 NRRD data.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import glob +import math +import os.path +import pickle + +import matplotlib.pyplot +from imageio import imwrite +from scipy.ndimage import zoom +from six.moves import xrange +from skimage.measure import label, regionprops + +from utils.NII import * +from utils.image_utils import crop, crop_center +from utils.tfrecord_utils import * + +matplotlib.pyplot.ion() + + +class MSISBI2015(object): + PROTOCOL_MAPPINGS = {'FLAIR': ['flair'], 'MPRAGE': ['mprage'], 'PD': ['pd'], 'T2': ['t2']} + SET_TYPES = ['TRAIN', 'VAL', 'TEST'] + + class Options(object): + def __init__(self): + self.dir = os.path.dirname(os.path.realpath(__file__)) + self.numSamples = -1 + self.partition = {'TRAIN': 0.7, 'VAL': 0.2, 'TEST': 0.1} + self.useCrops = False + self.cropType = 'random' # random or center + self.numRandomCropsPerSlice = 5 + self.onlyPatchesWithLesions = False + self.rotations = 0 + self.cropWidth = 128 + self.cropHeight = 128 + self.cache = False + self.sliceResolution = None # format: HxW + self.addInstanceNoise = False # Affects only the batch sampling. If True, a tiny bit of noise will be added to every batch + self.filterProtocol = None # FLAIR, T1, T2 + self.filterType = "train" # train or test + self.axis = 'axial' # saggital, coronal or axial + self.debug = False + self.normalizationMethod = 'standardization' + self.sliceStart = 0 + self.sliceEnd = 155 + self.format = "raw" # raw or aligned; If aligned, nii-files will be crawled and loaded + self.skullStripping = True + self.viewMapping = {'saggital': 2, 'coronal': 1, 'axial': 0} + + def __init__(self, options=Options()): + self.options = options + + if options.cache and os.path.isfile(self.pckl_name()): + f = open(self.pckl_name(), 'rb') + tmp = pickle.load(f) + f.close() + self._epochs_completed = tmp._epochs_completed + self._index_in_epoch = tmp._index_in_epoch + self.patientsSplit = tmp.patients_split + self.patients = tmp.patients + self._images, self._labels, self._sets = read_tf_record(self.tfrecord_name()) + self._epochs_completed = {'TRAIN': 0, 'VAL': 0, 'TEST': 0} + self._index_in_epoch = {'TRAIN': 0, 'VAL': 0, 'TEST': 0} + else: + # Collect all patients + self.patients = self._get_patients() + self.patientsSplit = {} + + if not os.path.isfile(self.split_name()): + _numPatients = len(self.patients) + _ridx = numpy.random.permutation(_numPatients) + + _already_taken = 0 + for split in self.options.partition.keys(): + if self.options.partition[split] <= 1.0: + num_patients_for_current_split = math.floor(self.options.partition[split] * _numPatients) + else: + num_patients_for_current_split = self.options.partition[split] + + if num_patients_for_current_split > (_numPatients - _already_taken): + num_patients_for_current_split = _numPatients - _already_taken + + self.patientsSplit[split] = _ridx[_already_taken:_already_taken + num_patients_for_current_split] + _already_taken += num_patients_for_current_split + + f = open(self.split_name(), 'wb') + pickle.dump(self.patientsSplit, f) + f.close() + else: + f = open(self.split_name(), 'rb') + self.patientsSplit = pickle.load(f) + f.close() + + self._create_numpy_arrays() + + self._epochs_completed = {'TRAIN': 0, 'VAL': 0, 'TEST': 0} + self._index_in_epoch = {'TRAIN': 0, 'VAL': 0, 'TEST': 0} + + if self.options.cache: + write_tf_record(self._images, self._labels, self._sets, self.tfrecord_name()) + tmp = copy.copy(self) + tmp._images = None + tmp._labels = None + tmp._sets = None + f = open(self.pckl_name(), 'wb') + pickle.dump(tmp, f) + f.close() + + def _create_numpy_arrays(self): + # Iterate over all patients and extract slices + _images = [] + _labels = [] + _sets = [] + for p, patient in enumerate(self.patients): + if p in self.patientsSplit['TRAIN']: + _set_of_current_patient = MSISBI2015.SET_TYPES.index('TRAIN') + elif p in self.patientsSplit['VAL']: + _set_of_current_patient = MSISBI2015.SET_TYPES.index('VAL') + elif p in self.patientsSplit['TEST']: + _set_of_current_patient = MSISBI2015.SET_TYPES.index('TEST') + + for n, nii_filename in enumerate(patient['filtered_files']): + # try: + _images_tmp, _labels_tmp = self.gather_data(patient, nii_filename) + _images += _images_tmp + _labels += _labels_tmp + _sets += [_set_of_current_patient] * len(_images_tmp) + + self._images = numpy.array(_images).astype(numpy.float32) + self._labels = numpy.array(_labels).astype(numpy.float32) + if self._images.ndim < 4: + self._images = numpy.expand_dims(self._images, 3) + self._sets = numpy.array(_sets).astype(numpy.int32) + + def gather_data(self, patient, nii_filename): + _images = [] + _labels = [] + + nii, nii_seg, nii_skullmap = self.load_volume_and_groundtruth(nii_filename, patient) + + # Iterate over all slices and collect them + # We only want to select in the range from 15 to 125 (in axial view) + for s in xrange(self.options.sliceStart, min(self.options.sliceEnd, nii.num_slices_along_axis(self.options.axis))): + if 0 < self.options.numSamples < len(_images): + break + + slice_data = nii.get_slice(s, self.options.axis) + slice_seg = nii_seg.get_slice(s, self.options.axis) + + # Skip the slice if it is "empty" + if numpy.percentile(slice_data, 90) < 0.2: + continue + + if self.options.sliceResolution is not None: + # Pad withzeros to top and bottom, if the image is too small + if slice_data.shape[0] < self.options.sliceResolution[0]: + before_y = math.floor((self.options.sliceResolution[0] - slice_data.shape[0]) / 2.0) + after_y = math.ceil((self.options.sliceResolution[0] - slice_data.shape[0]) / 2.0) + if slice_data.shape[1] < self.options.sliceResolution[1]: + before_x = math.floor((self.options.sliceResolution[1] - slice_data.shape[1]) / 2.0) + after_x = math.ceil((self.options.sliceResolution[1] - slice_data.shape[1]) / 2.0) + if slice_data.shape[0] < self.options.sliceResolution[0] or slice_data.shape[1] < self.options.sliceResolution[1]: + slice_data = np.pad(slice_data, ((before_y, after_y), (before_x, after_x)), 'constant', constant_values=(0, 0)) + slice_seg = np.pad(slice_seg, ((before_y, after_y), (before_x, after_x)), 'constant', constant_values=(0, 0)) + slice_data = zoom(slice_data, float(self.options.sliceResolution[0]) / float( + slice_data.shape[0])) + slice_seg = zoom( + slice_seg, float(self.options.sliceResolution[0]) / float(slice_seg.shape[0]), mode="nearest" + ) + slice_seg[slice_seg < 0.9] = 0.0 + slice_seg[slice_seg >= 0.9] = 1.0 + # assert numpy.max(slice_data) <= 1.0, "Resized slice range is outside [0; 1]!" + + # Either collect crops + if self.options.useCrops: + if self.options.cropType == 'random': + rx = numpy.random.randint(0, high=(slice_data.shape[1] - self.options.cropWidth), + size=self.options.numRandomCropsPerSlice) + ry = numpy.random.randint(0, high=(slice_data.shape[0] - self.options.cropHeight), + size=self.options.numRandomCropsPerSlice) + for r in range(self.options.numRandomCropsPerSlice): + _images.append(crop(slice_data, ry(r), rx(r), self.options.cropHeight, self.options.cropWidth)) + _labels.append(crop(slice_data, ry(r), rx(r), self.options.cropHeight, self.options.cropWidth)) + elif self.options.cropType == 'center': + slice_data_cropped = crop_center(slice_data, self.options.cropWidth, self.options.cropHeight) + slice_seg_cropped = crop_center(slice_seg, self.options.cropWidth, self.options.cropHeight) + _images.append(slice_data_cropped) + _labels.append(slice_seg_cropped) + elif self.options.cropType == 'lesions': + cc_slice = label(slice_seg) + props = regionprops(cc_slice) + if len(props) > 0: + for prop in props: + cx = prop['centroid'][1] + cy = prop['centroid'][0] + if cy < self.options.cropHeight // 2: + cy = self.options.cropHeight // 2 + if cy > (slice_data.shape[0] - (self.options.cropHeight // 2)): + cy = (slice_data.shape[0] - (self.options.cropHeight // 2)) + if cx < self.options.cropWidth // 2: + cx = self.options.cropWidth // 2 + if cx > (slice_data.shape[1] - (self.options.cropWidth // 2)): + cx = (slice_data.shape[1] - (self.options.cropWidth // 2)) + image_crop = crop(slice_data, int(cy) - (self.options.cropHeight // 2), int(cx) - (self.options.cropWidth // 2), + self.options.cropHeight, self.options.cropWidth) + seg_crop = crop(slice_seg, int(cy) - (self.options.cropHeight // 2), int(cx) - (self.options.cropWidth // 2), + self.options.cropHeight, self.options.cropWidth) + if image_crop.shape[0] != self.options.cropHeight or image_crop.shape[1] != self.options.cropWidth: + continue + _images.append(image_crop) + _labels.append(seg_crop) + + # Or whole slices + else: + _images.append(slice_data) + _labels.append(slice_seg) + + return _images, _labels + + def load_volume_and_groundtruth(self, nii_filename, patient): + # Load the nrrd + try: + nii = NII(nii_filename) + nii_groundtruth = NII(patient['groundtruth']) + + nii.denoise() + nii.set_view_mapping(self.options.viewMapping) + except: + print('MSISBI2015: Failed to open file ' + nii_filename) + + # Make sure ground-truth is binary and nrrd doesnt have NaNs + nii.data[np.isnan(nii.data)] = 0.0 + nii_groundtruth.data[nii_groundtruth.data < 0.9] = 0.0 + nii_groundtruth.data[nii_groundtruth.data >= 0.9] = 1.0 + + # Do skull-stripping, if desired + if self.options.skullStripping: + try: + nii_skullmap = NII(patient['skullmap']) + nii_skullmap.set_view_mapping(self.options.viewMapping) + nii.apply_skullmap(nii_skullmap) + except: + print('MSISBI2015: Failed to open file ' + patient['skullmap'] + ', skipping skullremoval') + + # In-place normalize the loaded volume + nii.normalize(method=self.options.normalizationMethod, lowerpercentile=0, upperpercentile=99.8) + # nii_skullmap.data = nii_skullmap.data > 0.0 + + return nii, nii_groundtruth, nii_skullmap + + # Hidden helper function, not supposed to be called from outside! + def _get_patients(self): + return MSISBI2015.get_patients(self.options) + + @staticmethod + def get_patients(options): + folders = ["training01", "training02", "training03", "training04", "training05"] + + # Iterate over all folders in folders and collect all patients + patients = [] + for f, folder in enumerate(folders): + # Get all files that can be used for training and validation + _patients = glob.glob(os.path.join(options.dir, folder, "preprocessed", folder + "_*_flair_pp.nii")) + for p, pname in enumerate(_patients): + patient = {} + _tmp = os.path.normpath(pname).split(os.path.sep) + patient['name'] = _tmp[-1].replace("_flair_pp.nii", "") + patient['fullpath'] = os.path.join(options.dir, folder, "preprocessed") + + patient["filtered_files"] = [] + for protocol, protocol_array in MSISBI2015.PROTOCOL_MAPPINGS.items(): + if len(options.filterProtocols) > 0 and protocol not in options.filterProtocols: + continue + else: + if options.format == "raw": + patient[protocol] = os.path.join(options.dir, folder, "preprocessed", patient['name'] + '_' + protocol_array[0] + '_pp.nii') + elif options.format == "aligned": + patient[protocol] = os.path.join(options.dir, folder, "preprocessed", + patient['name'] + '_' + protocol_array[0] + '.aligned.nii.gz') + + if len(options.filterProtocols) > 0 and protocol not in options.filterProtocols: + continue + else: + if options.format == "raw": + patient["filtered_files"] += [ + os.path.join(options.dir, folder, "preprocessed", patient['name'] + '_' + protocol_array[0] + '_pp.nii')] + elif options.format == "aligned": + patient["filtered_files"] += [ + os.path.join(options.dir, folder, "preprocessed", patient['name'] + '_' + protocol_array[0] + '.aligned.nii.gz')] + + if options.format == "raw": + patient['groundtruth'] = os.path.join(options.dir, folder, "masks", patient['name'] + "_mask1.nii") + patient['skullmap'] = os.path.join(options.dir, folder, "preprocessed", patient['name'] + "_skullmap.nii.gz") + elif options.format == "aligned": + patient['groundtruth'] = os.path.join(options.dir, folder, "preprocessed", patient['name'] + "_mask1.aligned.nii.gz") + patient['skullmap'] = os.path.join(options.dir, folder, "preprocessed", patient['name'] + "_skullmap.aligned.nii.gz") + + # Append to the list of all patients + patients.append(patient) + + return patients + + # Returns the indices of patients which belong to either TRAIN, VAL or TEST. Your choice + def get_patient_idx(self, split='TRAIN'): + return self.patientsSplit[split] + + def get_patient_split(self): + return self.patientsSplit + + @property + def images(self): + return self._images + + def get_images(self, set=None): + _setIdx = MSISBI2015.SET_TYPES.index(set) + images_in_set = numpy.where(self._sets == _setIdx)[0] + return self._images[images_in_set] + + def get_image(self, i): + return self._images[i, :, :, :] + + def get_label(self, i): + return self._labels[i, :, :, :] + + def get_patient(self, i): + return self.patients[i] + + @property + def labels(self): + return self._labels + + @property + def sets(self): + return self._sets + + @property + def meta(self): + return self._meta + + @property + def num_examples(self): + return self._images.shape[0] + + @property + def width(self): + return self._images.shape[2] + + @property + def height(self): + return self._images.shape[1] + + @property + def num_channels(self): + return self._images.shape[3] + + @property + def epochs_completed(self): + return self._epochs_completed + + def name(self): + _name = "MSISBI2015" + if self.options.numSamples > 0: + _name += '_n{}'.format(self.options.numSamples) + _name += "_p{}-{}".format(self.options.partition['TRAIN'], self.options.partition['VAL']) + if self.options.useCrops: + _name += "_{}crops{}x{}".format(self.options.cropType, self.options.cropWidth, self.options.cropHeight) + if self.options.cropType == "random": + _name += "_{}cropsPerSlice".format(self.options.numRandomCropsPerSlice) + if self.options.sliceResolution is not None: + _name += "_res{}x{}".format(self.options.sliceResolution[0], self.options.sliceResolution[1]) + _name += "_{}".format(self.options.format) + return _name + + def split_name(self): + return os.path.join(self.dir(), 'split-{}-{}.pckl'.format(self.options.partition['TRAIN'], self.options.partition['VAL'])) + + def pckl_name(self): + return os.path.join(self.dir(), self.name() + ".pckl") + + def tfrecord_name(self): + return os.path.join(self.dir(), self.name() + ".tfrecord") + + def dir(self): + return self.options.dir + + def export_slices(self, dir): + for i in range(self.num_examples): + imwrite(os.path.join(dir, '{}.png'.format(i)), np.squeeze(self.get_image(i) * 255).astype('uint8')) + + def visualize(self, pause=1): + f, (ax1, ax2) = matplotlib.pyplot.subplots(1, 2) + images_tmp, labels_tmp, _ = self.next_batch(10) + for i in range(images_tmp.shape[0]): + img = numpy.squeeze(images_tmp[i]) + lbl = numpy.squeeze(labels_tmp[i]) + ax1.imshow(img) + ax1.set_title('Patch') + ax2.imshow(lbl) + ax2.set_title('Groundtruth') + matplotlib.pyplot.pause(pause) + + def num_batches(self, batchsize, set='TRAIN'): + _setIdx = MSISBI2015.SET_TYPES.index(set) + images_in_set = numpy.where(self._sets == _setIdx)[0] + return len(images_in_set) // batchsize + + def next_batch(self, batch_size, shuffle=True, set='TRAIN', return_brainmask=True): + """Return the next `batch_size` examples from this data set.""" + _setIdx = MSISBI2015.SET_TYPES.index(set) + images_in_set = numpy.where(self._sets == _setIdx)[0] + samples_in_set = len(images_in_set) + + start = self._index_in_epoch[set] + # Shuffle for the first epoch + if self._epochs_completed[set] == 0 and start == 0 and shuffle: + perm0 = numpy.arange(samples_in_set) + numpy.random.shuffle(perm0) + self._images[images_in_set] = self.images[images_in_set[perm0]] + self._labels[images_in_set] = self.labels[images_in_set[perm0]] + self._sets[images_in_set] = self.sets[images_in_set[perm0]] + + # Go to the next epoch + if start + batch_size > samples_in_set: + # Finished epoch + self._epochs_completed[set] += 1 + + # Get the rest examples in this epoch + rest_num_examples = samples_in_set - start + images_rest_part = self._images[images_in_set[start:samples_in_set]] + labels_rest_part = self._labels[images_in_set[start:samples_in_set]] + + # Shuffle the data + if shuffle: + perm = numpy.arange(samples_in_set) + numpy.random.shuffle(perm) + self._images[images_in_set] = self.images[images_in_set[perm]] + self._labels[images_in_set] = self.labels[images_in_set[perm]] + self._sets[images_in_set] = self.sets[images_in_set[perm]] + + # Start next epoch + start = 0 + self._index_in_epoch[set] = batch_size - rest_num_examples + end = self._index_in_epoch[set] + images_new_part = self._images[images_in_set[start:end]] + labels_new_part = self._labels[images_in_set[start:end]] + + images_tmp = numpy.concatenate((images_rest_part, images_new_part), axis=0) + labels_tmp = numpy.concatenate((labels_rest_part, labels_new_part), axis=0) + else: + self._index_in_epoch[set] += batch_size + end = self._index_in_epoch[set] + images_tmp = self._images[images_in_set[start:end]] + labels_tmp = self._labels[images_in_set[start:end]] + + if self.options.addInstanceNoise: + noise = numpy.random.normal(0, 0.01, images_tmp.shape) + images_tmp += noise + + # Check the batch + assert images_tmp.size, "The batch is empty!" + assert labels_tmp.size, "The labels of the current batch are empty!" + + if return_brainmask: + brainmasks = images_tmp > 0.05 + else: + brainmasks = None + + return images_tmp, labels_tmp, brainmasks