a b/adpkd_segmentation/datasets/datasets.py
1
import json
2
import numpy as np
3
import torch
4
from pathlib import Path
5
import pandas as pd
6
import pydicom
7
from ast import literal_eval
8
9
from adpkd_segmentation.data.data_utils import (
10
    get_labeled,
11
    get_y_Path,
12
    int16_to_uint8,
13
    make_dcmdicts,
14
    path_2dcm_int16,
15
    path_2label,
16
    TKV_update,
17
)
18
19
from adpkd_segmentation.data.data_utils import (
20
    KIDNEY_PIXELS,
21
    STUDY_TKV,
22
    VOXEL_VOLUME,
23
)
24
25
from adpkd_segmentation.datasets.filters import PatientFiltering
26
27
28
class SegmentationDataset(torch.utils.data.Dataset):
29
    """Some information about SegmentationDataset"""
30
31
    def __init__(
32
        self,
33
        label2mask,
34
        dcm2attribs,
35
        patient2dcm,
36
        patient_IDS=None,
37
        augmentation=None,
38
        smp_preprocessing=None,
39
        normalization=None,
40
        output_idx=False,
41
        attrib_types=None,
42
    ):
43
44
        super().__init__()
45
        self.label2mask = label2mask
46
        self.dcm2attribs = dcm2attribs
47
        self.pt2dcm = patient2dcm
48
        self.patient_IDS = patient_IDS
49
        self.augmentation = augmentation
50
        self.smp_preprocessing = smp_preprocessing
51
        self.normalization = normalization
52
        self.output_idx = output_idx
53
        self.attrib_types = attrib_types
54
55
        # store some attributes as PyTorch tensors
56
        if self.attrib_types is None:
57
            self.attrib_types = {
58
                STUDY_TKV: "float32",
59
                KIDNEY_PIXELS: "float32",
60
                VOXEL_VOLUME: "float32",
61
            }
62
63
        self.patients = list(patient2dcm.keys())
64
        # kept for compatibility with previous experiments
65
        # following patient order in patient_IDS
66
        if patient_IDS is not None:
67
            self.patients = patient_IDS
68
69
        self.dcm_paths = []
70
        for p in self.patients:
71
            self.dcm_paths.extend(patient2dcm[p])
72
        self.label_paths = [get_y_Path(dcm) for dcm in self.dcm_paths]
73
74
        # study_id to TKV and TKV for each dcm
75
        self.studies, self.dcm2attribs = TKV_update(dcm2attribs)
76
        # storring attrib types as tensors
77
        self.tensor_dict = self.prepare_tensor_dict(self.attrib_types)
78
79
    def __getitem__(self, index):
80
81
        if isinstance(index, slice):
82
            return [self[ii] for ii in range(*index.indices(len(self)))]
83
84
        # numpy int16, (H, W)
85
        im_path = self.dcm_paths[index]
86
        image = path_2dcm_int16(im_path)
87
        # image local scaling by default to convert to uint8
88
        if self.normalization is None:
89
            image = int16_to_uint8(image)
90
        else:
91
            image = self.normalization(image, self.dcm2attribs[im_path])
92
93
        label = path_2label(self.label_paths[index])
94
95
        # numpy uint8, one hot encoded (C, H, W)
96
        mask = self.label2mask(label[np.newaxis, ...])
97
98
        if self.augmentation is not None:
99
            # requires (H, W, C) or (H, W)
100
            mask = mask.transpose(1, 2, 0)
101
            sample = self.augmentation(image=image, mask=mask)
102
            image, mask = sample["image"], sample["mask"]
103
            # get back to (C, H, W)
104
            mask = mask.transpose(2, 0, 1)
105
106
        # convert to float
107
        image = (image / 255).astype(np.float32)
108
        mask = mask.astype(np.float32)
109
110
        # smp preprocessing requires (H, W, 3)
111
        if self.smp_preprocessing is not None:
112
            image = np.repeat(image[..., np.newaxis], 3, axis=-1)
113
            image = self.smp_preprocessing(image).astype(np.float32)
114
            # get back to (3, H, W)
115
            image = image.transpose(2, 0, 1)
116
        else:
117
            # stack image to (3, H, W)
