Diff of /dataset.py [000000] .. [9cc651]

Switch to unified view

a b/dataset.py
1
import os
2
import random
3
4
import numpy as np
5
import torch
6
from skimage.io import imread
7
from torch.utils.data import Dataset
8
9
from utils import crop_sample, pad_sample, resize_sample, normalize_volume
10
11
12
class BrainSegmentationDataset(Dataset):
13
    """Brain MRI dataset for FLAIR abnormality segmentation"""
14
15
    in_channels = 3
16
    out_channels = 1
17
18
    def __init__(
19
        self,
20
        images_dir,
21
        transform=None,
22
        image_size=256,
23
        subset="train",
24
        random_sampling=True,
25
        validation_cases=10,
26
        seed=42,
27
    ):
28
        assert subset in ["all", "train", "validation"]
29
30
        # read images
31
        volumes = {}
32
        masks = {}
33
        print("reading {} images...".format(subset))
34
        for (dirpath, dirnames, filenames) in os.walk(images_dir):
35
            image_slices = []
36
            mask_slices = []
37
            for filename in sorted(
38
                filter(lambda f: ".tif" in f, filenames),
39
                key=lambda x: int(x.split(".")[-2].split("_")[4]),
40
            ):
41
                filepath = os.path.join(dirpath, filename)
42
                if "mask" in filename:
43
                    mask_slices.append(imread(filepath, as_gray=True))
44
                else:
45
                    image_slices.append(imread(filepath))
46
            if len(image_slices) > 0:
47
                patient_id = dirpath.split("/")[-1]
48
                volumes[patient_id] = np.array(image_slices[1:-1])
49
                masks[patient_id] = np.array(mask_slices[1:-1])
50
51
        self.patients = sorted(volumes)
52
53
        # select cases to subset
54
        if not subset == "all":
55
            random.seed(seed)
56
            validation_patients = random.sample(self.patients, k=validation_cases)
57
            if subset == "validation":
58
                self.patients = validation_patients
59
            else:
60
                self.patients = sorted(
61
                    list(set(self.patients).difference(validation_patients))
62
                )
63
64
        print("preprocessing {} volumes...".format(subset))
65
        # create list of tuples (volume, mask)
66
        self.volumes = [(volumes[k], masks[k]) for k in self.patients]
67
68
        print("cropping {} volumes...".format(subset))
69
        # crop to smallest enclosing volume
70
        self.volumes = [crop_sample(v) for v in self.volumes]
71
72
        print("padding {} volumes...".format(subset))
73
        # pad to square
74
        self.volumes = [pad_sample(v) for v in self.volumes]
75
76
        print("resizing {} volumes...".format(subset))
77
        # resize
78
        self.volumes = [resize_sample(v, size=image_size) for v in self.volumes]
79
80
        print("normalizing {} volumes...".format(subset))
81
        # normalize channel-wise
82
        self.volumes = [(normalize_volume(v), m) for v, m in self.volumes]
83
84
        # probabilities for sampling slices based on masks
85
        self.slice_weights = [m.sum(axis=-1).sum(axis=-1) for v, m in self.volumes]
86
        self.slice_weights = [
87
            (s + (s.sum() * 0.1 / len(s))) / (s.sum() * 1.1) for s in self.slice_weights
88
        ]
89
90
        # add channel dimension to masks
91
        self.volumes = [(v, m[..., np.newaxis]) for (v, m) in self.volumes]
92
93
        print("done creating {} dataset".format(subset))
94
95
        # create global index for patient and slice (idx -> (p_idx, s_idx))
96
        num_slices = [v.shape[0] for v, m in self.volumes]
97
        self.patient_slice_index = list(
98
            zip(
99
                sum([[i] * num_slices[i] for i in range(len(num_slices))], []),
100
                sum([list(range(x)) for x in num_slices], []),
101
            )
102
        )
103
104
        self.random_sampling = random_sampling
105
106
        self.transform = transform
107
108
    def __len__(self):
109
        return len(self.patient_slice_index)
110
111
    def __getitem__(self, idx):
112
        patient = self.patient_slice_index[idx][0]
113
        slice_n = self.patient_slice_index[idx][1]
114
115
        if self.random_sampling:
116
            patient = np.random.randint(len(self.volumes))
117
            slice_n = np.random.choice(
118
                range(self.volumes[patient][0].shape[0]), p=self.slice_weights[patient]
119
            )
120
121
        v, m = self.volumes[patient]
122
        image = v[slice_n]
123
        mask = m[slice_n]
124
125
        if self.transform is not None:
126
            image, mask = self.transform((image, mask))
127
128
        # fix dimensions (C, H, W)
129
        image = image.transpose(2, 0, 1)
130
        mask = mask.transpose(2, 0, 1)
131
132
        image_tensor = torch.from_numpy(image.astype(np.float32))
133
        mask_tensor = torch.from_numpy(mask.astype(np.float32))
134
135
        # return tensors
136
        return image_tensor, mask_tensor