Diff of /utils/dataloader_utils.py [000000] .. [bb7f56]

Switch to unified view

a b/utils/dataloader_utils.py
1
#!/usr/bin/env python
2
# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
# ==============================================================================
16
17
import numpy as np
18
import os
19
from multiprocessing import Pool
20
21
22
23
def get_class_balanced_patients(class_targets, batch_size, num_classes, slack_factor=0.1):
24
    '''
25
    samples patients towards equilibrium of classes on a roi-level. For highly imbalanced datasets, this might be a too strong requirement.
26
    Hence a slack factor determines the ratio of the batch, that is randomly sampled, before class-balance is triggered.
27
    :param class_targets: list of patient targets. where each patient target is a list of class labels of respective rois.
28
    :param batch_size:
29
    :param num_classes:
30
    :param slack_factor:
31
    :return: batch_ixs: list of indices referring to a subset in class_targets-list, sampled to build one batch.
32
    '''
33
    batch_ixs = []
34
    class_count = {k: 0 for k in range(num_classes)}
35
    weakest_class = 0
36
    for ix in range(batch_size):
37
38
        keep_looking = True
39
        while keep_looking:
40
            #choose a random patient.
41
            cand = np.random.choice(len(class_targets), 1)[0]
42
            # check the least occuring class among this patient's rois.
43
            tmp_weakest_class = np.argmin([class_targets[cand].count(ii) for ii in range(num_classes)])
44
            # if current batch already bigger than the slack_factor ratio, then
45
            # check that weakest class in this patient is not the weakest in current batch (since needs to be boosted)
46
            # also that at least one roi of this patient belongs to weakest class. If True, keep patient, else keep looking.
47
            if (tmp_weakest_class != weakest_class and class_targets[cand].count(weakest_class) > 0) or ix < int(batch_size * slack_factor):
48
                keep_looking = False
49
50
        for c in range(num_classes):
51
            class_count[c] += class_targets[cand].count(c)
52
        weakest_class = np.argmin(([class_count[c] for c in range(num_classes)]))
53
        batch_ixs.append(cand)
54
55
    return batch_ixs
56
57
58
59
class fold_generator:
60
    """
61
    generates splits of indices for a given length of a dataset to perform n-fold cross-validation.
62
    splits each fold into 3 subsets for training, validation and testing.
63
    This form of cross validation uses an inner loop test set, which is useful if test scores shall be reported on a
64
    statistically reliable amount of patients, despite limited size of a dataset.
65
    If hold out test set is provided and hence no inner loop test set needed, just add test_idxs to the training data in the dataloader.
66
    This creates straight-forward train-val splits.
67
    :returns names list: list of len n_splits. each element is a list of len 3 for train_ix, val_ix, test_ix.
68
    """
69
    def __init__(self, seed, n_splits, len_data):
70
        """
71
        :param seed: Random seed for splits.
72
        :param n_splits: number of splits, e.g. 5 splits for 5-fold cross-validation
73
        :param len_data: number of elements in the dataset.
74
        """
75
        self.tr_ix = []
76
        self.val_ix = []
77
        self.te_ix = []
78
        self.slicer = None
79
        self.missing = 0
80
        self.fold = 0
81
        self.len_data = len_data
82
        self.n_splits = n_splits
83
        self.myseed = seed
84
        self.boost_val = 0
85
86
    def init_indices(self):
87
88
        t = list(np.arange(self.l))
89
        # round up to next splittable data amount.
90
        split_length = int(np.ceil(len(t) / float(self.n_splits)))
91
        self.slicer = split_length
92
        self.mod = len(t) % self.n_splits
93
        if self.mod > 0:
94
            # missing is the number of folds, in which the new splits are reduced to account for missing data.
95
            self.missing = self.n_splits - self.mod
96
97
        self.te_ix = t[:self.slicer]
98
        self.tr_ix = t[self.slicer:]
99
        self.val_ix = self.tr_ix[:self.slicer]
100
        self.tr_ix = self.tr_ix[self.slicer:]
101
102
    def new_fold(self):
103
104
        slicer = self.slicer
105
        if self.fold < self.missing :