118
            image = np.repeat(image[np.newaxis, ...], 3, axis=0)
119
120
        if self.output_idx:
121
            return image, mask, index
122
        return image, mask
123
124
    def __len__(self):
125
        return len(self.dcm_paths)
126
127
    def get_verbose(self, index):
128
        """returns more details than __getitem__()
129
130
        Args:
131
            index (int): index in dataset
132
133
        Returns:
134
            tuple: sample, dcm_path, attributes dict
135
        """
136
137
        sample = self[index]
138
        dcm_path = self.dcm_paths[index]
139
        attribs = self.dcm2attribs[dcm_path]
140
141
        return sample, dcm_path, attribs
142
143
    def get_extra_dict(self, batch_of_idx):
144
        return {k: v[batch_of_idx] for k, v in self.tensor_dict.items()}
145
146
    def prepare_tensor_dict(self, attrib_types):
147
        tensor_dict = {}
148
        for k, v in attrib_types.items():
149
            tensor_dict[k] = torch.zeros(
150
                self.__len__(), dtype=getattr(torch, v)
151
            )
152
153
        for idx, _ in enumerate(self):
154
            dcm_path = self.dcm_paths[idx]
155
            attribs = self.dcm2attribs[dcm_path]
156
            for k, v in tensor_dict.items():
157
                v[idx] = attribs[k]
158
159
        return tensor_dict
160
161
162
class DatasetGetter:
163
    """Create SegmentationDataset"""
164
165
    def __init__(
166
        self,
167
        splitter,
168
        splitter_key,
169
        label2mask,
170
        augmentation=None,
171
        smp_preprocessing=None,
172
        filters=None,
173
        normalization=None,
174
        output_idx=False,
175
        attrib_types=None,
176
    ):
177
        super().__init__()
178
        self.splitter = splitter
179
        self.splitter_key = splitter_key
180
        self.label2mask = label2mask
181
        self.augmentation = augmentation
182
        self.smp_preprocessing = smp_preprocessing
183
        self.filters = filters
184
        self.normalization = normalization
185
        self.output_idx = output_idx
186
        self.attrib_types = attrib_types
187
188
        dcms_paths = sorted(get_labeled())
189
        print(
190
            "The number of images before splitting and filtering: {}".format(
191
                len(dcms_paths)
192
            )
193
        )
194
        dcm2attribs, patient2dcm = make_dcmdicts(tuple(dcms_paths))
195
196
        if filters is not None:
197
            dcm2attribs, patient2dcm = filters(dcm2attribs, patient2dcm)
198
199
        self.all_patient_IDS = list(patient2dcm.keys())
200
        # train, val, or test
201
        self.patient_IDS = self.splitter(self.all_patient_IDS)[
202
            self.splitter_key
203
        ]
204
205
        patient_filter = PatientFiltering(self.patient_IDS)
206
        self.dcm2attribs, self.patient2dcm = patient_filter(
207
            dcm2attribs, patient2dcm
208
        )
209
        if self.normalization is not None:
210
            self.normalization.update_dcm2attribs(self.dcm2attribs)
211
212
    def __call__(self):
213
        return SegmentationDataset(
214
            label2mask=self.label2mask,
215
            dcm2attribs=self.dcm2attribs,
216
            patient2dcm=self.patient2dcm,
217
            patient_IDS=self.patient_IDS,
218
            augmentation=self.augmentation,
219
            smp_preprocessing=self.smp_preprocessing,
220
            normalization=self.normalization,
221
            output_idx=self.output_idx,
222
            attrib_types=self.attrib_types,
223
        )
224
225
226
class JsonDatasetGetter:
227
    """Get the dataset from a prepared patient ID split"""
228
229
    def __init__(
230
        self,
231
        json_path,
232
        splitter_key,
233
        label2mask,
234
        augmentation=None,
235
        smp_preprocessing=None,
236
        normalization=None,
237
        output_idx=False,
238
        attrib_types=None,
239
    ):
240
        super().__init__()
241
242
        self.label2mask = label2mask
243
        self.augmentation = augmentation
244
        self.smp_preprocessing = smp_preprocessing
245
        self.normalization = normalization
246
        self.output_idx = output_idx
