--- a +++ b/dataset.py @@ -0,0 +1,136 @@ +import os +import random + +import numpy as np +import torch +from skimage.io import imread +from torch.utils.data import Dataset + +from utils import crop_sample, pad_sample, resize_sample, normalize_volume + + +class BrainSegmentationDataset(Dataset): + """Brain MRI dataset for FLAIR abnormality segmentation""" + + in_channels = 3 + out_channels = 1 + + def __init__( + self, + images_dir, + transform=None, + image_size=256, + subset="train", + random_sampling=True, + validation_cases=10, + seed=42, + ): + assert subset in ["all", "train", "validation"] + + # read images + volumes = {} + masks = {} + print("reading {} images...".format(subset)) + for (dirpath, dirnames, filenames) in os.walk(images_dir): + image_slices = [] + mask_slices = [] + for filename in sorted( + filter(lambda f: ".tif" in f, filenames), + key=lambda x: int(x.split(".")[-2].split("_")[4]), + ): + filepath = os.path.join(dirpath, filename) + if "mask" in filename: + mask_slices.append(imread(filepath, as_gray=True)) + else: + image_slices.append(imread(filepath)) + if len(image_slices) > 0: + patient_id = dirpath.split("/")[-1] + volumes[patient_id] = np.array(image_slices[1:-1]) + masks[patient_id] = np.array(mask_slices[1:-1]) + + self.patients = sorted(volumes) + + # select cases to subset + if not subset == "all": + random.seed(seed) + validation_patients = random.sample(self.patients, k=validation_cases) + if subset == "validation": + self.patients = validation_patients + else: + self.patients = sorted( + list(set(self.patients).difference(validation_patients)) + ) + + print("preprocessing {} volumes...".format(subset)) + # create list of tuples (volume, mask) + self.volumes = [(volumes[k], masks[k]) for k in self.patients] + + print("cropping {} volumes...".format(subset)) + # crop to smallest enclosing volume + self.volumes = [crop_sample(v) for v in self.volumes] + + print("padding {} volumes...".format(subset)) + # pad to square + self.volumes = [pad_sample(v) for v in self.volumes] + + print("resizing {} volumes...".format(subset)) + # resize + self.volumes = [resize_sample(v, size=image_size) for v in self.volumes] + + print("normalizing {} volumes...".format(subset)) + # normalize channel-wise + self.volumes = [(normalize_volume(v), m) for v, m in self.volumes] + + # probabilities for sampling slices based on masks + self.slice_weights = [m.sum(axis=-1).sum(axis=-1) for v, m in self.volumes] + self.slice_weights = [ + (s + (s.sum() * 0.1 / len(s))) / (s.sum() * 1.1) for s in self.slice_weights + ] + + # add channel dimension to masks + self.volumes = [(v, m[..., np.newaxis]) for (v, m) in self.volumes] + + print("done creating {} dataset".format(subset)) + + # create global index for patient and slice (idx -> (p_idx, s_idx)) + num_slices = [v.shape[0] for v, m in self.volumes] + self.patient_slice_index = list( + zip( + sum([[i] * num_slices[i] for i in range(len(num_slices))], []), + sum([list(range(x)) for x in num_slices], []), + ) + ) + + self.random_sampling = random_sampling + + self.transform = transform + + def __len__(self): + return len(self.patient_slice_index) + + def __getitem__(self, idx): + patient = self.patient_slice_index[idx][0] + slice_n = self.patient_slice_index[idx][1] + + if self.random_sampling: + patient = np.random.randint(len(self.volumes)) + slice_n = np.random.choice( + range(self.volumes[patient][0].shape[0]), p=self.slice_weights[patient] + ) + + v, m = self.volumes[patient] + image = v[slice_n] + mask = m[slice_n] + + if self.transform is not None: + image, mask = self.transform((image, mask)) + + # fix dimensions (C, H, W) + image = image.transpose(2, 0, 1) + mask = mask.transpose(2, 0, 1) + + image_tensor = torch.from_numpy(image.astype(np.float32)) + mask_tensor = torch.from_numpy(mask.astype(np.float32)) + + # return tensors + return image_tensor, mask_tensor