--- a
+++ b/utils/dataloader_utils.py
@@ -0,0 +1,280 @@
+#!/usr/bin/env python
+# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+import numpy as np
+import os
+from multiprocessing import Pool
+
+
+
+def get_class_balanced_patients(class_targets, batch_size, num_classes, slack_factor=0.1):
+    '''
+    samples patients towards equilibrium of classes on a roi-level. For highly imbalanced datasets, this might be a too strong requirement.
+    Hence a slack factor determines the ratio of the batch, that is randomly sampled, before class-balance is triggered.
+    :param class_targets: list of patient targets. where each patient target is a list of class labels of respective rois.
+    :param batch_size:
+    :param num_classes:
+    :param slack_factor:
+    :return: batch_ixs: list of indices referring to a subset in class_targets-list, sampled to build one batch.
+    '''
+    batch_ixs = []
+    class_count = {k: 0 for k in range(num_classes)}
+    weakest_class = 0
+    for ix in range(batch_size):
+
+        keep_looking = True
+        while keep_looking:
+            #choose a random patient.
+            cand = np.random.choice(len(class_targets), 1)[0]
+            # check the least occuring class among this patient's rois.
+            tmp_weakest_class = np.argmin([class_targets[cand].count(ii) for ii in range(num_classes)])
+            # if current batch already bigger than the slack_factor ratio, then
+            # check that weakest class in this patient is not the weakest in current batch (since needs to be boosted)
+            # also that at least one roi of this patient belongs to weakest class. If True, keep patient, else keep looking.
+            if (tmp_weakest_class != weakest_class and class_targets[cand].count(weakest_class) > 0) or ix < int(batch_size * slack_factor):
+                keep_looking = False
+
+        for c in range(num_classes):
+            class_count[c] += class_targets[cand].count(c)
+        weakest_class = np.argmin(([class_count[c] for c in range(num_classes)]))
+        batch_ixs.append(cand)
+
+    return batch_ixs
+
+
+
+class fold_generator:
+    """
+    generates splits of indices for a given length of a dataset to perform n-fold cross-validation.
+    splits each fold into 3 subsets for training, validation and testing.
+    This form of cross validation uses an inner loop test set, which is useful if test scores shall be reported on a
+    statistically reliable amount of patients, despite limited size of a dataset.
+    If hold out test set is provided and hence no inner loop test set needed, just add test_idxs to the training data in the dataloader.
+    This creates straight-forward train-val splits.
+    :returns names list: list of len n_splits. each element is a list of len 3 for train_ix, val_ix, test_ix.
+    """
+    def __init__(self, seed, n_splits, len_data):
+        """
+        :param seed: Random seed for splits.
+        :param n_splits: number of splits, e.g. 5 splits for 5-fold cross-validation
+        :param len_data: number of elements in the dataset.
+        """
+        self.tr_ix = []
+        self.val_ix = []
+        self.te_ix = []
+        self.slicer = None
+        self.missing = 0
+        self.fold = 0
+        self.len_data = len_data
+        self.n_splits = n_splits
+        self.myseed = seed
+        self.boost_val = 0
+
+    def init_indices(self):
+
+        t = list(np.arange(self.l))
+        # round up to next splittable data amount.
+        split_length = int(np.ceil(len(t) / float(self.n_splits)))
+        self.slicer = split_length
+        self.mod = len(t) % self.n_splits
+        if self.mod > 0:
+            # missing is the number of folds, in which the new splits are reduced to account for missing data.
+            self.missing = self.n_splits - self.mod
+
+        self.te_ix = t[:self.slicer]
+        self.tr_ix = t[self.slicer:]
+        self.val_ix = self.tr_ix[:self.slicer]
+        self.tr_ix = self.tr_ix[self.slicer:]
+
+    def new_fold(self):
+
+        slicer = self.slicer
+        if self.fold < self.missing :
+            slicer = self.slicer - 1
+
+        temp = self.te_ix
+
+        # catch exception mod == 1: test set collects 1+ data since walk through both roudned up splits.
+        # account for by reducing last fold split by 1.
+        if self.fold == self.n_splits-2 and self.mod ==1:
+            temp += self.val_ix[-1:]
+            self.val_ix = self.val_ix[:-1]
+
+        self.te_ix = self.val_ix
+        self.val_ix = self.tr_ix[:slicer]
+        self.tr_ix = self.tr_ix[slicer:] + temp
+
+
+    def get_fold_names(self):
+        names_list = []
+        rgen = np.random.RandomState(self.myseed)
+        cv_names = np.arange(self.len_data)
+
+        rgen.shuffle(cv_names)
+        self.l = len(cv_names)
+        self.init_indices()
+
+        for split in range(self.n_splits):
+            train_names, val_names, test_names = cv_names[self.tr_ix], cv_names[self.val_ix], cv_names[self.te_ix]
+            names_list.append([train_names, val_names, test_names, self.fold])
+            self.new_fold()
+            self.fold += 1
+
+        return names_list
+
+
+
+def get_patch_crop_coords(img, patch_size, min_overlap=30):
+    """
+
+    _:param img (y, x, (z))
+    _:param patch_size: list of len 2 (2D) or 3 (3D).
+    _:param min_overlap: minimum required overlap of patches.
+    If too small, some areas are poorly represented only at edges of single patches.
+    _:return ndarray: shape (n_patches, 2*dim). crop coordinates for each patch.
+    """
+    crop_coords = []
+    for dim in range(len(img.shape)):
+        n_patches = int(np.ceil(img.shape[dim] / patch_size[dim]))
+
+        # no crops required in this dimension, add image shape as coordinates.
+        if n_patches == 1:
+            crop_coords.append([(0, img.shape[dim])])
+            continue
+
+        # fix the two outside patches to coords patchsize/2 and interpolate.
+        center_dists = (img.shape[dim] - patch_size[dim]) / (n_patches - 1)
+
+        if (patch_size[dim] - center_dists) < min_overlap:
+            n_patches += 1
+            center_dists = (img.shape[dim] - patch_size[dim]) / (n_patches - 1)
+
+        patch_centers = np.round([(patch_size[dim] / 2 + (center_dists * ii)) for ii in range(n_patches)])
+        dim_crop_coords = [(center - patch_size[dim] / 2, center + patch_size[dim] / 2) for center in patch_centers]
+        crop_coords.append(dim_crop_coords)
+
+    coords_mesh_grid = []
+    for ymin, ymax in crop_coords[0]:
+        for xmin, xmax in crop_coords[1]:
+            if len(crop_coords) == 3 and patch_size[2] > 1:
+                for zmin, zmax in crop_coords[2]:
+                    coords_mesh_grid.append([ymin, ymax, xmin, xmax, zmin, zmax])
+            elif len(crop_coords) == 3 and patch_size[2] == 1:
+                for zmin in range(img.shape[2]):
+                    coords_mesh_grid.append([ymin, ymax, xmin, xmax, zmin, zmin + 1])
+            else:
+                coords_mesh_grid.append([ymin, ymax, xmin, xmax])
+    return np.array(coords_mesh_grid).astype(int)
+
+
+
+def pad_nd_image(image, new_shape=None, mode="edge", kwargs=None, return_slicer=False, shape_must_be_divisible_by=None):
+    """
+    one padder to pad them all. Documentation? Well okay. A little bit. by Fabian Isensee
+
+    :param image: nd image. can be anything
+    :param new_shape: what shape do you want? new_shape does not have to have the same dimensionality as image. If
+    len(new_shape) < len(image.shape) then the last axes of image will be padded. If new_shape < image.shape in any of
+    the axes then we will not pad that axis, but also not crop! (interpret new_shape as new_min_shape)
+    Example:
+    image.shape = (10, 1, 512, 512); new_shape = (768, 768) -> result: (10, 1, 768, 768). Cool, huh?
+    image.shape = (10, 1, 512, 512); new_shape = (364, 768) -> result: (10, 1, 512, 768).
+
+    :param mode: see np.pad for documentation
+    :param return_slicer: if True then this function will also return what coords you will need to use when cropping back
+    to original shape
+    :param shape_must_be_divisible_by: for network prediction. After applying new_shape, make sure the new shape is
+    divisibly by that number (can also be a list with an entry for each axis). Whatever is missing to match that will
+    be padded (so the result may be larger than new_shape if shape_must_be_divisible_by is not None)
+    :param kwargs: see np.pad for documentation
+    """
+    if kwargs is None:
+        kwargs = {}
+
+    if new_shape is not None:
+        old_shape = np.array(image.shape[-len(new_shape):])
+    else:
+        assert shape_must_be_divisible_by is not None
+        assert isinstance(shape_must_be_divisible_by, (list, tuple, np.ndarray))
+        new_shape = image.shape[-len(shape_must_be_divisible_by):]
+        old_shape = new_shape
+
+    num_axes_nopad = len(image.shape) - len(new_shape)
+
+    new_shape = [max(new_shape[i], old_shape[i]) for i in range(len(new_shape))]
+
+    if not isinstance(new_shape, np.ndarray):
+        new_shape = np.array(new_shape)
+
+    if shape_must_be_divisible_by is not None:
+        if not isinstance(shape_must_be_divisible_by, (list, tuple, np.ndarray)):
+            shape_must_be_divisible_by = [shape_must_be_divisible_by] * len(new_shape)
+        else:
+            assert len(shape_must_be_divisible_by) == len(new_shape)
+
+        for i in range(len(new_shape)):
+            if new_shape[i] % shape_must_be_divisible_by[i] == 0:
+                new_shape[i] -= shape_must_be_divisible_by[i]
+
+        new_shape = np.array([new_shape[i] + shape_must_be_divisible_by[i] - new_shape[i] % shape_must_be_divisible_by[i] for i in range(len(new_shape))])
+
+    difference = new_shape - old_shape
+    pad_below = difference // 2
+    pad_above = difference // 2 + difference % 2
+    pad_list = [[0, 0]]*num_axes_nopad + list([list(i) for i in zip(pad_below, pad_above)])
+    res = np.pad(image, pad_list, mode, **kwargs)
+    if not return_slicer:
+        return res
+    else:
+        pad_list = np.array(pad_list)
+        pad_list[:, 1] = np.array(res.shape) - pad_list[:, 1]
+        slicer = list(slice(*i) for i in pad_list)
+        return res, slicer
+
+
+#############################
+#  data packing / unpacking #
+#############################
+
+def get_case_identifiers(folder):
+    case_identifiers = [i[:-4] for i in os.listdir(folder) if i.endswith("npz")]
+    return case_identifiers
+
+
+def convert_to_npy(npz_file, remove=False):
+    identifier = os.path.split(npz_file)[1][:-4]
+    if not os.path.isfile(npz_file[:-4] + ".npy"):
+        a = np.load(npz_file)[identifier]
+        np.save(npz_file[:-4] + ".npy", a)
+    if remove:
+        os.remove(npz_file)
+
+
+def unpack_dataset(folder, threads=8):
+    case_identifiers = get_case_identifiers(folder)
+    p = Pool(threads)
+    npz_files = [os.path.join(folder, i + ".npz") for i in case_identifiers]
+    p.starmap(convert_to_npy, [(f, True) for f in npz_files])
+    p.close()
+    p.join()
+
+
+def delete_npy(folder):
+    case_identifiers = get_case_identifiers(folder)
+    npy_files = [os.path.join(folder, i + ".npy") for i in case_identifiers]
+    npy_files = [i for i in npy_files if os.path.isfile(i)]
+    for n in npy_files:
+        os.remove(n)