247
        self.attrib_types = attrib_types
248
249
        dcms_paths = sorted(get_labeled())
250
        print(
251
            "The number of images before splitting and filtering: {}".format(
252
                len(dcms_paths)
253
            )
254
        )
255
        dcm2attribs, patient2dcm = make_dcmdicts(tuple(dcms_paths))
256
257
        print("Loading ", json_path)
258
        with open(json_path, "r") as f:
259
            dataset_split = json.load(f)
260
        self.patient_IDS = dataset_split[splitter_key]
261
262
        # filter info dicts to correpsond to patient_IDS
263
        patient_filter = PatientFiltering(self.patient_IDS)
264
        self.dcm2attribs, self.patient2dcm = patient_filter(
265
            dcm2attribs, patient2dcm
266
        )
267
        if self.normalization is not None:
268
            self.normalization.update_dcm2attribs(self.dcm2attribs)
269
270
    def __call__(self):
271
        return SegmentationDataset(
272
            label2mask=self.label2mask,
273
            dcm2attribs=self.dcm2attribs,
274
            patient2dcm=self.patient2dcm,
275
            patient_IDS=self.patient_IDS,
276
            augmentation=self.augmentation,
277
            smp_preprocessing=self.smp_preprocessing,
278
            normalization=self.normalization,
279
            output_idx=self.output_idx,
280
            attrib_types=self.attrib_types,
281
        )
282
283
284
class InferenceDataset(torch.utils.data.Dataset):
285
    """Some information about SegmentationDataset"""
286
287
    def __init__(
288
        self,
289
        dcm2attribs,
290
        patient2dcm,
291
        augmentation=None,
292
        smp_preprocessing=None,
293
        normalization=None,
294
        output_idx=False,
295
        attrib_types=None,
296
    ):
297
298
        super().__init__()
299
        self.dcm2attribs = dcm2attribs
300
        self.pt2dcm = patient2dcm
301
        self.augmentation = augmentation
302
        self.smp_preprocessing = smp_preprocessing
303
        self.normalization = normalization
304
        self.output_idx = output_idx
305
        self.attrib_types = attrib_types
306
307
        self.patients = list(patient2dcm.keys())
308
309
        self.dcm_paths = []
310
        for p in self.patients:
311
            self.dcm_paths.extend(patient2dcm[p])
312
313
        # Sorts Studies by Z axis
314
        studies = [
315
            pydicom.dcmread(path).SeriesDescription for path in self.dcm_paths
316
        ]
317
        folders = [path.parent.name for path in self.dcm_paths]
318
        patients = [pydicom.dcmread(path).PatientID for path in self.dcm_paths]
319
        x_dims = [pydicom.dcmread(path).Rows for path in self.dcm_paths]
320
        y_dims = [pydicom.dcmread(path).Columns for path in self.dcm_paths]
321
        z_pos = [
322
            literal_eval(str(pydicom.dcmread(path).ImagePositionPatient))[2]
323
            for path in self.dcm_paths
324
        ]
325
        acc_nums = [
326
            pydicom.dcmread(path).AccessionNumber for path in self.dcm_paths
327
        ]
328
        ser_nums = [
329
            pydicom.dcmread(path).SeriesNumber for path in self.dcm_paths
330
        ]
331
332
        data = {
333
            "dcm_paths": self.dcm_paths,
334
            "folders": folders,
335
            "studies": studies,
336
            "patients": patients,
337
            "x_dims": x_dims,
338
            "y_dims": y_dims,
339
            "z_pos": z_pos,
340
            "acc_nums": acc_nums,
341
            "ser_nums": ser_nums,
342
        }
343
344
        group_keys = [
345
            "folders",
346
            "studies",
347
            "patients",
348
            "x_dims",
349
            "y_dims",
350
            "acc_nums",
351
            "ser_nums",
352
        ]
353
354
        dataset = pd.DataFrame.from_dict(data)
355
        dataset["slice_pos"] = ""
356
357
        grouped_dataset = dataset.groupby(group_keys)
358
359
        for (name, group) in grouped_dataset:
360
            sort_key = "z_pos"
361
362
            # handle missing slice position with filename
363
            if group[sort_key].isna().any():
