|
a |
|
b/dataset.py |
|
|
1 |
import os |
|
|
2 |
import random |
|
|
3 |
|
|
|
4 |
import numpy as np |
|
|
5 |
import torch |
|
|
6 |
from skimage.io import imread |
|
|
7 |
from torch.utils.data import Dataset |
|
|
8 |
|
|
|
9 |
from utils import crop_sample, pad_sample, resize_sample, normalize_volume |
|
|
10 |
|
|
|
11 |
|
|
|
12 |
class BrainSegmentationDataset(Dataset): |
|
|
13 |
"""Brain MRI dataset for FLAIR abnormality segmentation""" |
|
|
14 |
|
|
|
15 |
in_channels = 3 |
|
|
16 |
out_channels = 1 |
|
|
17 |
|
|
|
18 |
def __init__( |
|
|
19 |
self, |
|
|
20 |
images_dir, |
|
|
21 |
transform=None, |
|
|
22 |
image_size=256, |
|
|
23 |
subset="train", |
|
|
24 |
random_sampling=True, |
|
|
25 |
validation_cases=10, |
|
|
26 |
seed=42, |
|
|
27 |
): |
|
|
28 |
assert subset in ["all", "train", "validation"] |
|
|
29 |
|
|
|
30 |
# read images |
|
|
31 |
volumes = {} |
|
|
32 |
masks = {} |
|
|
33 |
print("reading {} images...".format(subset)) |
|
|
34 |
for (dirpath, dirnames, filenames) in os.walk(images_dir): |
|
|
35 |
image_slices = [] |
|
|
36 |
mask_slices = [] |
|
|
37 |
for filename in sorted( |
|
|
38 |
filter(lambda f: ".tif" in f, filenames), |
|
|
39 |
key=lambda x: int(x.split(".")[-2].split("_")[4]), |
|
|
40 |
): |
|
|
41 |
filepath = os.path.join(dirpath, filename) |
|
|
42 |
if "mask" in filename: |
|
|
43 |
mask_slices.append(imread(filepath, as_gray=True)) |
|
|
44 |
else: |
|
|
45 |
image_slices.append(imread(filepath)) |
|
|
46 |
if len(image_slices) > 0: |
|
|
47 |
patient_id = dirpath.split("/")[-1] |
|
|
48 |
volumes[patient_id] = np.array(image_slices[1:-1]) |
|
|
49 |
masks[patient_id] = np.array(mask_slices[1:-1]) |
|
|
50 |
|
|
|
51 |
self.patients = sorted(volumes) |
|
|
52 |
|
|
|
53 |
# select cases to subset |
|
|
54 |
if not subset == "all": |
|
|
55 |
random.seed(seed) |
|
|
56 |
validation_patients = random.sample(self.patients, k=validation_cases) |
|
|
57 |
if subset == "validation": |
|
|
58 |
self.patients = validation_patients |
|
|
59 |
else: |
|
|
60 |
self.patients = sorted( |
|
|
61 |
list(set(self.patients).difference(validation_patients)) |
|
|
62 |
) |
|
|
63 |
|
|
|
64 |
print("preprocessing {} volumes...".format(subset)) |
|
|
65 |
# create list of tuples (volume, mask) |
|
|
66 |
self.volumes = [(volumes[k], masks[k]) for k in self.patients] |
|
|
67 |
|
|
|
68 |
print("cropping {} volumes...".format(subset)) |
|
|
69 |
# crop to smallest enclosing volume |
|
|
70 |
self.volumes = [crop_sample(v) for v in self.volumes] |
|
|
71 |
|
|
|
72 |
print("padding {} volumes...".format(subset)) |
|
|
73 |
# pad to square |
|
|
74 |
self.volumes = [pad_sample(v) for v in self.volumes] |
|
|
75 |
|
|
|
76 |
print("resizing {} volumes...".format(subset)) |
|
|
77 |
# resize |
|
|
78 |
self.volumes = [resize_sample(v, size=image_size) for v in self.volumes] |
|
|
79 |
|
|
|
80 |
print("normalizing {} volumes...".format(subset)) |
|
|
81 |
# normalize channel-wise |
|
|
82 |
self.volumes = [(normalize_volume(v), m) for v, m in self.volumes] |
|
|
83 |
|
|
|
84 |
# probabilities for sampling slices based on masks |
|
|
85 |
self.slice_weights = [m.sum(axis=-1).sum(axis=-1) for v, m in self.volumes] |
|
|
86 |
self.slice_weights = [ |
|
|
87 |
(s + (s.sum() * 0.1 / len(s))) / (s.sum() * 1.1) for s in self.slice_weights |
|
|
88 |
] |
|
|
89 |
|
|
|
90 |
# add channel dimension to masks |
|
|
91 |
self.volumes = [(v, m[..., np.newaxis]) for (v, m) in self.volumes] |
|
|
92 |
|
|
|
93 |
print("done creating {} dataset".format(subset)) |
|
|
94 |
|
|
|
95 |
# create global index for patient and slice (idx -> (p_idx, s_idx)) |
|
|
96 |
num_slices = [v.shape[0] for v, m in self.volumes] |
|
|
97 |
self.patient_slice_index = list( |
|
|
98 |
zip( |
|
|
99 |
sum([[i] * num_slices[i] for i in range(len(num_slices))], []), |
|
|
100 |
sum([list(range(x)) for x in num_slices], []), |
|
|
101 |
) |
|
|
102 |
) |
|
|
103 |
|
|
|
104 |
self.random_sampling = random_sampling |
|
|
105 |
|
|
|
106 |
self.transform = transform |
|
|
107 |
|
|
|
108 |
def __len__(self): |
|
|
109 |
return len(self.patient_slice_index) |
|
|
110 |
|
|
|
111 |
def __getitem__(self, idx): |
|
|
112 |
patient = self.patient_slice_index[idx][0] |
|
|
113 |
slice_n = self.patient_slice_index[idx][1] |
|
|
114 |
|
|
|
115 |
if self.random_sampling: |
|
|
116 |
patient = np.random.randint(len(self.volumes)) |
|
|
117 |
slice_n = np.random.choice( |
|
|
118 |
range(self.volumes[patient][0].shape[0]), p=self.slice_weights[patient] |
|
|
119 |
) |
|
|
120 |
|
|
|
121 |
v, m = self.volumes[patient] |
|
|
122 |
image = v[slice_n] |
|
|
123 |
mask = m[slice_n] |
|
|
124 |
|
|
|
125 |
if self.transform is not None: |
|
|
126 |
image, mask = self.transform((image, mask)) |
|
|
127 |
|
|
|
128 |
# fix dimensions (C, H, W) |
|
|
129 |
image = image.transpose(2, 0, 1) |
|
|
130 |
mask = mask.transpose(2, 0, 1) |
|
|
131 |
|
|
|
132 |
image_tensor = torch.from_numpy(image.astype(np.float32)) |
|
|
133 |
mask_tensor = torch.from_numpy(mask.astype(np.float32)) |
|
|
134 |
|
|
|
135 |
# return tensors |
|
|
136 |
return image_tensor, mask_tensor |