106
            slicer = self.slicer - 1
107
108
        temp = self.te_ix
109
110
        # catch exception mod == 1: test set collects 1+ data since walk through both roudned up splits.
111
        # account for by reducing last fold split by 1.
112
        if self.fold == self.n_splits-2 and self.mod ==1:
113
            temp += self.val_ix[-1:]
114
            self.val_ix = self.val_ix[:-1]
115
116
        self.te_ix = self.val_ix
117
        self.val_ix = self.tr_ix[:slicer]
118
        self.tr_ix = self.tr_ix[slicer:] + temp
119
120
121
    def get_fold_names(self):
122
        names_list = []
123
        rgen = np.random.RandomState(self.myseed)
124
        cv_names = np.arange(self.len_data)
125
126
        rgen.shuffle(cv_names)
127
        self.l = len(cv_names)
128
        self.init_indices()
129
130
        for split in range(self.n_splits):
131
            train_names, val_names, test_names = cv_names[self.tr_ix], cv_names[self.val_ix], cv_names[self.te_ix]
132
            names_list.append([train_names, val_names, test_names, self.fold])
133
            self.new_fold()
134
            self.fold += 1
135
136
        return names_list
137
138
139
140
def get_patch_crop_coords(img, patch_size, min_overlap=30):
141
    """
142
143
    _:param img (y, x, (z))
144
    _:param patch_size: list of len 2 (2D) or 3 (3D).
145
    _:param min_overlap: minimum required overlap of patches.
146
    If too small, some areas are poorly represented only at edges of single patches.
147
    _:return ndarray: shape (n_patches, 2*dim). crop coordinates for each patch.
148
    """
149
    crop_coords = []
150
    for dim in range(len(img.shape)):
151
        n_patches = int(np.ceil(img.shape[dim] / patch_size[dim]))
152
153
        # no crops required in this dimension, add image shape as coordinates.
154
        if n_patches == 1:
155
            crop_coords.append([(0, img.shape[dim])])
156
            continue
157
158
        # fix the two outside patches to coords patchsize/2 and interpolate.
159
        center_dists = (img.shape[dim] - patch_size[dim]) / (n_patches - 1)
160
161
        if (patch_size[dim] - center_dists) < min_overlap:
162
            n_patches += 1
163
            center_dists = (img.shape[dim] - patch_size[dim]) / (n_patches - 1)
164
165
        patch_centers = np.round([(patch_size[dim] / 2 + (center_dists * ii)) for ii in range(n_patches)])
166
        dim_crop_coords = [(center - patch_size[dim] / 2, center + patch_size[dim] / 2) for center in patch_centers]
167
        crop_coords.append(dim_crop_coords)
168
169
    coords_mesh_grid = []
170
    for ymin, ymax in crop_coords[0]:
171
        for xmin, xmax in crop_coords[1]:
172
            if len(crop_coords) == 3 and patch_size[2] > 1:
173
                for zmin, zmax in crop_coords[2]:
174
                    coords_mesh_grid.append([ymin, ymax, xmin, xmax, zmin, zmax])
175
            elif len(crop_coords) == 3 and patch_size[2] == 1:
176
                for zmin in range(img.shape[2]):
177
                    coords_mesh_grid.append([ymin, ymax, xmin, xmax, zmin, zmin + 1])
178
            else:
179
                coords_mesh_grid.append([ymin, ymax, xmin, xmax])
180
    return np.array(coords_mesh_grid).astype(int)