364
                sort_key = "dcm_paths"
365
366
            zs = list(group[sort_key])
367
368
            sorted_idxs = np.argsort(zs)
369
            slice_map = {
370
                zs[idx]: pos for idx, pos in zip(sorted_idxs, range(len(zs)))
371
            }
372
            zs_slice_pos = group[sort_key].map(slice_map)
373
374
            for i in group.index:
375
                dataset.at[i, "slice_pos"] = zs_slice_pos.get(i)
376
377
        grouped_dataset = dataset.groupby(group_keys)
378
        for (name, group) in grouped_dataset:
379
            group.sort_values(by="slice_pos", inplace=True)
380
381
        self.df = dataset
382
        self.dcm_paths = list(dataset["dcm_paths"])
383
384
    def __getitem__(self, index):
385
386
        if isinstance(index, slice):
387
            return [self[ii] for ii in range(*index.indices(len(self)))]
388
389
        # numpy int16, (H, W)
390
        im_path = self.dcm_paths[index]
391
        image = path_2dcm_int16(im_path)
392
        # image local scaling by default to convert to uint8
393
        if self.normalization is None:
394
            image = int16_to_uint8(image)
395
        else:
396
            image = self.normalization(image, self.dcm2attribs[im_path])
397
398
        if self.augmentation is not None:
399
            sample = self.augmentation(image=image)
400
            image = sample["image"]
401
402
        # convert to float
403
        image = (image / 255).astype(np.float32)
404
405
        # smp preprocessing requires (H, W, 3)
406
        if self.smp_preprocessing is not None:
407
            image = np.repeat(image[..., np.newaxis], 3, axis=-1)
408
            image = self.smp_preprocessing(image).astype(np.float32)
409
            # get back to (3, H, W)
410
            image = image.transpose(2, 0, 1)
411
        else:
412
            # stack image to (3, H, W)
413
            image = np.repeat(image[np.newaxis, ...], 3, axis=0)
414
415
        if self.output_idx:
416
            return image, index
417
418
        return image
419
420
    def __len__(self):
421
        return len(self.dcm_paths)
422
423
    def get_verbose(self, index):
424
        """returns more details than __getitem__()
425
426
        Args:
427
            index (int): index in dataset
428
429
        Returns:
430
            tuple: sample, dcm_path, attributes dict
431
        """
432
433
        sample = self[index]
434
        dcm_path = self.dcm_paths[index]
435
        attribs = self.dcm2attribs[dcm_path]
436
437
        return sample, dcm_path, attribs
438
439
440
class InferenceDatasetGetter:
441
    """Get the dataset from a prepared patient ID split"""
442
443
    def __init__(
444
        self,
445
        inference_path,
446
        augmentation=None,
447
        smp_preprocessing=None,
448
        normalization=None,
449
        output_idx=False,
450
        attrib_types=None,
451
    ):
452
        super().__init__()
453
454
        self.augmentation = augmentation
455
        self.smp_preprocessing = smp_preprocessing
456
        self.normalization = normalization
457
        self.output_idx = output_idx
458
        self.attrib_types = attrib_types
459
460
        self.inference_path = Path(inference_path)
461
462
        all_paths = set(self.inference_path.glob("**/*"))
463
        dcms_paths = []
464
        for path in all_paths:
465
            if path.is_file():
466
                try:
467
                    pydicom.filereader.dcmread(path, stop_before_pixels=True)
468
                    dcms_paths.append(path)
469
                except pydicom.errors.InvalidDicomError:
470
                    continue
471
472
        self.dcm2attribs, self.patient2dcm = make_dcmdicts(
473
            tuple(dcms_paths), label_status=False, WCM=False
474
        )
475
476
        if self.normalization is not None:
477
            self.normalization.update_dcm2attribs(self.dcm2attribs)
478
479
    def __call__(self):
480
        return InferenceDataset(
481
            dcm2attribs=self.dcm2attribs,
482
            patient2dcm=self.patient2dcm,
483
            augmentation=self.augmentation,
484
            smp_preprocessing=self.smp_preprocessing,
485
            normalization=self.normalization,
486
            output_idx=self.output_idx,
487
            attrib_types=self.attrib_types,
488
        )