|
a |
|
b/adpkd_segmentation/datasets/datasets.py |
|
|
1 |
import json |
|
|
2 |
import numpy as np |
|
|
3 |
import torch |
|
|
4 |
from pathlib import Path |
|
|
5 |
import pandas as pd |
|
|
6 |
import pydicom |
|
|
7 |
from ast import literal_eval |
|
|
8 |
|
|
|
9 |
from adpkd_segmentation.data.data_utils import ( |
|
|
10 |
get_labeled, |
|
|
11 |
get_y_Path, |
|
|
12 |
int16_to_uint8, |
|
|
13 |
make_dcmdicts, |
|
|
14 |
path_2dcm_int16, |
|
|
15 |
path_2label, |
|
|
16 |
TKV_update, |
|
|
17 |
) |
|
|
18 |
|
|
|
19 |
from adpkd_segmentation.data.data_utils import ( |
|
|
20 |
KIDNEY_PIXELS, |
|
|
21 |
STUDY_TKV, |
|
|
22 |
VOXEL_VOLUME, |
|
|
23 |
) |
|
|
24 |
|
|
|
25 |
from adpkd_segmentation.datasets.filters import PatientFiltering |
|
|
26 |
|
|
|
27 |
|
|
|
28 |
class SegmentationDataset(torch.utils.data.Dataset): |
|
|
29 |
"""Some information about SegmentationDataset""" |
|
|
30 |
|
|
|
31 |
def __init__( |
|
|
32 |
self, |
|
|
33 |
label2mask, |
|
|
34 |
dcm2attribs, |
|
|
35 |
patient2dcm, |
|
|
36 |
patient_IDS=None, |
|
|
37 |
augmentation=None, |
|
|
38 |
smp_preprocessing=None, |
|
|
39 |
normalization=None, |
|
|
40 |
output_idx=False, |
|
|
41 |
attrib_types=None, |
|
|
42 |
): |
|
|
43 |
|
|
|
44 |
super().__init__() |
|
|
45 |
self.label2mask = label2mask |
|
|
46 |
self.dcm2attribs = dcm2attribs |
|
|
47 |
self.pt2dcm = patient2dcm |
|
|
48 |
self.patient_IDS = patient_IDS |
|
|
49 |
self.augmentation = augmentation |
|
|
50 |
self.smp_preprocessing = smp_preprocessing |
|
|
51 |
self.normalization = normalization |
|
|
52 |
self.output_idx = output_idx |
|
|
53 |
self.attrib_types = attrib_types |
|
|
54 |
|
|
|
55 |
# store some attributes as PyTorch tensors |
|
|
56 |
if self.attrib_types is None: |
|
|
57 |
self.attrib_types = { |
|
|
58 |
STUDY_TKV: "float32", |
|
|
59 |
KIDNEY_PIXELS: "float32", |
|
|
60 |
VOXEL_VOLUME: "float32", |
|
|
61 |
} |
|
|
62 |
|
|
|
63 |
self.patients = list(patient2dcm.keys()) |
|
|
64 |
# kept for compatibility with previous experiments |
|
|
65 |
# following patient order in patient_IDS |
|
|
66 |
if patient_IDS is not None: |
|
|
67 |
self.patients = patient_IDS |
|
|
68 |
|
|
|
69 |
self.dcm_paths = [] |
|
|
70 |
for p in self.patients: |
|
|
71 |
self.dcm_paths.extend(patient2dcm[p]) |
|
|
72 |
self.label_paths = [get_y_Path(dcm) for dcm in self.dcm_paths] |
|
|
73 |
|
|
|
74 |
# study_id to TKV and TKV for each dcm |
|
|
75 |
self.studies, self.dcm2attribs = TKV_update(dcm2attribs) |
|
|
76 |
# storring attrib types as tensors |
|
|
77 |
self.tensor_dict = self.prepare_tensor_dict(self.attrib_types) |
|
|
78 |
|
|
|
79 |
def __getitem__(self, index): |
|
|
80 |
|
|
|
81 |
if isinstance(index, slice): |
|
|
82 |
return [self[ii] for ii in range(*index.indices(len(self)))] |
|
|
83 |
|
|
|
84 |
# numpy int16, (H, W) |
|
|
85 |
im_path = self.dcm_paths[index] |
|
|
86 |
image = path_2dcm_int16(im_path) |
|
|
87 |
# image local scaling by default to convert to uint8 |
|
|
88 |
if self.normalization is None: |
|
|
89 |
image = int16_to_uint8(image) |
|
|
90 |
else: |
|
|
91 |
image = self.normalization(image, self.dcm2attribs[im_path]) |
|
|
92 |
|
|
|
93 |
label = path_2label(self.label_paths[index]) |
|
|
94 |
|
|
|
95 |
# numpy uint8, one hot encoded (C, H, W) |
|
|
96 |
mask = self.label2mask(label[np.newaxis, ...]) |
|
|
97 |
|
|
|
98 |
if self.augmentation is not None: |
|
|
99 |
# requires (H, W, C) or (H, W) |
|
|
100 |
mask = mask.transpose(1, 2, 0) |
|
|
101 |
sample = self.augmentation(image=image, mask=mask) |
|
|
102 |
image, mask = sample["image"], sample["mask"] |
|
|
103 |
# get back to (C, H, W) |
|
|
104 |
mask = mask.transpose(2, 0, 1) |
|
|
105 |
|
|
|
106 |
# convert to float |
|
|
107 |
image = (image / 255).astype(np.float32) |
|
|
108 |
mask = mask.astype(np.float32) |
|
|
109 |
|
|
|
110 |
# smp preprocessing requires (H, W, 3) |
|
|
111 |
if self.smp_preprocessing is not None: |
|
|
112 |
image = np.repeat(image[..., np.newaxis], 3, axis=-1) |
|
|
113 |
image = self.smp_preprocessing(image).astype(np.float32) |
|
|
114 |
# get back to (3, H, W) |
|
|
115 |
image = image.transpose(2, 0, 1) |
|
|
116 |
else: |
|
|
117 |
# stack image to (3, H, W) |
|
|
118 |
image = np.repeat(image[np.newaxis, ...], 3, axis=0) |
|
|
119 |
|
|
|
120 |
if self.output_idx: |
|
|
121 |
return image, mask, index |
|
|
122 |
return image, mask |
|
|
123 |
|
|
|
124 |
def __len__(self): |
|
|
125 |
return len(self.dcm_paths) |
|
|
126 |
|
|
|
127 |
def get_verbose(self, index): |
|
|
128 |
"""returns more details than __getitem__() |
|
|
129 |
|
|
|
130 |
Args: |
|
|
131 |
index (int): index in dataset |
|
|
132 |
|
|
|
133 |
Returns: |
|
|
134 |
tuple: sample, dcm_path, attributes dict |
|
|
135 |
""" |
|
|
136 |
|
|
|
137 |
sample = self[index] |
|
|
138 |
dcm_path = self.dcm_paths[index] |
|
|
139 |
attribs = self.dcm2attribs[dcm_path] |
|
|
140 |
|
|
|
141 |
return sample, dcm_path, attribs |
|
|
142 |
|
|
|
143 |
def get_extra_dict(self, batch_of_idx): |
|
|
144 |
return {k: v[batch_of_idx] for k, v in self.tensor_dict.items()} |
|
|
145 |
|
|
|
146 |
def prepare_tensor_dict(self, attrib_types): |
|
|
147 |
tensor_dict = {} |
|
|
148 |
for k, v in attrib_types.items(): |
|
|
149 |
tensor_dict[k] = torch.zeros( |
|
|
150 |
self.__len__(), dtype=getattr(torch, v) |
|
|
151 |
) |
|
|
152 |
|
|
|
153 |
for idx, _ in enumerate(self): |
|
|
154 |
dcm_path = self.dcm_paths[idx] |
|
|
155 |
attribs = self.dcm2attribs[dcm_path] |
|
|
156 |
for k, v in tensor_dict.items(): |
|
|
157 |
v[idx] = attribs[k] |
|
|
158 |
|
|
|
159 |
return tensor_dict |
|
|
160 |
|
|
|
161 |
|
|
|
162 |
class DatasetGetter: |
|
|
163 |
"""Create SegmentationDataset""" |
|
|
164 |
|
|
|
165 |
def __init__( |
|
|
166 |
self, |
|
|
167 |
splitter, |
|
|
168 |
splitter_key, |
|
|
169 |
label2mask, |
|
|
170 |
augmentation=None, |
|
|
171 |
smp_preprocessing=None, |
|
|
172 |
filters=None, |
|
|
173 |
normalization=None, |
|
|
174 |
output_idx=False, |
|
|
175 |
attrib_types=None, |
|
|
176 |
): |
|
|
177 |
super().__init__() |
|
|
178 |
self.splitter = splitter |
|
|
179 |
self.splitter_key = splitter_key |
|
|
180 |
self.label2mask = label2mask |
|
|
181 |
self.augmentation = augmentation |
|
|
182 |
self.smp_preprocessing = smp_preprocessing |
|
|
183 |
self.filters = filters |
|
|
184 |
self.normalization = normalization |
|
|
185 |
self.output_idx = output_idx |
|
|
186 |
self.attrib_types = attrib_types |
|
|
187 |
|
|
|
188 |
dcms_paths = sorted(get_labeled()) |
|
|
189 |
print( |
|
|
190 |
"The number of images before splitting and filtering: {}".format( |
|
|
191 |
len(dcms_paths) |
|
|
192 |
) |
|
|
193 |
) |
|
|
194 |
dcm2attribs, patient2dcm = make_dcmdicts(tuple(dcms_paths)) |
|
|
195 |
|
|
|
196 |
if filters is not None: |
|
|
197 |
dcm2attribs, patient2dcm = filters(dcm2attribs, patient2dcm) |
|
|
198 |
|
|
|
199 |
self.all_patient_IDS = list(patient2dcm.keys()) |
|
|
200 |
# train, val, or test |
|
|
201 |
self.patient_IDS = self.splitter(self.all_patient_IDS)[ |
|
|
202 |
self.splitter_key |
|
|
203 |
] |
|
|
204 |
|
|
|
205 |
patient_filter = PatientFiltering(self.patient_IDS) |
|
|
206 |
self.dcm2attribs, self.patient2dcm = patient_filter( |
|
|
207 |
dcm2attribs, patient2dcm |
|
|
208 |
) |
|
|
209 |
if self.normalization is not None: |
|
|
210 |
self.normalization.update_dcm2attribs(self.dcm2attribs) |
|
|
211 |
|
|
|
212 |
def __call__(self): |
|
|
213 |
return SegmentationDataset( |
|
|
214 |
label2mask=self.label2mask, |
|
|
215 |
dcm2attribs=self.dcm2attribs, |
|
|
216 |
patient2dcm=self.patient2dcm, |
|
|
217 |
patient_IDS=self.patient_IDS, |
|
|
218 |
augmentation=self.augmentation, |
|
|
219 |
smp_preprocessing=self.smp_preprocessing, |
|
|
220 |
normalization=self.normalization, |
|
|
221 |
output_idx=self.output_idx, |
|
|
222 |
attrib_types=self.attrib_types, |
|
|
223 |
) |
|
|
224 |
|
|
|
225 |
|
|
|
226 |
class JsonDatasetGetter: |
|
|
227 |
"""Get the dataset from a prepared patient ID split""" |
|
|
228 |
|
|
|
229 |
def __init__( |
|
|
230 |
self, |
|
|
231 |
json_path, |
|
|
232 |
splitter_key, |
|
|
233 |
label2mask, |
|
|
234 |
augmentation=None, |
|
|
235 |
smp_preprocessing=None, |
|
|
236 |
normalization=None, |
|
|
237 |
output_idx=False, |
|
|
238 |
attrib_types=None, |
|
|
239 |
): |
|
|
240 |
super().__init__() |
|
|
241 |
|
|
|
242 |
self.label2mask = label2mask |
|
|
243 |
self.augmentation = augmentation |
|
|
244 |
self.smp_preprocessing = smp_preprocessing |
|
|
245 |
self.normalization = normalization |
|
|
246 |
self.output_idx = output_idx |
|
|
247 |
self.attrib_types = attrib_types |
|
|
248 |
|
|
|
249 |
dcms_paths = sorted(get_labeled()) |
|
|
250 |
print( |
|
|
251 |
"The number of images before splitting and filtering: {}".format( |
|
|
252 |
len(dcms_paths) |
|
|
253 |
) |
|
|
254 |
) |
|
|
255 |
dcm2attribs, patient2dcm = make_dcmdicts(tuple(dcms_paths)) |
|
|
256 |
|
|
|
257 |
print("Loading ", json_path) |
|
|
258 |
with open(json_path, "r") as f: |
|
|
259 |
dataset_split = json.load(f) |
|
|
260 |
self.patient_IDS = dataset_split[splitter_key] |
|
|
261 |
|
|
|
262 |
# filter info dicts to correpsond to patient_IDS |
|
|
263 |
patient_filter = PatientFiltering(self.patient_IDS) |
|
|
264 |
self.dcm2attribs, self.patient2dcm = patient_filter( |
|
|
265 |
dcm2attribs, patient2dcm |
|
|
266 |
) |
|
|
267 |
if self.normalization is not None: |
|
|
268 |
self.normalization.update_dcm2attribs(self.dcm2attribs) |
|
|
269 |
|
|
|
270 |
def __call__(self): |
|
|
271 |
return SegmentationDataset( |
|
|
272 |
label2mask=self.label2mask, |
|
|
273 |
dcm2attribs=self.dcm2attribs, |
|
|
274 |
patient2dcm=self.patient2dcm, |
|
|
275 |
patient_IDS=self.patient_IDS, |
|
|
276 |
augmentation=self.augmentation, |
|
|
277 |
smp_preprocessing=self.smp_preprocessing, |
|
|
278 |
normalization=self.normalization, |
|
|
279 |
output_idx=self.output_idx, |
|
|
280 |
attrib_types=self.attrib_types, |
|
|
281 |
) |
|
|
282 |
|
|
|
283 |
|
|
|
284 |
class InferenceDataset(torch.utils.data.Dataset): |
|
|
285 |
"""Some information about SegmentationDataset""" |
|
|
286 |
|
|
|
287 |
def __init__( |
|
|
288 |
self, |
|
|
289 |
dcm2attribs, |
|
|
290 |
patient2dcm, |
|
|
291 |
augmentation=None, |
|
|
292 |
smp_preprocessing=None, |
|
|
293 |
normalization=None, |
|
|
294 |
output_idx=False, |
|
|
295 |
attrib_types=None, |
|
|
296 |
): |
|
|
297 |
|
|
|
298 |
super().__init__() |
|
|
299 |
self.dcm2attribs = dcm2attribs |
|
|
300 |
self.pt2dcm = patient2dcm |
|
|
301 |
self.augmentation = augmentation |
|
|
302 |
self.smp_preprocessing = smp_preprocessing |
|
|
303 |
self.normalization = normalization |
|
|
304 |
self.output_idx = output_idx |
|
|
305 |
self.attrib_types = attrib_types |
|
|
306 |
|
|
|
307 |
self.patients = list(patient2dcm.keys()) |
|
|
308 |
|
|
|
309 |
self.dcm_paths = [] |
|
|
310 |
for p in self.patients: |
|
|
311 |
self.dcm_paths.extend(patient2dcm[p]) |
|
|
312 |
|
|
|
313 |
# Sorts Studies by Z axis |
|
|
314 |
studies = [ |
|
|
315 |
pydicom.dcmread(path).SeriesDescription for path in self.dcm_paths |
|
|
316 |
] |
|
|
317 |
folders = [path.parent.name for path in self.dcm_paths] |
|
|
318 |
patients = [pydicom.dcmread(path).PatientID for path in self.dcm_paths] |
|
|
319 |
x_dims = [pydicom.dcmread(path).Rows for path in self.dcm_paths] |
|
|
320 |
y_dims = [pydicom.dcmread(path).Columns for path in self.dcm_paths] |
|
|
321 |
z_pos = [ |
|
|
322 |
literal_eval(str(pydicom.dcmread(path).ImagePositionPatient))[2] |
|
|
323 |
for path in self.dcm_paths |
|
|
324 |
] |
|
|
325 |
acc_nums = [ |
|
|
326 |
pydicom.dcmread(path).AccessionNumber for path in self.dcm_paths |
|
|
327 |
] |
|
|
328 |
ser_nums = [ |
|
|
329 |
pydicom.dcmread(path).SeriesNumber for path in self.dcm_paths |
|
|
330 |
] |
|
|
331 |
|
|
|
332 |
data = { |
|
|
333 |
"dcm_paths": self.dcm_paths, |
|
|
334 |
"folders": folders, |
|
|
335 |
"studies": studies, |
|
|
336 |
"patients": patients, |
|
|
337 |
"x_dims": x_dims, |
|
|
338 |
"y_dims": y_dims, |
|
|
339 |
"z_pos": z_pos, |
|
|
340 |
"acc_nums": acc_nums, |
|
|
341 |
"ser_nums": ser_nums, |
|
|
342 |
} |
|
|
343 |
|
|
|
344 |
group_keys = [ |
|
|
345 |
"folders", |
|
|
346 |
"studies", |
|
|
347 |
"patients", |
|
|
348 |
"x_dims", |
|
|
349 |
"y_dims", |
|
|
350 |
"acc_nums", |
|
|
351 |
"ser_nums", |
|
|
352 |
] |
|
|
353 |
|
|
|
354 |
dataset = pd.DataFrame.from_dict(data) |
|
|
355 |
dataset["slice_pos"] = "" |
|
|
356 |
|
|
|
357 |
grouped_dataset = dataset.groupby(group_keys) |
|
|
358 |
|
|
|
359 |
for (name, group) in grouped_dataset: |
|
|
360 |
sort_key = "z_pos" |
|
|
361 |
|
|
|
362 |
# handle missing slice position with filename |
|
|
363 |
if group[sort_key].isna().any(): |
|
|
364 |
sort_key = "dcm_paths" |
|
|
365 |
|
|
|
366 |
zs = list(group[sort_key]) |
|
|
367 |
|
|
|
368 |
sorted_idxs = np.argsort(zs) |
|
|
369 |
slice_map = { |
|
|
370 |
zs[idx]: pos for idx, pos in zip(sorted_idxs, range(len(zs))) |
|
|
371 |
} |
|
|
372 |
zs_slice_pos = group[sort_key].map(slice_map) |
|
|
373 |
|
|
|
374 |
for i in group.index: |
|
|
375 |
dataset.at[i, "slice_pos"] = zs_slice_pos.get(i) |
|
|
376 |
|
|
|
377 |
grouped_dataset = dataset.groupby(group_keys) |
|
|
378 |
for (name, group) in grouped_dataset: |
|
|
379 |
group.sort_values(by="slice_pos", inplace=True) |
|
|
380 |
|
|
|
381 |
self.df = dataset |
|
|
382 |
self.dcm_paths = list(dataset["dcm_paths"]) |
|
|
383 |
|
|
|
384 |
def __getitem__(self, index): |
|
|
385 |
|
|
|
386 |
if isinstance(index, slice): |
|
|
387 |
return [self[ii] for ii in range(*index.indices(len(self)))] |
|
|
388 |
|
|
|
389 |
# numpy int16, (H, W) |
|
|
390 |
im_path = self.dcm_paths[index] |
|
|
391 |
image = path_2dcm_int16(im_path) |
|
|
392 |
# image local scaling by default to convert to uint8 |
|
|
393 |
if self.normalization is None: |
|
|
394 |
image = int16_to_uint8(image) |
|
|
395 |
else: |
|
|
396 |
image = self.normalization(image, self.dcm2attribs[im_path]) |
|
|
397 |
|
|
|
398 |
if self.augmentation is not None: |
|
|
399 |
sample = self.augmentation(image=image) |
|
|
400 |
image = sample["image"] |
|
|
401 |
|
|
|
402 |
# convert to float |
|
|
403 |
image = (image / 255).astype(np.float32) |
|
|
404 |
|
|
|
405 |
# smp preprocessing requires (H, W, 3) |
|
|
406 |
if self.smp_preprocessing is not None: |
|
|
407 |
image = np.repeat(image[..., np.newaxis], 3, axis=-1) |
|
|
408 |
image = self.smp_preprocessing(image).astype(np.float32) |
|
|
409 |
# get back to (3, H, W) |
|
|
410 |
image = image.transpose(2, 0, 1) |
|
|
411 |
else: |
|
|
412 |
# stack image to (3, H, W) |
|
|
413 |
image = np.repeat(image[np.newaxis, ...], 3, axis=0) |
|
|
414 |
|
|
|
415 |
if self.output_idx: |
|
|
416 |
return image, index |
|
|
417 |
|
|
|
418 |
return image |
|
|
419 |
|
|
|
420 |
def __len__(self): |
|
|
421 |
return len(self.dcm_paths) |
|
|
422 |
|
|
|
423 |
def get_verbose(self, index): |
|
|
424 |
"""returns more details than __getitem__() |
|
|
425 |
|
|
|
426 |
Args: |
|
|
427 |
index (int): index in dataset |
|
|
428 |
|
|
|
429 |
Returns: |
|
|
430 |
tuple: sample, dcm_path, attributes dict |
|
|
431 |
""" |
|
|
432 |
|
|
|
433 |
sample = self[index] |
|
|
434 |
dcm_path = self.dcm_paths[index] |
|
|
435 |
attribs = self.dcm2attribs[dcm_path] |
|
|
436 |
|
|
|
437 |
return sample, dcm_path, attribs |
|
|
438 |
|
|
|
439 |
|
|
|
440 |
class InferenceDatasetGetter: |
|
|
441 |
"""Get the dataset from a prepared patient ID split""" |
|
|
442 |
|
|
|
443 |
def __init__( |
|
|
444 |
self, |
|
|
445 |
inference_path, |
|
|
446 |
augmentation=None, |
|
|
447 |
smp_preprocessing=None, |
|
|
448 |
normalization=None, |
|
|
449 |
output_idx=False, |
|
|
450 |
attrib_types=None, |
|
|
451 |
): |
|
|
452 |
super().__init__() |
|
|
453 |
|
|
|
454 |
self.augmentation = augmentation |
|
|
455 |
self.smp_preprocessing = smp_preprocessing |
|
|
456 |
self.normalization = normalization |
|
|
457 |
self.output_idx = output_idx |
|
|
458 |
self.attrib_types = attrib_types |
|
|
459 |
|
|
|
460 |
self.inference_path = Path(inference_path) |
|
|
461 |
|
|
|
462 |
all_paths = set(self.inference_path.glob("**/*")) |
|
|
463 |
dcms_paths = [] |
|
|
464 |
for path in all_paths: |
|
|
465 |
if path.is_file(): |
|
|
466 |
try: |
|
|
467 |
pydicom.filereader.dcmread(path, stop_before_pixels=True) |
|
|
468 |
dcms_paths.append(path) |
|
|
469 |
except pydicom.errors.InvalidDicomError: |
|
|
470 |
continue |
|
|
471 |
|
|
|
472 |
self.dcm2attribs, self.patient2dcm = make_dcmdicts( |
|
|
473 |
tuple(dcms_paths), label_status=False, WCM=False |
|
|
474 |
) |
|
|
475 |
|
|
|
476 |
if self.normalization is not None: |
|
|
477 |
self.normalization.update_dcm2attribs(self.dcm2attribs) |
|
|
478 |
|
|
|
479 |
def __call__(self): |
|
|
480 |
return InferenceDataset( |
|
|
481 |
dcm2attribs=self.dcm2attribs, |
|
|
482 |
patient2dcm=self.patient2dcm, |
|
|
483 |
augmentation=self.augmentation, |
|
|
484 |
smp_preprocessing=self.smp_preprocessing, |
|
|
485 |
normalization=self.normalization, |
|
|
486 |
output_idx=self.output_idx, |
|
|
487 |
attrib_types=self.attrib_types, |
|
|
488 |
) |