181
182
183
184
def pad_nd_image(image, new_shape=None, mode="edge", kwargs=None, return_slicer=False, shape_must_be_divisible_by=None):
185
    """
186
    one padder to pad them all. Documentation? Well okay. A little bit. by Fabian Isensee
187
188
    :param image: nd image. can be anything
189
    :param new_shape: what shape do you want? new_shape does not have to have the same dimensionality as image. If
190
    len(new_shape) < len(image.shape) then the last axes of image will be padded. If new_shape < image.shape in any of
191
    the axes then we will not pad that axis, but also not crop! (interpret new_shape as new_min_shape)
192
    Example:
193
    image.shape = (10, 1, 512, 512); new_shape = (768, 768) -> result: (10, 1, 768, 768). Cool, huh?
194
    image.shape = (10, 1, 512, 512); new_shape = (364, 768) -> result: (10, 1, 512, 768).
195
196
    :param mode: see np.pad for documentation
197
    :param return_slicer: if True then this function will also return what coords you will need to use when cropping back
198
    to original shape
199
    :param shape_must_be_divisible_by: for network prediction. After applying new_shape, make sure the new shape is
200
    divisibly by that number (can also be a list with an entry for each axis). Whatever is missing to match that will
201
    be padded (so the result may be larger than new_shape if shape_must_be_divisible_by is not None)
202
    :param kwargs: see np.pad for documentation
203
    """
204
    if kwargs is None:
205
        kwargs = {}
206
207
    if new_shape is not None:
208
        old_shape = np.array(image.shape[-len(new_shape):])
209
    else:
210
        assert shape_must_be_divisible_by is not None
211
        assert isinstance(shape_must_be_divisible_by, (list, tuple, np.ndarray))
212
        new_shape = image.shape[-len(shape_must_be_divisible_by):]
213
        old_shape = new_shape
214
215
    num_axes_nopad = len(image.shape) - len(new_shape)
216
217
    new_shape = [max(new_shape[i], old_shape[i]) for i in range(len(new_shape))]
218
219
    if not isinstance(new_shape, np.ndarray):
220
        new_shape = np.array(new_shape)
221
222
    if shape_must_be_divisible_by is not None:
223
        if not isinstance(shape_must_be_divisible_by, (list, tuple, np.ndarray)):
224
            shape_must_be_divisible_by = [shape_must_be_divisible_by] * len(new_shape)
225
        else:
226
            assert len(shape_must_be_divisible_by) == len(new_shape)
227
228
        for i in range(len(new_shape)):
229
            if new_shape[i] % shape_must_be_divisible_by[i] == 0:
230
                new_shape[i] -= shape_must_be_divisible_by[i]
231
232
        new_shape = np.array([new_shape[i] + shape_must_be_divisible_by[i] - new_shape[i] % shape_must_be_divisible_by[i] for i in range(len(new_shape))])
233
234
    difference = new_shape - old_shape
235
    pad_below = difference // 2
236
    pad_above = difference // 2 + difference % 2
237
    pad_list = [[0, 0]]*num_axes_nopad + list([list(i) for i in zip(pad_below, pad_above)])
238
    res = np.pad(image, pad_list, mode, **kwargs)
239
    if not return_slicer:
240
        return res
241
    else:
242
        pad_list = np.array(pad_list)
243
        pad_list[:, 1] = np.array(res.shape) - pad_list[:, 1]
244
        slicer = list(slice(*i) for i in pad_list)
245
        return res, slicer
246
247
248
#############################
249
#  data packing / unpacking #
250
#############################
251
252
def get_case_identifiers(folder):
253
    case_identifiers = [i[:-4] for i in os.listdir(folder) if i.endswith("npz")]
254
    return case_identifiers
255
256
257
def convert_to_npy(npz_file, remove=False):
258
    identifier = os.path.split(npz_file)[1][:-4]
259
    if not os.path.isfile(npz_file[:-4] + ".npy"):
260
        a = np.load(npz_file)[identifier]
261
        np.save(npz_file[:-4] + ".npy", a)
262
    if remove:
263
        os.remove(npz_file)
264
265
266
def unpack_dataset(folder, threads=8):
267
    case_identifiers = get_case_identifiers(folder)
268
    p = Pool(threads)
269
    npz_files = [os.path.join(folder, i + ".npz") for i in case_identifiers]
270
    p.starmap(convert_to_npy, [(f, True) for f in npz_files])
271
    p.close()
272
    p.join()
273
274
275
def delete_npy(folder):
276
    case_identifiers = get_case_identifiers(folder)
277
    npy_files = [os.path.join(folder, i + ".npy") for i in case_identifiers]
278
    npy_files = [i for i in npy_files if os.path.isfile(i)]
279
    for n in npy_files:
280
        os.remove(n)