--- a +++ b/dataset.py @@ -0,0 +1,148 @@ +import copy +import nibabel as nib +import numpy as np +import os +import tarfile +import json +from sklearn.utils import shuffle +from torch.utils.data import Dataset, DataLoader +import torch +from torch.utils.data import random_split +from config import ( + DATASET_PATH, TASK_ID, TRAIN_VAL_TEST_SPLIT, + TRAIN_BATCH_SIZE, VAL_BATCH_SIZE, TEST_BATCH_SIZE +) + +#Utility function to extract .tar file formats into ./Datasets directory +def ExtractTar(Directory): + try: + print("Extracting tar file ...") + tarfile.open(Directory).extractall('./Datasets') + except: + raise "File extraction failed!" + print("Extraction completed!") + return + + +#The dict representing segmentation tasks along with their IDs +task_names = { + "01": "BrainTumour", + "02": "Heart", + "03": "Liver", + "04": "Hippocampus", + "05": "Prostate", + "06": "Lung", + "07": "Pancreas", + "08": "HepaticVessel", + "09": "Spleen", + "10": "Colon" +} + + +class MedicalSegmentationDecathlon(Dataset): + """ + The base dataset class for Decathlon segmentation tasks + -- __init__() + :param task_number -> represent the organ dataset ID (see task_names above for hints) + :param dir_path -> the dataset directory path to .tar files + :param transform -> optional - transforms to be applied on each instance + """ + def __init__(self, task_number, dir_path, split_ratios = [0.8, 0.1, 0.1], transforms = None, mode = None) -> None: + super(MedicalSegmentationDecathlon, self).__init__() + #Rectify the task ID representaion + self.task_number = str(task_number) + if len(self.task_number) == 1: + self.task_number = "0" + self.task_number + #Building the file name according to task ID + self.file_name = f"Task{self.task_number}_{task_names[self.task_number]}" + #Extracting .tar file + if not os.path.exists(os.path.join(os.getcwd(), "Datasets", self.file_name)): + ExtractTar(os.path.join(dir_path, f"{self.file_name}.tar")) + #Path to extracted dataset + self.dir = os.path.join(os.getcwd(), "Datasets", self.file_name) + #Meta data about the dataset + self.meta = json.load(open(os.path.join(self.dir, "dataset.json"))) + self.splits = split_ratios + self.transform = transforms + #Calculating split number of images + num_training_imgs = self.meta["numTraining"] + train_val_test = [int(x * num_training_imgs) for x in split_ratios] + if(sum(train_val_test) != num_training_imgs): train_val_test[0] += (num_training_imgs - sum(train_val_test)) + train_val_test = [x for x in train_val_test if x!=0] + # train_val_test = [(x-1) for x in train_val_test] + self.mode = mode + #Spliting dataset + samples = self.meta["training"] + shuffle(samples) + self.train = samples[0:train_val_test[0]] + self.val = samples[train_val_test[0]:train_val_test[0] + train_val_test[1]] + self.test = samples[train_val_test[1]:train_val_test[1] + train_val_test[2]] + + def set_mode(self, mode): + self.mode = mode + + def __len__(self): + if self.mode == "train": + return len(self.train) + elif self.mode == "val": + return len(self.val) + elif self.mode == "test": + return len(self.test) + return self.meta["numTraining"] + + def __getitem__(self, idx): + if torch.is_tensor(idx): + idx = idx.tolist() + #Obtaining image name by given index and the mode using meta data + if self.mode == "train": + name = self.train[idx]['image'].split('/')[-1] + elif self.mode == "val": + name = self.val[idx]['image'].split('/')[-1] + elif self.mode == "test": + name = self.test[idx]['image'].split('/')[-1] + else: + name = self.meta["training"][idx]['image'].split('/')[-1] + img_path = os.path.join(self.dir, "imagesTr", name) + label_path = os.path.join(self.dir, "labelsTr", name) + img_object = nib.load(img_path) + label_object = nib.load(label_path) + img_array = img_object.get_fdata() + #Converting to channel-first numpy array + img_array = np.moveaxis(img_array, -1, 0) + label_array = label_object.get_fdata() + label_array = np.moveaxis(label_array, -1, 0) + proccessed_out = {'name': name, 'image': img_array, 'label': label_array} + if self.transform: + if self.mode == "train": + proccessed_out = self.transform[0](proccessed_out) + elif self.mode == "val": + proccessed_out = self.transform[1](proccessed_out) + elif self.mode == "test": + proccessed_out = self.transform[2](proccessed_out) + else: + proccessed_out = self.transform(proccessed_out) + + #The output numpy array is in channel-first format + return proccessed_out + + + +def get_train_val_test_Dataloaders(train_transforms, val_transforms, test_transforms): + """ + The utility function to generate splitted train, validation and test dataloaders + + Note: all the configs to generate dataloaders in included in "config.py" + """ + + dataset = MedicalSegmentationDecathlon(task_number=TASK_ID, dir_path=DATASET_PATH, split_ratios=TRAIN_VAL_TEST_SPLIT, transforms=[train_transforms, val_transforms, test_transforms]) + + #Spliting dataset and building their respective DataLoaders + train_set, val_set, test_set = copy.deepcopy(dataset), copy.deepcopy(dataset), copy.deepcopy(dataset) + train_set.set_mode('train') + val_set.set_mode('val') + test_set.set_mode('test') + train_dataloader = DataLoader(dataset= train_set, batch_size= TRAIN_BATCH_SIZE, shuffle= False) + val_dataloader = DataLoader(dataset= val_set, batch_size= VAL_BATCH_SIZE, shuffle= False) + test_dataloader = DataLoader(dataset= test_set, batch_size= TEST_BATCH_SIZE, shuffle= False) + + return train_dataloader, val_dataloader, test_dataloader