|
a |
|
b/utils/dataloader_utils.py |
|
|
1 |
#!/usr/bin/env python |
|
|
2 |
# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ). |
|
|
3 |
# |
|
|
4 |
# Licensed under the Apache License, Version 2.0 (the "License"); |
|
|
5 |
# you may not use this file except in compliance with the License. |
|
|
6 |
# You may obtain a copy of the License at |
|
|
7 |
# |
|
|
8 |
# http://www.apache.org/licenses/LICENSE-2.0 |
|
|
9 |
# |
|
|
10 |
# Unless required by applicable law or agreed to in writing, software |
|
|
11 |
# distributed under the License is distributed on an "AS IS" BASIS, |
|
|
12 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
13 |
# See the License for the specific language governing permissions and |
|
|
14 |
# limitations under the License. |
|
|
15 |
# ============================================================================== |
|
|
16 |
|
|
|
17 |
import numpy as np |
|
|
18 |
import os |
|
|
19 |
from multiprocessing import Pool |
|
|
20 |
|
|
|
21 |
|
|
|
22 |
|
|
|
23 |
def get_class_balanced_patients(class_targets, batch_size, num_classes, slack_factor=0.1): |
|
|
24 |
''' |
|
|
25 |
samples patients towards equilibrium of classes on a roi-level. For highly imbalanced datasets, this might be a too strong requirement. |
|
|
26 |
Hence a slack factor determines the ratio of the batch, that is randomly sampled, before class-balance is triggered. |
|
|
27 |
:param class_targets: list of patient targets. where each patient target is a list of class labels of respective rois. |
|
|
28 |
:param batch_size: |
|
|
29 |
:param num_classes: |
|
|
30 |
:param slack_factor: |
|
|
31 |
:return: batch_ixs: list of indices referring to a subset in class_targets-list, sampled to build one batch. |
|
|
32 |
''' |
|
|
33 |
batch_ixs = [] |
|
|
34 |
class_count = {k: 0 for k in range(num_classes)} |
|
|
35 |
weakest_class = 0 |
|
|
36 |
for ix in range(batch_size): |
|
|
37 |
|
|
|
38 |
keep_looking = True |
|
|
39 |
while keep_looking: |
|
|
40 |
#choose a random patient. |
|
|
41 |
cand = np.random.choice(len(class_targets), 1)[0] |
|
|
42 |
# check the least occuring class among this patient's rois. |
|
|
43 |
tmp_weakest_class = np.argmin([class_targets[cand].count(ii) for ii in range(num_classes)]) |
|
|
44 |
# if current batch already bigger than the slack_factor ratio, then |
|
|
45 |
# check that weakest class in this patient is not the weakest in current batch (since needs to be boosted) |
|
|
46 |
# also that at least one roi of this patient belongs to weakest class. If True, keep patient, else keep looking. |
|
|
47 |
if (tmp_weakest_class != weakest_class and class_targets[cand].count(weakest_class) > 0) or ix < int(batch_size * slack_factor): |
|
|
48 |
keep_looking = False |
|
|
49 |
|
|
|
50 |
for c in range(num_classes): |
|
|
51 |
class_count[c] += class_targets[cand].count(c) |
|
|
52 |
weakest_class = np.argmin(([class_count[c] for c in range(num_classes)])) |
|
|
53 |
batch_ixs.append(cand) |
|
|
54 |
|
|
|
55 |
return batch_ixs |
|
|
56 |
|
|
|
57 |
|
|
|
58 |
|
|
|
59 |
class fold_generator: |
|
|
60 |
""" |
|
|
61 |
generates splits of indices for a given length of a dataset to perform n-fold cross-validation. |
|
|
62 |
splits each fold into 3 subsets for training, validation and testing. |
|
|
63 |
This form of cross validation uses an inner loop test set, which is useful if test scores shall be reported on a |
|
|
64 |
statistically reliable amount of patients, despite limited size of a dataset. |
|
|
65 |
If hold out test set is provided and hence no inner loop test set needed, just add test_idxs to the training data in the dataloader. |
|
|
66 |
This creates straight-forward train-val splits. |
|
|
67 |
:returns names list: list of len n_splits. each element is a list of len 3 for train_ix, val_ix, test_ix. |
|
|
68 |
""" |
|
|
69 |
def __init__(self, seed, n_splits, len_data): |
|
|
70 |
""" |
|
|
71 |
:param seed: Random seed for splits. |
|
|
72 |
:param n_splits: number of splits, e.g. 5 splits for 5-fold cross-validation |
|
|
73 |
:param len_data: number of elements in the dataset. |
|
|
74 |
""" |
|
|
75 |
self.tr_ix = [] |
|
|
76 |
self.val_ix = [] |
|
|
77 |
self.te_ix = [] |
|
|
78 |
self.slicer = None |
|
|
79 |
self.missing = 0 |
|
|
80 |
self.fold = 0 |
|
|
81 |
self.len_data = len_data |
|
|
82 |
self.n_splits = n_splits |
|
|
83 |
self.myseed = seed |
|
|
84 |
self.boost_val = 0 |
|
|
85 |
|
|
|
86 |
def init_indices(self): |
|
|
87 |
|
|
|
88 |
t = list(np.arange(self.l)) |
|
|
89 |
# round up to next splittable data amount. |
|
|
90 |
split_length = int(np.ceil(len(t) / float(self.n_splits))) |
|
|
91 |
self.slicer = split_length |
|
|
92 |
self.mod = len(t) % self.n_splits |
|
|
93 |
if self.mod > 0: |
|
|
94 |
# missing is the number of folds, in which the new splits are reduced to account for missing data. |
|
|
95 |
self.missing = self.n_splits - self.mod |
|
|
96 |
|
|
|
97 |
self.te_ix = t[:self.slicer] |
|
|
98 |
self.tr_ix = t[self.slicer:] |
|
|
99 |
self.val_ix = self.tr_ix[:self.slicer] |
|
|
100 |
self.tr_ix = self.tr_ix[self.slicer:] |
|
|
101 |
|
|
|
102 |
def new_fold(self): |
|
|
103 |
|
|
|
104 |
slicer = self.slicer |
|
|
105 |
if self.fold < self.missing : |
|
|
106 |
slicer = self.slicer - 1 |
|
|
107 |
|
|
|
108 |
temp = self.te_ix |
|
|
109 |
|
|
|
110 |
# catch exception mod == 1: test set collects 1+ data since walk through both roudned up splits. |
|
|
111 |
# account for by reducing last fold split by 1. |
|
|
112 |
if self.fold == self.n_splits-2 and self.mod ==1: |
|
|
113 |
temp += self.val_ix[-1:] |
|
|
114 |
self.val_ix = self.val_ix[:-1] |
|
|
115 |
|
|
|
116 |
self.te_ix = self.val_ix |
|
|
117 |
self.val_ix = self.tr_ix[:slicer] |
|
|
118 |
self.tr_ix = self.tr_ix[slicer:] + temp |
|
|
119 |
|
|
|
120 |
|
|
|
121 |
def get_fold_names(self): |
|
|
122 |
names_list = [] |
|
|
123 |
rgen = np.random.RandomState(self.myseed) |
|
|
124 |
cv_names = np.arange(self.len_data) |
|
|
125 |
|
|
|
126 |
rgen.shuffle(cv_names) |
|
|
127 |
self.l = len(cv_names) |
|
|
128 |
self.init_indices() |
|
|
129 |
|
|
|
130 |
for split in range(self.n_splits): |
|
|
131 |
train_names, val_names, test_names = cv_names[self.tr_ix], cv_names[self.val_ix], cv_names[self.te_ix] |
|
|
132 |
names_list.append([train_names, val_names, test_names, self.fold]) |
|
|
133 |
self.new_fold() |
|
|
134 |
self.fold += 1 |
|
|
135 |
|
|
|
136 |
return names_list |
|
|
137 |
|
|
|
138 |
|
|
|
139 |
|
|
|
140 |
def get_patch_crop_coords(img, patch_size, min_overlap=30): |
|
|
141 |
""" |
|
|
142 |
|
|
|
143 |
_:param img (y, x, (z)) |
|
|
144 |
_:param patch_size: list of len 2 (2D) or 3 (3D). |
|
|
145 |
_:param min_overlap: minimum required overlap of patches. |
|
|
146 |
If too small, some areas are poorly represented only at edges of single patches. |
|
|
147 |
_:return ndarray: shape (n_patches, 2*dim). crop coordinates for each patch. |
|
|
148 |
""" |
|
|
149 |
crop_coords = [] |
|
|
150 |
for dim in range(len(img.shape)): |
|
|
151 |
n_patches = int(np.ceil(img.shape[dim] / patch_size[dim])) |
|
|
152 |
|
|
|
153 |
# no crops required in this dimension, add image shape as coordinates. |
|
|
154 |
if n_patches == 1: |
|
|
155 |
crop_coords.append([(0, img.shape[dim])]) |
|
|
156 |
continue |
|
|
157 |
|
|
|
158 |
# fix the two outside patches to coords patchsize/2 and interpolate. |
|
|
159 |
center_dists = (img.shape[dim] - patch_size[dim]) / (n_patches - 1) |
|
|
160 |
|
|
|
161 |
if (patch_size[dim] - center_dists) < min_overlap: |
|
|
162 |
n_patches += 1 |
|
|
163 |
center_dists = (img.shape[dim] - patch_size[dim]) / (n_patches - 1) |
|
|
164 |
|
|
|
165 |
patch_centers = np.round([(patch_size[dim] / 2 + (center_dists * ii)) for ii in range(n_patches)]) |
|
|
166 |
dim_crop_coords = [(center - patch_size[dim] / 2, center + patch_size[dim] / 2) for center in patch_centers] |
|
|
167 |
crop_coords.append(dim_crop_coords) |
|
|
168 |
|
|
|
169 |
coords_mesh_grid = [] |
|
|
170 |
for ymin, ymax in crop_coords[0]: |
|
|
171 |
for xmin, xmax in crop_coords[1]: |
|
|
172 |
if len(crop_coords) == 3 and patch_size[2] > 1: |
|
|
173 |
for zmin, zmax in crop_coords[2]: |
|
|
174 |
coords_mesh_grid.append([ymin, ymax, xmin, xmax, zmin, zmax]) |
|
|
175 |
elif len(crop_coords) == 3 and patch_size[2] == 1: |
|
|
176 |
for zmin in range(img.shape[2]): |
|
|
177 |
coords_mesh_grid.append([ymin, ymax, xmin, xmax, zmin, zmin + 1]) |
|
|
178 |
else: |
|
|
179 |
coords_mesh_grid.append([ymin, ymax, xmin, xmax]) |
|
|
180 |
return np.array(coords_mesh_grid).astype(int) |
|
|
181 |
|
|
|
182 |
|
|
|
183 |
|
|
|
184 |
def pad_nd_image(image, new_shape=None, mode="edge", kwargs=None, return_slicer=False, shape_must_be_divisible_by=None): |
|
|
185 |
""" |
|
|
186 |
one padder to pad them all. Documentation? Well okay. A little bit. by Fabian Isensee |
|
|
187 |
|
|
|
188 |
:param image: nd image. can be anything |
|
|
189 |
:param new_shape: what shape do you want? new_shape does not have to have the same dimensionality as image. If |
|
|
190 |
len(new_shape) < len(image.shape) then the last axes of image will be padded. If new_shape < image.shape in any of |
|
|
191 |
the axes then we will not pad that axis, but also not crop! (interpret new_shape as new_min_shape) |
|
|
192 |
Example: |
|
|
193 |
image.shape = (10, 1, 512, 512); new_shape = (768, 768) -> result: (10, 1, 768, 768). Cool, huh? |
|
|
194 |
image.shape = (10, 1, 512, 512); new_shape = (364, 768) -> result: (10, 1, 512, 768). |
|
|
195 |
|
|
|
196 |
:param mode: see np.pad for documentation |
|
|
197 |
:param return_slicer: if True then this function will also return what coords you will need to use when cropping back |
|
|
198 |
to original shape |
|
|
199 |
:param shape_must_be_divisible_by: for network prediction. After applying new_shape, make sure the new shape is |
|
|
200 |
divisibly by that number (can also be a list with an entry for each axis). Whatever is missing to match that will |
|
|
201 |
be padded (so the result may be larger than new_shape if shape_must_be_divisible_by is not None) |
|
|
202 |
:param kwargs: see np.pad for documentation |
|
|
203 |
""" |
|
|
204 |
if kwargs is None: |
|
|
205 |
kwargs = {} |
|
|
206 |
|
|
|
207 |
if new_shape is not None: |
|
|
208 |
old_shape = np.array(image.shape[-len(new_shape):]) |
|
|
209 |
else: |
|
|
210 |
assert shape_must_be_divisible_by is not None |
|
|
211 |
assert isinstance(shape_must_be_divisible_by, (list, tuple, np.ndarray)) |
|
|
212 |
new_shape = image.shape[-len(shape_must_be_divisible_by):] |
|
|
213 |
old_shape = new_shape |
|
|
214 |
|
|
|
215 |
num_axes_nopad = len(image.shape) - len(new_shape) |
|
|
216 |
|
|
|
217 |
new_shape = [max(new_shape[i], old_shape[i]) for i in range(len(new_shape))] |
|
|
218 |
|
|
|
219 |
if not isinstance(new_shape, np.ndarray): |
|
|
220 |
new_shape = np.array(new_shape) |
|
|
221 |
|
|
|
222 |
if shape_must_be_divisible_by is not None: |
|
|
223 |
if not isinstance(shape_must_be_divisible_by, (list, tuple, np.ndarray)): |
|
|
224 |
shape_must_be_divisible_by = [shape_must_be_divisible_by] * len(new_shape) |
|
|
225 |
else: |
|
|
226 |
assert len(shape_must_be_divisible_by) == len(new_shape) |
|
|
227 |
|
|
|
228 |
for i in range(len(new_shape)): |
|
|
229 |
if new_shape[i] % shape_must_be_divisible_by[i] == 0: |
|
|
230 |
new_shape[i] -= shape_must_be_divisible_by[i] |
|
|
231 |
|
|
|
232 |
new_shape = np.array([new_shape[i] + shape_must_be_divisible_by[i] - new_shape[i] % shape_must_be_divisible_by[i] for i in range(len(new_shape))]) |
|
|
233 |
|
|
|
234 |
difference = new_shape - old_shape |
|
|
235 |
pad_below = difference // 2 |
|
|
236 |
pad_above = difference // 2 + difference % 2 |
|
|
237 |
pad_list = [[0, 0]]*num_axes_nopad + list([list(i) for i in zip(pad_below, pad_above)]) |
|
|
238 |
res = np.pad(image, pad_list, mode, **kwargs) |
|
|
239 |
if not return_slicer: |
|
|
240 |
return res |
|
|
241 |
else: |
|
|
242 |
pad_list = np.array(pad_list) |
|
|
243 |
pad_list[:, 1] = np.array(res.shape) - pad_list[:, 1] |
|
|
244 |
slicer = list(slice(*i) for i in pad_list) |
|
|
245 |
return res, slicer |
|
|
246 |
|
|
|
247 |
|
|
|
248 |
############################# |
|
|
249 |
# data packing / unpacking # |
|
|
250 |
############################# |
|
|
251 |
|
|
|
252 |
def get_case_identifiers(folder): |
|
|
253 |
case_identifiers = [i[:-4] for i in os.listdir(folder) if i.endswith("npz")] |
|
|
254 |
return case_identifiers |
|
|
255 |
|
|
|
256 |
|
|
|
257 |
def convert_to_npy(npz_file, remove=False): |
|
|
258 |
identifier = os.path.split(npz_file)[1][:-4] |
|
|
259 |
if not os.path.isfile(npz_file[:-4] + ".npy"): |
|
|
260 |
a = np.load(npz_file)[identifier] |
|
|
261 |
np.save(npz_file[:-4] + ".npy", a) |
|
|
262 |
if remove: |
|
|
263 |
os.remove(npz_file) |
|
|
264 |
|
|
|
265 |
|
|
|
266 |
def unpack_dataset(folder, threads=8): |
|
|
267 |
case_identifiers = get_case_identifiers(folder) |
|
|
268 |
p = Pool(threads) |
|
|
269 |
npz_files = [os.path.join(folder, i + ".npz") for i in case_identifiers] |
|
|
270 |
p.starmap(convert_to_npy, [(f, True) for f in npz_files]) |
|
|
271 |
p.close() |
|
|
272 |
p.join() |
|
|
273 |
|
|
|
274 |
|
|
|
275 |
def delete_npy(folder): |
|
|
276 |
case_identifiers = get_case_identifiers(folder) |
|
|
277 |
npy_files = [os.path.join(folder, i + ".npy") for i in case_identifiers] |
|
|
278 |
npy_files = [i for i in npy_files if os.path.isfile(i)] |
|
|
279 |
for n in npy_files: |
|
|
280 |
os.remove(n) |