#!/usr/bin/env python
# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import os
from multiprocessing import Pool
def get_class_balanced_patients(class_targets, batch_size, num_classes, slack_factor=0.1):
'''
samples patients towards equilibrium of classes on a roi-level. For highly imbalanced datasets, this might be a too strong requirement.
Hence a slack factor determines the ratio of the batch, that is randomly sampled, before class-balance is triggered.
:param class_targets: list of patient targets. where each patient target is a list of class labels of respective rois.
:param batch_size:
:param num_classes:
:param slack_factor:
:return: batch_ixs: list of indices referring to a subset in class_targets-list, sampled to build one batch.
'''
batch_ixs = []
class_count = {k: 0 for k in range(num_classes)}
weakest_class = 0
for ix in range(batch_size):
keep_looking = True
while keep_looking:
#choose a random patient.
cand = np.random.choice(len(class_targets), 1)[0]
# check the least occuring class among this patient's rois.
tmp_weakest_class = np.argmin([class_targets[cand].count(ii) for ii in range(num_classes)])
# if current batch already bigger than the slack_factor ratio, then
# check that weakest class in this patient is not the weakest in current batch (since needs to be boosted)
# also that at least one roi of this patient belongs to weakest class. If True, keep patient, else keep looking.
if (tmp_weakest_class != weakest_class and class_targets[cand].count(weakest_class) > 0) or ix < int(batch_size * slack_factor):
keep_looking = False
for c in range(num_classes):
class_count[c] += class_targets[cand].count(c)
weakest_class = np.argmin(([class_count[c] for c in range(num_classes)]))
batch_ixs.append(cand)
return batch_ixs
class fold_generator:
"""
generates splits of indices for a given length of a dataset to perform n-fold cross-validation.
splits each fold into 3 subsets for training, validation and testing.
This form of cross validation uses an inner loop test set, which is useful if test scores shall be reported on a
statistically reliable amount of patients, despite limited size of a dataset.
If hold out test set is provided and hence no inner loop test set needed, just add test_idxs to the training data in the dataloader.
This creates straight-forward train-val splits.
:returns names list: list of len n_splits. each element is a list of len 3 for train_ix, val_ix, test_ix.
"""
def __init__(self, seed, n_splits, len_data):
"""
:param seed: Random seed for splits.
:param n_splits: number of splits, e.g. 5 splits for 5-fold cross-validation
:param len_data: number of elements in the dataset.
"""
self.tr_ix = []
self.val_ix = []
self.te_ix = []
self.slicer = None
self.missing = 0
self.fold = 0
self.len_data = len_data
self.n_splits = n_splits
self.myseed = seed
self.boost_val = 0
def init_indices(self):
t = list(np.arange(self.l))
# round up to next splittable data amount.
split_length = int(np.ceil(len(t) / float(self.n_splits)))
self.slicer = split_length
self.mod = len(t) % self.n_splits
if self.mod > 0:
# missing is the number of folds, in which the new splits are reduced to account for missing data.
self.missing = self.n_splits - self.mod
self.te_ix = t[:self.slicer]
self.tr_ix = t[self.slicer:]
self.val_ix = self.tr_ix[:self.slicer]
self.tr_ix = self.tr_ix[self.slicer:]
def new_fold(self):
slicer = self.slicer
if self.fold < self.missing :
slicer = self.slicer - 1
temp = self.te_ix
# catch exception mod == 1: test set collects 1+ data since walk through both roudned up splits.
# account for by reducing last fold split by 1.
if self.fold == self.n_splits-2 and self.mod ==1:
temp += self.val_ix[-1:]
self.val_ix = self.val_ix[:-1]
self.te_ix = self.val_ix
self.val_ix = self.tr_ix[:slicer]
self.tr_ix = self.tr_ix[slicer:] + temp
def get_fold_names(self):
names_list = []
rgen = np.random.RandomState(self.myseed)
cv_names = np.arange(self.len_data)
rgen.shuffle(cv_names)
self.l = len(cv_names)
self.init_indices()
for split in range(self.n_splits):
train_names, val_names, test_names = cv_names[self.tr_ix], cv_names[self.val_ix], cv_names[self.te_ix]
names_list.append([train_names, val_names, test_names, self.fold])
self.new_fold()
self.fold += 1
return names_list
def get_patch_crop_coords(img, patch_size, min_overlap=30):
"""
_:param img (y, x, (z))
_:param patch_size: list of len 2 (2D) or 3 (3D).
_:param min_overlap: minimum required overlap of patches.
If too small, some areas are poorly represented only at edges of single patches.
_:return ndarray: shape (n_patches, 2*dim). crop coordinates for each patch.
"""
crop_coords = []
for dim in range(len(img.shape)):
n_patches = int(np.ceil(img.shape[dim] / patch_size[dim]))
# no crops required in this dimension, add image shape as coordinates.
if n_patches == 1:
crop_coords.append([(0, img.shape[dim])])
continue
# fix the two outside patches to coords patchsize/2 and interpolate.
center_dists = (img.shape[dim] - patch_size[dim]) / (n_patches - 1)
if (patch_size[dim] - center_dists) < min_overlap:
n_patches += 1
center_dists = (img.shape[dim] - patch_size[dim]) / (n_patches - 1)
patch_centers = np.round([(patch_size[dim] / 2 + (center_dists * ii)) for ii in range(n_patches)])
dim_crop_coords = [(center - patch_size[dim] / 2, center + patch_size[dim] / 2) for center in patch_centers]
crop_coords.append(dim_crop_coords)
coords_mesh_grid = []
for ymin, ymax in crop_coords[0]:
for xmin, xmax in crop_coords[1]:
if len(crop_coords) == 3 and patch_size[2] > 1:
for zmin, zmax in crop_coords[2]:
coords_mesh_grid.append([ymin, ymax, xmin, xmax, zmin, zmax])
elif len(crop_coords) == 3 and patch_size[2] == 1:
for zmin in range(img.shape[2]):
coords_mesh_grid.append([ymin, ymax, xmin, xmax, zmin, zmin + 1])
else:
coords_mesh_grid.append([ymin, ymax, xmin, xmax])
return np.array(coords_mesh_grid).astype(int)
def pad_nd_image(image, new_shape=None, mode="edge", kwargs=None, return_slicer=False, shape_must_be_divisible_by=None):
"""
one padder to pad them all. Documentation? Well okay. A little bit. by Fabian Isensee
:param image: nd image. can be anything
:param new_shape: what shape do you want? new_shape does not have to have the same dimensionality as image. If
len(new_shape) < len(image.shape) then the last axes of image will be padded. If new_shape < image.shape in any of
the axes then we will not pad that axis, but also not crop! (interpret new_shape as new_min_shape)
Example:
image.shape = (10, 1, 512, 512); new_shape = (768, 768) -> result: (10, 1, 768, 768). Cool, huh?
image.shape = (10, 1, 512, 512); new_shape = (364, 768) -> result: (10, 1, 512, 768).
:param mode: see np.pad for documentation
:param return_slicer: if True then this function will also return what coords you will need to use when cropping back
to original shape
:param shape_must_be_divisible_by: for network prediction. After applying new_shape, make sure the new shape is
divisibly by that number (can also be a list with an entry for each axis). Whatever is missing to match that will
be padded (so the result may be larger than new_shape if shape_must_be_divisible_by is not None)
:param kwargs: see np.pad for documentation
"""
if kwargs is None:
kwargs = {}
if new_shape is not None:
old_shape = np.array(image.shape[-len(new_shape):])
else:
assert shape_must_be_divisible_by is not None
assert isinstance(shape_must_be_divisible_by, (list, tuple, np.ndarray))
new_shape = image.shape[-len(shape_must_be_divisible_by):]
old_shape = new_shape
num_axes_nopad = len(image.shape) - len(new_shape)
new_shape = [max(new_shape[i], old_shape[i]) for i in range(len(new_shape))]
if not isinstance(new_shape, np.ndarray):
new_shape = np.array(new_shape)
if shape_must_be_divisible_by is not None:
if not isinstance(shape_must_be_divisible_by, (list, tuple, np.ndarray)):
shape_must_be_divisible_by = [shape_must_be_divisible_by] * len(new_shape)
else:
assert len(shape_must_be_divisible_by) == len(new_shape)
for i in range(len(new_shape)):
if new_shape[i] % shape_must_be_divisible_by[i] == 0:
new_shape[i] -= shape_must_be_divisible_by[i]
new_shape = np.array([new_shape[i] + shape_must_be_divisible_by[i] - new_shape[i] % shape_must_be_divisible_by[i] for i in range(len(new_shape))])
difference = new_shape - old_shape
pad_below = difference // 2
pad_above = difference // 2 + difference % 2
pad_list = [[0, 0]]*num_axes_nopad + list([list(i) for i in zip(pad_below, pad_above)])
res = np.pad(image, pad_list, mode, **kwargs)
if not return_slicer:
return res
else:
pad_list = np.array(pad_list)
pad_list[:, 1] = np.array(res.shape) - pad_list[:, 1]
slicer = list(slice(*i) for i in pad_list)
return res, slicer
#############################
# data packing / unpacking #
#############################
def get_case_identifiers(folder):
case_identifiers = [i[:-4] for i in os.listdir(folder) if i.endswith("npz")]
return case_identifiers
def convert_to_npy(npz_file, remove=False):
identifier = os.path.split(npz_file)[1][:-4]
if not os.path.isfile(npz_file[:-4] + ".npy"):
a = np.load(npz_file)[identifier]
np.save(npz_file[:-4] + ".npy", a)
if remove:
os.remove(npz_file)
def unpack_dataset(folder, threads=8):
case_identifiers = get_case_identifiers(folder)
p = Pool(threads)
npz_files = [os.path.join(folder, i + ".npz") for i in case_identifiers]
p.starmap(convert_to_npy, [(f, True) for f in npz_files])
p.close()
p.join()
def delete_npy(folder):
case_identifiers = get_case_identifiers(folder)
npy_files = [os.path.join(folder, i + ".npy") for i in case_identifiers]
npy_files = [i for i in npy_files if os.path.isfile(i)]
for n in npy_files:
os.remove(n)