Diff of /dataset.py [000000] .. [ef4563]

Switch to unified view

a b/dataset.py
1
import copy
2
import nibabel as nib
3
import numpy as np
4
import os
5
import tarfile
6
import json
7
from sklearn.utils import shuffle
8
from torch.utils.data import Dataset, DataLoader
9
import torch
10
from torch.utils.data import random_split
11
from config import (
12
    DATASET_PATH, TASK_ID, TRAIN_VAL_TEST_SPLIT,
13
    TRAIN_BATCH_SIZE, VAL_BATCH_SIZE, TEST_BATCH_SIZE
14
)
15
16
#Utility function to extract .tar file formats into ./Datasets directory
17
def ExtractTar(Directory):
18
        try:
19
            print("Extracting tar file ...")
20
            tarfile.open(Directory).extractall('./Datasets')
21
        except:
22
            raise "File extraction failed!"
23
        print("Extraction completed!")
24
        return 
25
26
27
#The dict representing segmentation tasks along with their IDs
28
task_names = {
29
    "01": "BrainTumour",
30
    "02": "Heart",
31
    "03": "Liver",
32
    "04": "Hippocampus",
33
    "05": "Prostate",
34
    "06": "Lung",
35
    "07": "Pancreas",
36
    "08": "HepaticVessel",
37
    "09": "Spleen",
38
    "10": "Colon"
39
}
40
41
42
class MedicalSegmentationDecathlon(Dataset):
43
    """
44
    The base dataset class for Decathlon segmentation tasks
45
    -- __init__()
46
    :param task_number -> represent the organ dataset ID (see task_names above for hints)
47
    :param dir_path -> the dataset directory path to .tar files
48
    :param transform -> optional - transforms to be applied on each instance
49
    """
50
    def __init__(self, task_number, dir_path, split_ratios = [0.8, 0.1, 0.1], transforms = None, mode = None) -> None:
51
        super(MedicalSegmentationDecathlon, self).__init__()
52
        #Rectify the task ID representaion
53
        self.task_number = str(task_number)
54
        if len(self.task_number) == 1:
55
            self.task_number = "0" + self.task_number
56
        #Building the file name according to task ID
57
        self.file_name = f"Task{self.task_number}_{task_names[self.task_number]}"
58
        #Extracting .tar file
59
        if not os.path.exists(os.path.join(os.getcwd(), "Datasets", self.file_name)):
60
            ExtractTar(os.path.join(dir_path, f"{self.file_name}.tar"))
61
        #Path to extracted dataset
62
        self.dir = os.path.join(os.getcwd(), "Datasets", self.file_name)
63
        #Meta data about the dataset
64
        self.meta = json.load(open(os.path.join(self.dir, "dataset.json")))
65
        self.splits = split_ratios
66
        self.transform = transforms
67
        #Calculating split number of images
68
        num_training_imgs =  self.meta["numTraining"]
69
        train_val_test = [int(x * num_training_imgs) for x in split_ratios]
70
        if(sum(train_val_test) != num_training_imgs): train_val_test[0] += (num_training_imgs - sum(train_val_test))
71
        train_val_test = [x for x in train_val_test if x!=0]
72
        # train_val_test = [(x-1) for x in train_val_test]
73
        self.mode = mode
74
        #Spliting dataset
75
        samples = self.meta["training"]
76
        shuffle(samples)
77
        self.train = samples[0:train_val_test[0]]
78
        self.val = samples[train_val_test[0]:train_val_test[0] + train_val_test[1]]
79
        self.test = samples[train_val_test[1]:train_val_test[1] + train_val_test[2]]
80
81
    def set_mode(self, mode):
82
        self.mode = mode
83
84
    def __len__(self):
85
        if self.mode == "train":
86
            return len(self.train)
87
        elif self.mode == "val":
88
            return len(self.val)
89
        elif self.mode == "test":
90
            return len(self.test)
91
        return self.meta["numTraining"]
92
93
    def __getitem__(self, idx):
94
        if torch.is_tensor(idx):
95
            idx = idx.tolist()
96
        #Obtaining image name by given index and the mode using meta data
97
        if self.mode == "train":
98
            name = self.train[idx]['image'].split('/')[-1]
99
        elif self.mode == "val":
100
            name = self.val[idx]['image'].split('/')[-1]
101
        elif self.mode == "test":
102
            name = self.test[idx]['image'].split('/')[-1]
103
        else:
104
            name = self.meta["training"][idx]['image'].split('/')[-1]
105
        img_path = os.path.join(self.dir, "imagesTr", name)
106
        label_path = os.path.join(self.dir, "labelsTr", name)
107
        img_object = nib.load(img_path)
108
        label_object = nib.load(label_path)
109
        img_array = img_object.get_fdata()
110
        #Converting to channel-first numpy array
111
        img_array = np.moveaxis(img_array, -1, 0)
112
        label_array = label_object.get_fdata()
113
        label_array = np.moveaxis(label_array, -1, 0)
114
        proccessed_out = {'name': name, 'image': img_array, 'label': label_array} 
115
        if self.transform:
116
            if self.mode == "train":
117
                proccessed_out = self.transform[0](proccessed_out)
118
            elif self.mode == "val":
119
                proccessed_out = self.transform[1](proccessed_out)
120
            elif self.mode == "test":
121
                proccessed_out = self.transform[2](proccessed_out)
122
            else:
123
                proccessed_out = self.transform(proccessed_out)
124
        
125
        #The output numpy array is in channel-first format
126
        return proccessed_out
127
128
129
130
def get_train_val_test_Dataloaders(train_transforms, val_transforms, test_transforms):
131
    """
132
    The utility function to generate splitted train, validation and test dataloaders
133
    
134
    Note: all the configs to generate dataloaders in included in "config.py"
135
    """
136
137
    dataset = MedicalSegmentationDecathlon(task_number=TASK_ID, dir_path=DATASET_PATH, split_ratios=TRAIN_VAL_TEST_SPLIT, transforms=[train_transforms, val_transforms, test_transforms])
138
139
    #Spliting dataset and building their respective DataLoaders
140
    train_set, val_set, test_set = copy.deepcopy(dataset), copy.deepcopy(dataset), copy.deepcopy(dataset)
141
    train_set.set_mode('train')
142
    val_set.set_mode('val')
143
    test_set.set_mode('test')
144
    train_dataloader = DataLoader(dataset= train_set, batch_size= TRAIN_BATCH_SIZE, shuffle= False)
145
    val_dataloader = DataLoader(dataset= val_set, batch_size= VAL_BATCH_SIZE, shuffle= False)
146
    test_dataloader = DataLoader(dataset= test_set, batch_size= TEST_BATCH_SIZE, shuffle= False)
147
    
148
    return train_dataloader, val_dataloader, test_dataloader