Diff of /dataset.py [000000] .. [ef4563]

Switch to side-by-side view

--- a
+++ b/dataset.py
@@ -0,0 +1,148 @@
+import copy
+import nibabel as nib
+import numpy as np
+import os
+import tarfile
+import json
+from sklearn.utils import shuffle
+from torch.utils.data import Dataset, DataLoader
+import torch
+from torch.utils.data import random_split
+from config import (
+    DATASET_PATH, TASK_ID, TRAIN_VAL_TEST_SPLIT,
+    TRAIN_BATCH_SIZE, VAL_BATCH_SIZE, TEST_BATCH_SIZE
+)
+
+#Utility function to extract .tar file formats into ./Datasets directory
+def ExtractTar(Directory):
+        try:
+            print("Extracting tar file ...")
+            tarfile.open(Directory).extractall('./Datasets')
+        except:
+            raise "File extraction failed!"
+        print("Extraction completed!")
+        return 
+
+
+#The dict representing segmentation tasks along with their IDs
+task_names = {
+    "01": "BrainTumour",
+    "02": "Heart",
+    "03": "Liver",
+    "04": "Hippocampus",
+    "05": "Prostate",
+    "06": "Lung",
+    "07": "Pancreas",
+    "08": "HepaticVessel",
+    "09": "Spleen",
+    "10": "Colon"
+}
+
+
+class MedicalSegmentationDecathlon(Dataset):
+    """
+    The base dataset class for Decathlon segmentation tasks
+    -- __init__()
+    :param task_number -> represent the organ dataset ID (see task_names above for hints)
+    :param dir_path -> the dataset directory path to .tar files
+    :param transform -> optional - transforms to be applied on each instance
+    """
+    def __init__(self, task_number, dir_path, split_ratios = [0.8, 0.1, 0.1], transforms = None, mode = None) -> None:
+        super(MedicalSegmentationDecathlon, self).__init__()
+        #Rectify the task ID representaion
+        self.task_number = str(task_number)
+        if len(self.task_number) == 1:
+            self.task_number = "0" + self.task_number
+        #Building the file name according to task ID
+        self.file_name = f"Task{self.task_number}_{task_names[self.task_number]}"
+        #Extracting .tar file
+        if not os.path.exists(os.path.join(os.getcwd(), "Datasets", self.file_name)):
+            ExtractTar(os.path.join(dir_path, f"{self.file_name}.tar"))
+        #Path to extracted dataset
+        self.dir = os.path.join(os.getcwd(), "Datasets", self.file_name)
+        #Meta data about the dataset
+        self.meta = json.load(open(os.path.join(self.dir, "dataset.json")))
+        self.splits = split_ratios
+        self.transform = transforms
+        #Calculating split number of images
+        num_training_imgs =  self.meta["numTraining"]
+        train_val_test = [int(x * num_training_imgs) for x in split_ratios]
+        if(sum(train_val_test) != num_training_imgs): train_val_test[0] += (num_training_imgs - sum(train_val_test))
+        train_val_test = [x for x in train_val_test if x!=0]
+        # train_val_test = [(x-1) for x in train_val_test]
+        self.mode = mode
+        #Spliting dataset
+        samples = self.meta["training"]
+        shuffle(samples)
+        self.train = samples[0:train_val_test[0]]
+        self.val = samples[train_val_test[0]:train_val_test[0] + train_val_test[1]]
+        self.test = samples[train_val_test[1]:train_val_test[1] + train_val_test[2]]
+
+    def set_mode(self, mode):
+        self.mode = mode
+
+    def __len__(self):
+        if self.mode == "train":
+            return len(self.train)
+        elif self.mode == "val":
+            return len(self.val)
+        elif self.mode == "test":
+            return len(self.test)
+        return self.meta["numTraining"]
+
+    def __getitem__(self, idx):
+        if torch.is_tensor(idx):
+            idx = idx.tolist()
+        #Obtaining image name by given index and the mode using meta data
+        if self.mode == "train":
+            name = self.train[idx]['image'].split('/')[-1]
+        elif self.mode == "val":
+            name = self.val[idx]['image'].split('/')[-1]
+        elif self.mode == "test":
+            name = self.test[idx]['image'].split('/')[-1]
+        else:
+            name = self.meta["training"][idx]['image'].split('/')[-1]
+        img_path = os.path.join(self.dir, "imagesTr", name)
+        label_path = os.path.join(self.dir, "labelsTr", name)
+        img_object = nib.load(img_path)
+        label_object = nib.load(label_path)
+        img_array = img_object.get_fdata()
+        #Converting to channel-first numpy array
+        img_array = np.moveaxis(img_array, -1, 0)
+        label_array = label_object.get_fdata()
+        label_array = np.moveaxis(label_array, -1, 0)
+        proccessed_out = {'name': name, 'image': img_array, 'label': label_array} 
+        if self.transform:
+            if self.mode == "train":
+                proccessed_out = self.transform[0](proccessed_out)
+            elif self.mode == "val":
+                proccessed_out = self.transform[1](proccessed_out)
+            elif self.mode == "test":
+                proccessed_out = self.transform[2](proccessed_out)
+            else:
+                proccessed_out = self.transform(proccessed_out)
+        
+        #The output numpy array is in channel-first format
+        return proccessed_out
+
+
+
+def get_train_val_test_Dataloaders(train_transforms, val_transforms, test_transforms):
+    """
+    The utility function to generate splitted train, validation and test dataloaders
+    
+    Note: all the configs to generate dataloaders in included in "config.py"
+    """
+
+    dataset = MedicalSegmentationDecathlon(task_number=TASK_ID, dir_path=DATASET_PATH, split_ratios=TRAIN_VAL_TEST_SPLIT, transforms=[train_transforms, val_transforms, test_transforms])
+
+    #Spliting dataset and building their respective DataLoaders
+    train_set, val_set, test_set = copy.deepcopy(dataset), copy.deepcopy(dataset), copy.deepcopy(dataset)
+    train_set.set_mode('train')
+    val_set.set_mode('val')
+    test_set.set_mode('test')
+    train_dataloader = DataLoader(dataset= train_set, batch_size= TRAIN_BATCH_SIZE, shuffle= False)
+    val_dataloader = DataLoader(dataset= val_set, batch_size= VAL_BATCH_SIZE, shuffle= False)
+    test_dataloader = DataLoader(dataset= test_set, batch_size= TEST_BATCH_SIZE, shuffle= False)
+    
+    return train_dataloader, val_dataloader, test_dataloader