|
a |
|
b/dataset.py |
|
|
1 |
import copy |
|
|
2 |
import nibabel as nib |
|
|
3 |
import numpy as np |
|
|
4 |
import os |
|
|
5 |
import tarfile |
|
|
6 |
import json |
|
|
7 |
from sklearn.utils import shuffle |
|
|
8 |
from torch.utils.data import Dataset, DataLoader |
|
|
9 |
import torch |
|
|
10 |
from torch.utils.data import random_split |
|
|
11 |
from config import ( |
|
|
12 |
DATASET_PATH, TASK_ID, TRAIN_VAL_TEST_SPLIT, |
|
|
13 |
TRAIN_BATCH_SIZE, VAL_BATCH_SIZE, TEST_BATCH_SIZE |
|
|
14 |
) |
|
|
15 |
|
|
|
16 |
#Utility function to extract .tar file formats into ./Datasets directory |
|
|
17 |
def ExtractTar(Directory): |
|
|
18 |
try: |
|
|
19 |
print("Extracting tar file ...") |
|
|
20 |
tarfile.open(Directory).extractall('./Datasets') |
|
|
21 |
except: |
|
|
22 |
raise "File extraction failed!" |
|
|
23 |
print("Extraction completed!") |
|
|
24 |
return |
|
|
25 |
|
|
|
26 |
|
|
|
27 |
#The dict representing segmentation tasks along with their IDs |
|
|
28 |
task_names = { |
|
|
29 |
"01": "BrainTumour", |
|
|
30 |
"02": "Heart", |
|
|
31 |
"03": "Liver", |
|
|
32 |
"04": "Hippocampus", |
|
|
33 |
"05": "Prostate", |
|
|
34 |
"06": "Lung", |
|
|
35 |
"07": "Pancreas", |
|
|
36 |
"08": "HepaticVessel", |
|
|
37 |
"09": "Spleen", |
|
|
38 |
"10": "Colon" |
|
|
39 |
} |
|
|
40 |
|
|
|
41 |
|
|
|
42 |
class MedicalSegmentationDecathlon(Dataset): |
|
|
43 |
""" |
|
|
44 |
The base dataset class for Decathlon segmentation tasks |
|
|
45 |
-- __init__() |
|
|
46 |
:param task_number -> represent the organ dataset ID (see task_names above for hints) |
|
|
47 |
:param dir_path -> the dataset directory path to .tar files |
|
|
48 |
:param transform -> optional - transforms to be applied on each instance |
|
|
49 |
""" |
|
|
50 |
def __init__(self, task_number, dir_path, split_ratios = [0.8, 0.1, 0.1], transforms = None, mode = None) -> None: |
|
|
51 |
super(MedicalSegmentationDecathlon, self).__init__() |
|
|
52 |
#Rectify the task ID representaion |
|
|
53 |
self.task_number = str(task_number) |
|
|
54 |
if len(self.task_number) == 1: |
|
|
55 |
self.task_number = "0" + self.task_number |
|
|
56 |
#Building the file name according to task ID |
|
|
57 |
self.file_name = f"Task{self.task_number}_{task_names[self.task_number]}" |
|
|
58 |
#Extracting .tar file |
|
|
59 |
if not os.path.exists(os.path.join(os.getcwd(), "Datasets", self.file_name)): |
|
|
60 |
ExtractTar(os.path.join(dir_path, f"{self.file_name}.tar")) |
|
|
61 |
#Path to extracted dataset |
|
|
62 |
self.dir = os.path.join(os.getcwd(), "Datasets", self.file_name) |
|
|
63 |
#Meta data about the dataset |
|
|
64 |
self.meta = json.load(open(os.path.join(self.dir, "dataset.json"))) |
|
|
65 |
self.splits = split_ratios |
|
|
66 |
self.transform = transforms |
|
|
67 |
#Calculating split number of images |
|
|
68 |
num_training_imgs = self.meta["numTraining"] |
|
|
69 |
train_val_test = [int(x * num_training_imgs) for x in split_ratios] |
|
|
70 |
if(sum(train_val_test) != num_training_imgs): train_val_test[0] += (num_training_imgs - sum(train_val_test)) |
|
|
71 |
train_val_test = [x for x in train_val_test if x!=0] |
|
|
72 |
# train_val_test = [(x-1) for x in train_val_test] |
|
|
73 |
self.mode = mode |
|
|
74 |
#Spliting dataset |
|
|
75 |
samples = self.meta["training"] |
|
|
76 |
shuffle(samples) |
|
|
77 |
self.train = samples[0:train_val_test[0]] |
|
|
78 |
self.val = samples[train_val_test[0]:train_val_test[0] + train_val_test[1]] |
|
|
79 |
self.test = samples[train_val_test[1]:train_val_test[1] + train_val_test[2]] |
|
|
80 |
|
|
|
81 |
def set_mode(self, mode): |
|
|
82 |
self.mode = mode |
|
|
83 |
|
|
|
84 |
def __len__(self): |
|
|
85 |
if self.mode == "train": |
|
|
86 |
return len(self.train) |
|
|
87 |
elif self.mode == "val": |
|
|
88 |
return len(self.val) |
|
|
89 |
elif self.mode == "test": |
|
|
90 |
return len(self.test) |
|
|
91 |
return self.meta["numTraining"] |
|
|
92 |
|
|
|
93 |
def __getitem__(self, idx): |
|
|
94 |
if torch.is_tensor(idx): |
|
|
95 |
idx = idx.tolist() |
|
|
96 |
#Obtaining image name by given index and the mode using meta data |
|
|
97 |
if self.mode == "train": |
|
|
98 |
name = self.train[idx]['image'].split('/')[-1] |
|
|
99 |
elif self.mode == "val": |
|
|
100 |
name = self.val[idx]['image'].split('/')[-1] |
|
|
101 |
elif self.mode == "test": |
|
|
102 |
name = self.test[idx]['image'].split('/')[-1] |
|
|
103 |
else: |
|
|
104 |
name = self.meta["training"][idx]['image'].split('/')[-1] |
|
|
105 |
img_path = os.path.join(self.dir, "imagesTr", name) |
|
|
106 |
label_path = os.path.join(self.dir, "labelsTr", name) |
|
|
107 |
img_object = nib.load(img_path) |
|
|
108 |
label_object = nib.load(label_path) |
|
|
109 |
img_array = img_object.get_fdata() |
|
|
110 |
#Converting to channel-first numpy array |
|
|
111 |
img_array = np.moveaxis(img_array, -1, 0) |
|
|
112 |
label_array = label_object.get_fdata() |
|
|
113 |
label_array = np.moveaxis(label_array, -1, 0) |
|
|
114 |
proccessed_out = {'name': name, 'image': img_array, 'label': label_array} |
|
|
115 |
if self.transform: |
|
|
116 |
if self.mode == "train": |
|
|
117 |
proccessed_out = self.transform[0](proccessed_out) |
|
|
118 |
elif self.mode == "val": |
|
|
119 |
proccessed_out = self.transform[1](proccessed_out) |
|
|
120 |
elif self.mode == "test": |
|
|
121 |
proccessed_out = self.transform[2](proccessed_out) |
|
|
122 |
else: |
|
|
123 |
proccessed_out = self.transform(proccessed_out) |
|
|
124 |
|
|
|
125 |
#The output numpy array is in channel-first format |
|
|
126 |
return proccessed_out |
|
|
127 |
|
|
|
128 |
|
|
|
129 |
|
|
|
130 |
def get_train_val_test_Dataloaders(train_transforms, val_transforms, test_transforms): |
|
|
131 |
""" |
|
|
132 |
The utility function to generate splitted train, validation and test dataloaders |
|
|
133 |
|
|
|
134 |
Note: all the configs to generate dataloaders in included in "config.py" |
|
|
135 |
""" |
|
|
136 |
|
|
|
137 |
dataset = MedicalSegmentationDecathlon(task_number=TASK_ID, dir_path=DATASET_PATH, split_ratios=TRAIN_VAL_TEST_SPLIT, transforms=[train_transforms, val_transforms, test_transforms]) |
|
|
138 |
|
|
|
139 |
#Spliting dataset and building their respective DataLoaders |
|
|
140 |
train_set, val_set, test_set = copy.deepcopy(dataset), copy.deepcopy(dataset), copy.deepcopy(dataset) |
|
|
141 |
train_set.set_mode('train') |
|
|
142 |
val_set.set_mode('val') |
|
|
143 |
test_set.set_mode('test') |
|
|
144 |
train_dataloader = DataLoader(dataset= train_set, batch_size= TRAIN_BATCH_SIZE, shuffle= False) |
|
|
145 |
val_dataloader = DataLoader(dataset= val_set, batch_size= VAL_BATCH_SIZE, shuffle= False) |
|
|
146 |
test_dataloader = DataLoader(dataset= test_set, batch_size= TEST_BATCH_SIZE, shuffle= False) |
|
|
147 |
|
|
|
148 |
return train_dataloader, val_dataloader, test_dataloader |