[bb7f56]: / utils / dataloader_utils.py

Download this file

281 lines (226 with data), 11.8 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
#!/usr/bin/env python
# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import os
from multiprocessing import Pool
def get_class_balanced_patients(class_targets, batch_size, num_classes, slack_factor=0.1):
'''
samples patients towards equilibrium of classes on a roi-level. For highly imbalanced datasets, this might be a too strong requirement.
Hence a slack factor determines the ratio of the batch, that is randomly sampled, before class-balance is triggered.
:param class_targets: list of patient targets. where each patient target is a list of class labels of respective rois.
:param batch_size:
:param num_classes:
:param slack_factor:
:return: batch_ixs: list of indices referring to a subset in class_targets-list, sampled to build one batch.
'''
batch_ixs = []
class_count = {k: 0 for k in range(num_classes)}
weakest_class = 0
for ix in range(batch_size):
keep_looking = True
while keep_looking:
#choose a random patient.
cand = np.random.choice(len(class_targets), 1)[0]
# check the least occuring class among this patient's rois.
tmp_weakest_class = np.argmin([class_targets[cand].count(ii) for ii in range(num_classes)])
# if current batch already bigger than the slack_factor ratio, then
# check that weakest class in this patient is not the weakest in current batch (since needs to be boosted)
# also that at least one roi of this patient belongs to weakest class. If True, keep patient, else keep looking.
if (tmp_weakest_class != weakest_class and class_targets[cand].count(weakest_class) > 0) or ix < int(batch_size * slack_factor):
keep_looking = False
for c in range(num_classes):
class_count[c] += class_targets[cand].count(c)
weakest_class = np.argmin(([class_count[c] for c in range(num_classes)]))
batch_ixs.append(cand)
return batch_ixs
class fold_generator:
"""
generates splits of indices for a given length of a dataset to perform n-fold cross-validation.
splits each fold into 3 subsets for training, validation and testing.
This form of cross validation uses an inner loop test set, which is useful if test scores shall be reported on a
statistically reliable amount of patients, despite limited size of a dataset.
If hold out test set is provided and hence no inner loop test set needed, just add test_idxs to the training data in the dataloader.
This creates straight-forward train-val splits.
:returns names list: list of len n_splits. each element is a list of len 3 for train_ix, val_ix, test_ix.
"""
def __init__(self, seed, n_splits, len_data):
"""
:param seed: Random seed for splits.
:param n_splits: number of splits, e.g. 5 splits for 5-fold cross-validation
:param len_data: number of elements in the dataset.
"""
self.tr_ix = []
self.val_ix = []
self.te_ix = []
self.slicer = None
self.missing = 0
self.fold = 0
self.len_data = len_data
self.n_splits = n_splits
self.myseed = seed
self.boost_val = 0
def init_indices(self):
t = list(np.arange(self.l))
# round up to next splittable data amount.
split_length = int(np.ceil(len(t) / float(self.n_splits)))
self.slicer = split_length
self.mod = len(t) % self.n_splits
if self.mod > 0:
# missing is the number of folds, in which the new splits are reduced to account for missing data.
self.missing = self.n_splits - self.mod
self.te_ix = t[:self.slicer]
self.tr_ix = t[self.slicer:]
self.val_ix = self.tr_ix[:self.slicer]
self.tr_ix = self.tr_ix[self.slicer:]
def new_fold(self):
slicer = self.slicer
if self.fold < self.missing :
slicer = self.slicer - 1
temp = self.te_ix
# catch exception mod == 1: test set collects 1+ data since walk through both roudned up splits.
# account for by reducing last fold split by 1.
if self.fold == self.n_splits-2 and self.mod ==1:
temp += self.val_ix[-1:]
self.val_ix = self.val_ix[:-1]
self.te_ix = self.val_ix
self.val_ix = self.tr_ix[:slicer]
self.tr_ix = self.tr_ix[slicer:] + temp
def get_fold_names(self):
names_list = []
rgen = np.random.RandomState(self.myseed)
cv_names = np.arange(self.len_data)
rgen.shuffle(cv_names)
self.l = len(cv_names)
self.init_indices()
for split in range(self.n_splits):
train_names, val_names, test_names = cv_names[self.tr_ix], cv_names[self.val_ix], cv_names[self.te_ix]
names_list.append([train_names, val_names, test_names, self.fold])
self.new_fold()
self.fold += 1
return names_list
def get_patch_crop_coords(img, patch_size, min_overlap=30):
"""
_:param img (y, x, (z))
_:param patch_size: list of len 2 (2D) or 3 (3D).
_:param min_overlap: minimum required overlap of patches.
If too small, some areas are poorly represented only at edges of single patches.
_:return ndarray: shape (n_patches, 2*dim). crop coordinates for each patch.
"""
crop_coords = []
for dim in range(len(img.shape)):
n_patches = int(np.ceil(img.shape[dim] / patch_size[dim]))
# no crops required in this dimension, add image shape as coordinates.
if n_patches == 1:
crop_coords.append([(0, img.shape[dim])])
continue
# fix the two outside patches to coords patchsize/2 and interpolate.
center_dists = (img.shape[dim] - patch_size[dim]) / (n_patches - 1)
if (patch_size[dim] - center_dists) < min_overlap:
n_patches += 1
center_dists = (img.shape[dim] - patch_size[dim]) / (n_patches - 1)
patch_centers = np.round([(patch_size[dim] / 2 + (center_dists * ii)) for ii in range(n_patches)])
dim_crop_coords = [(center - patch_size[dim] / 2, center + patch_size[dim] / 2) for center in patch_centers]
crop_coords.append(dim_crop_coords)
coords_mesh_grid = []
for ymin, ymax in crop_coords[0]:
for xmin, xmax in crop_coords[1]:
if len(crop_coords) == 3 and patch_size[2] > 1:
for zmin, zmax in crop_coords[2]:
coords_mesh_grid.append([ymin, ymax, xmin, xmax, zmin, zmax])
elif len(crop_coords) == 3 and patch_size[2] == 1:
for zmin in range(img.shape[2]):
coords_mesh_grid.append([ymin, ymax, xmin, xmax, zmin, zmin + 1])
else:
coords_mesh_grid.append([ymin, ymax, xmin, xmax])
return np.array(coords_mesh_grid).astype(int)
def pad_nd_image(image, new_shape=None, mode="edge", kwargs=None, return_slicer=False, shape_must_be_divisible_by=None):
"""
one padder to pad them all. Documentation? Well okay. A little bit. by Fabian Isensee
:param image: nd image. can be anything
:param new_shape: what shape do you want? new_shape does not have to have the same dimensionality as image. If
len(new_shape) < len(image.shape) then the last axes of image will be padded. If new_shape < image.shape in any of
the axes then we will not pad that axis, but also not crop! (interpret new_shape as new_min_shape)
Example:
image.shape = (10, 1, 512, 512); new_shape = (768, 768) -> result: (10, 1, 768, 768). Cool, huh?
image.shape = (10, 1, 512, 512); new_shape = (364, 768) -> result: (10, 1, 512, 768).
:param mode: see np.pad for documentation
:param return_slicer: if True then this function will also return what coords you will need to use when cropping back
to original shape
:param shape_must_be_divisible_by: for network prediction. After applying new_shape, make sure the new shape is
divisibly by that number (can also be a list with an entry for each axis). Whatever is missing to match that will
be padded (so the result may be larger than new_shape if shape_must_be_divisible_by is not None)
:param kwargs: see np.pad for documentation
"""
if kwargs is None:
kwargs = {}
if new_shape is not None:
old_shape = np.array(image.shape[-len(new_shape):])
else:
assert shape_must_be_divisible_by is not None
assert isinstance(shape_must_be_divisible_by, (list, tuple, np.ndarray))
new_shape = image.shape[-len(shape_must_be_divisible_by):]
old_shape = new_shape
num_axes_nopad = len(image.shape) - len(new_shape)
new_shape = [max(new_shape[i], old_shape[i]) for i in range(len(new_shape))]
if not isinstance(new_shape, np.ndarray):
new_shape = np.array(new_shape)
if shape_must_be_divisible_by is not None:
if not isinstance(shape_must_be_divisible_by, (list, tuple, np.ndarray)):
shape_must_be_divisible_by = [shape_must_be_divisible_by] * len(new_shape)
else:
assert len(shape_must_be_divisible_by) == len(new_shape)
for i in range(len(new_shape)):
if new_shape[i] % shape_must_be_divisible_by[i] == 0:
new_shape[i] -= shape_must_be_divisible_by[i]
new_shape = np.array([new_shape[i] + shape_must_be_divisible_by[i] - new_shape[i] % shape_must_be_divisible_by[i] for i in range(len(new_shape))])
difference = new_shape - old_shape
pad_below = difference // 2
pad_above = difference // 2 + difference % 2
pad_list = [[0, 0]]*num_axes_nopad + list([list(i) for i in zip(pad_below, pad_above)])
res = np.pad(image, pad_list, mode, **kwargs)
if not return_slicer:
return res
else:
pad_list = np.array(pad_list)
pad_list[:, 1] = np.array(res.shape) - pad_list[:, 1]
slicer = list(slice(*i) for i in pad_list)
return res, slicer
#############################
# data packing / unpacking #
#############################
def get_case_identifiers(folder):
case_identifiers = [i[:-4] for i in os.listdir(folder) if i.endswith("npz")]
return case_identifiers
def convert_to_npy(npz_file, remove=False):
identifier = os.path.split(npz_file)[1][:-4]
if not os.path.isfile(npz_file[:-4] + ".npy"):
a = np.load(npz_file)[identifier]
np.save(npz_file[:-4] + ".npy", a)
if remove:
os.remove(npz_file)
def unpack_dataset(folder, threads=8):
case_identifiers = get_case_identifiers(folder)
p = Pool(threads)
npz_files = [os.path.join(folder, i + ".npz") for i in case_identifiers]
p.starmap(convert_to_npy, [(f, True) for f in npz_files])
p.close()
p.join()
def delete_npy(folder):
case_identifiers = get_case_identifiers(folder)
npy_files = [os.path.join(folder, i + ".npy") for i in case_identifiers]
npy_files = [i for i in npy_files if os.path.isfile(i)]
for n in npy_files:
os.remove(n)