--- a +++ b/ext/lab2im/utils.py @@ -0,0 +1,1057 @@ +""" +This file contains all the utilities used in that project. They are classified in 5 categories: +1- loading/saving functions: + -load_volume + -save_volume + -get_volume_info + -get_list_labels + -load_array_if_path + -write_pickle + -read_pickle + -write_model_summary +2- reformatting functions + -reformat_to_list + -reformat_to_n_channels_array +3- path related functions + -list_images_in_folder + -list_files + -list_subfolders + -strip_extension + -strip_suffix + -mkdir + -mkcmd +4- shape-related functions + -get_dims + -get_resample_shape + -add_axis + -get_padding_margin +5- build affine matrices/tensors + -create_affine_transformation_matrix + -sample_affine_transform + -create_rotation_transform + -create_shearing_transform +6- miscellaneous + -infer + -LoopInfo + -get_mapping_lut + -build_training_generator + -find_closest_number_divisible_by_m + -build_binary_structure + -draw_value_from_distribution + -build_exp + + +If you use this code, please cite the first SynthSeg paper: +https://github.com/BBillot/lab2im/blob/master/bibtex.bib + +Copyright 2020 Benjamin Billot + +Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in +compliance with the License. You may obtain a copy of the License at +https://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software distributed under the License is +distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +implied. See the License for the specific language governing permissions and limitations under the +License. +""" + + +import os +import glob +import math +import time +import pickle +import numpy as np +import nibabel as nib +import tensorflow as tf +import keras.layers as KL +import keras.backend as K +from datetime import timedelta +from scipy.ndimage.morphology import distance_transform_edt + + +# ---------------------------------------------- loading/saving functions ---------------------------------------------- + + +def load_volume(path_volume, im_only=True, squeeze=True, dtype=None, aff_ref=None): + """ + Load volume file. + :param path_volume: path of the volume to load. Can either be a nii, nii.gz, mgz, or npz format. + If npz format, 1) the variable name is assumed to be 'vol_data', + 2) the volume is associated with an identity affine matrix and blank header. + :param im_only: (optional) if False, the function also returns the affine matrix and header of the volume. + :param squeeze: (optional) whether to squeeze the volume when loading. + :param dtype: (optional) if not None, convert the loaded volume to this numpy dtype. + :param aff_ref: (optional) If not None, the loaded volume is aligned to this affine matrix. + The returned affine matrix is also given in this new space. Must be a numpy array of dimension 4x4. + :return: the volume, with corresponding affine matrix and header if im_only is False. + """ + assert path_volume.endswith(('.nii', '.nii.gz', '.mgz', '.npz')), 'Unknown data file: %s' % path_volume + + if path_volume.endswith(('.nii', '.nii.gz', '.mgz')): + x = nib.load(path_volume) + if squeeze: + volume = np.squeeze(x.get_fdata()) + else: + volume = x.get_fdata() + aff = x.affine + header = x.header + else: # npz + volume = np.load(path_volume)['vol_data'] + if squeeze: + volume = np.squeeze(volume) + aff = np.eye(4) + header = nib.Nifti1Header() + if dtype is not None: + if 'int' in dtype: + volume = np.round(volume) + volume = volume.astype(dtype=dtype) + + # align image to reference affine matrix + if aff_ref is not None: + from ext.lab2im import edit_volumes # the import is done here to avoid import loops + n_dims, _ = get_dims(list(volume.shape), max_channels=10) + volume, aff = edit_volumes.align_volume_to_ref(volume, aff, aff_ref=aff_ref, return_aff=True, n_dims=n_dims) + + if im_only: + return volume + else: + return volume, aff, header + + +def save_volume(volume, aff, header, path, res=None, dtype=None, n_dims=3): + """ + Save a volume. + :param volume: volume to save + :param aff: affine matrix of the volume to save. If aff is None, the volume is saved with an identity affine matrix. + aff can also be set to 'FS', in which case the volume is saved with the affine matrix of FreeSurfer outputs. + :param header: header of the volume to save. If None, the volume is saved with a blank header. + :param path: path where to save the volume. + :param res: (optional) update the resolution in the header before saving the volume. + :param dtype: (optional) numpy dtype for the saved volume. + :param n_dims: (optional) number of dimensions, to avoid confusion in multi-channel case. Default is None, where + n_dims is automatically inferred. + """ + + mkdir(os.path.dirname(path)) + if '.npz' in path: + np.savez_compressed(path, vol_data=volume) + else: + if header is None: + header = nib.Nifti1Header() + if isinstance(aff, str): + if aff == 'FS': + aff = np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]]) + elif aff is None: + aff = np.eye(4) + if dtype is not None: + if 'int' in dtype: + volume = np.round(volume) + volume = volume.astype(dtype=dtype) + nifty = nib.Nifti1Image(volume, aff, header) + nifty.set_data_dtype(dtype) + else: + nifty = nib.Nifti1Image(volume, aff, header) + if res is not None: + if n_dims is None: + n_dims, _ = get_dims(volume.shape) + res = reformat_to_list(res, length=n_dims, dtype=None) + nifty.header.set_zooms(res) + nib.save(nifty, path) + + +def get_volume_info(path_volume, return_volume=False, aff_ref=None, max_channels=10): + """ + Gather information about a volume: shape, affine matrix, number of dimensions and channels, header, and resolution. + :param path_volume: path of the volume to get information form. + :param return_volume: (optional) whether to return the volume along with the information. + :param aff_ref: (optional) If not None, the loaded volume is aligned to this affine matrix. + All info relative to the volume is then given in this new space. Must be a numpy array of dimension 4x4. + :param max_channels: maximum possible number of channels for the input volume. + :return: volume (if return_volume is true), and corresponding info. If aff_ref is not None, the returned aff is + the original one, i.e. the affine of the image before being aligned to aff_ref. + """ + # read image + im, aff, header = load_volume(path_volume, im_only=False) + + # understand if image is multichannel + im_shape = list(im.shape) + n_dims, n_channels = get_dims(im_shape, max_channels=max_channels) + im_shape = im_shape[:n_dims] + + # get labels res + if '.nii' in path_volume: + data_res = np.array(header['pixdim'][1:n_dims + 1]) + elif '.mgz' in path_volume: + data_res = np.array(header['delta']) # mgz image + else: + data_res = np.array([1.0] * n_dims) + + # align to given affine matrix + if aff_ref is not None: + from ext.lab2im import edit_volumes # the import is done here to avoid import loops + ras_axes = edit_volumes.get_ras_axes(aff, n_dims=n_dims) + ras_axes_ref = edit_volumes.get_ras_axes(aff_ref, n_dims=n_dims) + im = edit_volumes.align_volume_to_ref(im, aff, aff_ref=aff_ref, n_dims=n_dims) + im_shape = np.array(im_shape) + data_res = np.array(data_res) + im_shape[ras_axes_ref] = im_shape[ras_axes] + data_res[ras_axes_ref] = data_res[ras_axes] + im_shape = im_shape.tolist() + + # return info + if return_volume: + return im, im_shape, aff, n_dims, n_channels, header, data_res + else: + return im_shape, aff, n_dims, n_channels, header, data_res + + +def get_list_labels(label_list=None, labels_dir=None, save_label_list=None, FS_sort=False): + """This function reads or computes a list of all label values used in a set of label maps. + It can also sort all labels according to FreeSurfer lut. + :param label_list: (optional) already computed label_list. Can be a sequence, a 1d numpy array, or the path to + a numpy 1d array. + :param labels_dir: (optional) if path_label_list is None, the label list is computed by reading all the label maps + in the given folder. Can also be the path to a single label map. + :param save_label_list: (optional) path where to save the label list. + :param FS_sort: (optional) whether to sort label values according to the FreeSurfer classification. + If true, the label values will be ordered as follows: neutral labels first (i.e. non-sided), left-side labels, + and right-side labels. If FS_sort is True, this function also returns the number of neutral labels in label_list. + :return: the label list (numpy 1d array), and the number of neutral (i.e. non-sided) labels if FS_sort is True. + If one side of the brain is not represented at all in label_list, all labels are considered as neutral, and + n_neutral_labels = len(label_list). + """ + + # load label list if previously computed + if label_list is not None: + label_list = np.array(reformat_to_list(label_list, load_as_numpy=True, dtype='int')) + + # compute label list from all label files + elif labels_dir is not None: + print('Compiling list of unique labels') + # go through all labels files and compute unique list of labels + labels_paths = list_images_in_folder(labels_dir) + label_list = np.empty(0) + loop_info = LoopInfo(len(labels_paths), 10, 'processing', print_time=True) + for lab_idx, path in enumerate(labels_paths): + loop_info.update(lab_idx) + y = load_volume(path, dtype='int32') + y_unique = np.unique(y) + label_list = np.unique(np.concatenate((label_list, y_unique))).astype('int') + + else: + raise Exception('either label_list, path_label_list or labels_dir should be provided') + + # sort labels in neutral/left/right according to FS labels + n_neutral_labels = 0 + if FS_sort: + neutral_FS_labels = [0, 14, 15, 16, 21, 22, 23, 24, 72, 77, 80, 85, 100, 101, 102, 103, 104, 105, 106, 107, 108, + 109, 165, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, + 251, 252, 253, 254, 255, 258, 259, 260, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340, + 502, 506, 507, 508, 509, 511, 512, 514, 515, 516, 517, 530, + 531, 532, 533, 534, 535, 536, 537] + neutral = list() + left = list() + right = list() + for la in label_list: + if la in neutral_FS_labels: + if la not in neutral: + neutral.append(la) + elif (0 < la < 14) | (16 < la < 21) | (24 < la < 40) | (135 < la < 139) | (1000 <= la <= 1035) | \ + (la == 865) | (20100 < la < 20110): + if la not in left: + left.append(la) + elif (39 < la < 72) | (162 < la < 165) | (2000 <= la <= 2035) | (20000 < la < 20010) | (la == 139) | \ + (la == 866): + if la not in right: + right.append(la) + else: + raise Exception('label {} not in our current FS classification, ' + 'please update get_list_labels in utils.py'.format(la)) + label_list = np.concatenate([sorted(neutral), sorted(left), sorted(right)]) + if ((len(left) > 0) & (len(right) > 0)) | ((len(left) == 0) & (len(right) == 0)): + n_neutral_labels = len(neutral) + else: + n_neutral_labels = len(label_list) + + # save labels if specified + if save_label_list is not None: + np.save(save_label_list, np.int32(label_list)) + + if FS_sort: + return np.int32(label_list), n_neutral_labels + else: + return np.int32(label_list), None + + +def load_array_if_path(var, load_as_numpy=True): + """If var is a string and load_as_numpy is True, this function loads the array writen at the path indicated by var. + Otherwise it simply returns var as it is.""" + if (isinstance(var, str)) & load_as_numpy: + assert os.path.isfile(var), 'No such path: %s' % var + var = np.load(var) + return var + + +def write_pickle(filepath, obj): + """ write a python object with a pickle at a given path""" + with open(filepath, 'wb') as file: + pickler = pickle.Pickler(file) + pickler.dump(obj) + + +def read_pickle(filepath): + """ read a python object with a pickle""" + with open(filepath, 'rb') as file: + unpickler = pickle.Unpickler(file) + return unpickler.load() + + +def write_model_summary(model, filepath='./model_summary.txt', line_length=150): + """Write the summary of a keras model at a given path, with a given length for each line""" + with open(filepath, 'w') as fh: + model.summary(print_fn=lambda x: fh.write(x + '\n'), line_length=line_length) + + +# ----------------------------------------------- reformatting functions ----------------------------------------------- + + +def reformat_to_list(var, length=None, load_as_numpy=False, dtype=None): + """This function takes a variable and reformat it into a list of desired + length and type (int, float, bool, str). + If variable is a string, and load_as_numpy is True, it will be loaded as a numpy array. + If variable is None, this function returns None. + :param var: a str, int, float, list, tuple, or numpy array + :param length: (optional) if var is a single item, it will be replicated to a list of this length + :param load_as_numpy: (optional) whether var is the path to a numpy array + :param dtype: (optional) convert all item to this type. Can be 'int', 'float', 'bool', or 'str' + :return: reformatted list + """ + + # convert to list + if var is None: + return None + var = load_array_if_path(var, load_as_numpy=load_as_numpy) + if isinstance(var, (int, float, np.int, np.int32, np.int64, np.float, np.float32, np.float64)): + var = [var] + elif isinstance(var, tuple): + var = list(var) + elif isinstance(var, np.ndarray): + if var.shape == (1,): + var = [var[0]] + else: + var = np.squeeze(var).tolist() + elif isinstance(var, str): + var = [var] + elif isinstance(var, bool): + var = [var] + if isinstance(var, list): + if length is not None: + if len(var) == 1: + var = var * length + elif len(var) != length: + raise ValueError('if var is a list/tuple/numpy array, it should be of length 1 or {0}, ' + 'had {1}'.format(length, var)) + else: + raise TypeError('var should be an int, float, tuple, list, numpy array, or path to numpy array') + + # convert items type + if dtype is not None: + if dtype == 'int': + var = [int(v) for v in var] + elif dtype == 'float': + var = [float(v) for v in var] + elif dtype == 'bool': + var = [bool(v) for v in var] + elif dtype == 'str': + var = [str(v) for v in var] + else: + raise ValueError("dtype should be 'str', 'float', 'int', or 'bool'; had {}".format(dtype)) + return var + + +def reformat_to_n_channels_array(var, n_dims=3, n_channels=1): + """This function takes an int, float, list or tuple and reformat it to an array of shape (n_channels, n_dims). + If resolution is a str, it will be assumed to be the path of a numpy array. + If resolution is a numpy array, it will be checked to have shape (n_channels, n_dims). + Finally if resolution is None, this function returns None as well.""" + if var is None: + return [None] * n_channels + if isinstance(var, str): + var = np.load(var) + # convert to numpy array + if isinstance(var, (int, float, list, tuple)): + var = reformat_to_list(var, n_dims) + var = np.tile(np.array(var), (n_channels, 1)) + # check shape if numpy array + elif isinstance(var, np.ndarray): + if n_channels == 1: + var = var.reshape((1, n_dims)) + else: + if np.squeeze(var).shape == (n_dims,): + var = np.tile(var.reshape((1, n_dims)), (n_channels, 1)) + elif var.shape != (n_channels, n_dims): + raise ValueError('if array, var should be {0} or {1}'.format((1, n_dims), (n_channels, n_dims))) + else: + raise TypeError('var should be int, float, list, tuple or ndarray') + return np.round(var, 3) + + +# ----------------------------------------------- path-related functions ----------------------------------------------- + + +def list_images_in_folder(path_dir, include_single_image=True, check_if_empty=True): + """List all files with extension nii, nii.gz, mgz, or npz within a folder.""" + basename = os.path.basename(path_dir) + if include_single_image & \ + (('.nii.gz' in basename) | ('.nii' in basename) | ('.mgz' in basename) | ('.npz' in basename)): + assert os.path.isfile(path_dir), 'file %s does not exist' % path_dir + list_images = [path_dir] + else: + if os.path.isdir(path_dir): + list_images = sorted(glob.glob(os.path.join(path_dir, '*nii.gz')) + + glob.glob(os.path.join(path_dir, '*nii')) + + glob.glob(os.path.join(path_dir, '*.mgz')) + + glob.glob(os.path.join(path_dir, '*.npz'))) + else: + raise Exception('Folder does not exist: %s' % path_dir) + if check_if_empty: + assert len(list_images) > 0, 'no .nii, .nii.gz, .mgz or .npz image could be found in %s' % path_dir + return list_images + + +def list_files(path_dir, whole_path=True, expr=None, cond_type='or'): + """This function returns a list of files contained in a folder, with possible regexp. + :param path_dir: path of a folder + :param whole_path: (optional) whether to return whole path or just the filenames. + :param expr: (optional) regexp for files to list. Can be a str or a list of str. + :param cond_type: (optional) if exp is a list, specify the logical link between expressions in exp. + Can be 'or', or 'and'. + :return: a list of files + """ + assert isinstance(whole_path, bool), "whole_path should be bool" + assert cond_type in ['or', 'and'], "cond_type should be either 'or', or 'and'" + if whole_path: + files_list = sorted([os.path.join(path_dir, f) for f in os.listdir(path_dir) + if os.path.isfile(os.path.join(path_dir, f))]) + else: + files_list = sorted([f for f in os.listdir(path_dir) if os.path.isfile(os.path.join(path_dir, f))]) + if expr is not None: # assumed to be either str or list of str + if isinstance(expr, str): + expr = [expr] + elif not isinstance(expr, (list, tuple)): + raise Exception("if specified, 'expr' should be a string or list of strings.") + matched_list_files = list() + for match in expr: + tmp_matched_files_list = sorted([f for f in files_list if match in os.path.basename(f)]) + if cond_type == 'or': + files_list = [f for f in files_list if f not in tmp_matched_files_list] + matched_list_files += tmp_matched_files_list + elif cond_type == 'and': + files_list = tmp_matched_files_list + matched_list_files = tmp_matched_files_list + files_list = sorted(matched_list_files) + return files_list + + +def list_subfolders(path_dir, whole_path=True, expr=None, cond_type='or'): + """This function returns a list of subfolders contained in a folder, with possible regexp. + :param path_dir: path of a folder + :param whole_path: (optional) whether to return whole path or just the subfolder names. + :param expr: (optional) regexp for files to list. Can be a str or a list of str. + :param cond_type: (optional) if exp is a list, specify the logical link between expressions in exp. + Can be 'or', or 'and'. + :return: a list of subfolders + """ + assert isinstance(whole_path, bool), "whole_path should be bool" + assert cond_type in ['or', 'and'], "cond_type should be either 'or', or 'and'" + if whole_path: + subdirs_list = sorted([os.path.join(path_dir, f) for f in os.listdir(path_dir) + if os.path.isdir(os.path.join(path_dir, f))]) + else: + subdirs_list = sorted([f for f in os.listdir(path_dir) if os.path.isdir(os.path.join(path_dir, f))]) + if expr is not None: # assumed to be either str or list of str + if isinstance(expr, str): + expr = [expr] + elif not isinstance(expr, (list, tuple)): + raise Exception("if specified, 'expr' should be a string or list of strings.") + matched_list_subdirs = list() + for match in expr: + tmp_matched_list_subdirs = sorted([f for f in subdirs_list if match in os.path.basename(f)]) + if cond_type == 'or': + subdirs_list = [f for f in subdirs_list if f not in tmp_matched_list_subdirs] + matched_list_subdirs += tmp_matched_list_subdirs + elif cond_type == 'and': + subdirs_list = tmp_matched_list_subdirs + matched_list_subdirs = tmp_matched_list_subdirs + subdirs_list = sorted(matched_list_subdirs) + return subdirs_list + + +def get_image_extension(path): + name = os.path.basename(path) + if name[-7:] == '.nii.gz': + return 'nii.gz' + elif name[-4:] == '.mgz': + return 'mgz' + elif name[-4:] == '.nii': + return 'nii' + elif name[-4:] == '.npz': + return 'npz' + + +def strip_extension(path): + """Strip classical image extensions (.nii.gz, .nii, .mgz, .npz) from a filename.""" + return path.replace('.nii.gz', '').replace('.nii', '').replace('.mgz', '').replace('.npz', '') + + +def strip_suffix(path): + """Strip classical image suffix from a filename.""" + path = path.replace('_aseg', '') + path = path.replace('aseg', '') + path = path.replace('.aseg', '') + path = path.replace('_aseg_1', '') + path = path.replace('_aseg_2', '') + path = path.replace('aseg_1_', '') + path = path.replace('aseg_2_', '') + path = path.replace('_orig', '') + path = path.replace('orig', '') + path = path.replace('.orig', '') + path = path.replace('_norm', '') + path = path.replace('norm', '') + path = path.replace('.norm', '') + path = path.replace('_talairach', '') + path = path.replace('GSP_FS_4p5', 'GSP') + path = path.replace('.nii_crispSegmentation', '') + path = path.replace('_crispSegmentation', '') + path = path.replace('_seg', '') + path = path.replace('.seg', '') + path = path.replace('seg', '') + path = path.replace('_seg_1', '') + path = path.replace('_seg_2', '') + path = path.replace('seg_1_', '') + path = path.replace('seg_2_', '') + return path + + +def mkdir(path_dir): + """Recursively creates the current dir as well as its parent folders if they do not already exist.""" + if path_dir[-1] == '/': + path_dir = path_dir[:-1] + if not os.path.isdir(path_dir): + list_dir_to_create = [path_dir] + while not os.path.isdir(os.path.dirname(list_dir_to_create[-1])): + list_dir_to_create.append(os.path.dirname(list_dir_to_create[-1])) + for dir_to_create in reversed(list_dir_to_create): + os.mkdir(dir_to_create) + + +def mkcmd(*args): + """Creates terminal command with provided inputs. + Example: mkcmd('mv', 'source', 'dest') will give 'mv source dest'.""" + return ' '.join([str(arg) for arg in args]) + + +# ---------------------------------------------- shape-related functions ----------------------------------------------- + + +def get_dims(shape, max_channels=10): + """Get the number of dimensions and channels from the shape of an array. + The number of dimensions is assumed to be the length of the shape, as long as the shape of the last dimension is + inferior or equal to max_channels (default 3). + :param shape: shape of an array. Can be a sequence or a 1d numpy array. + :param max_channels: maximum possible number of channels. + :return: the number of dimensions and channels associated with the provided shape. + example 1: get_dims([150, 150, 150], max_channels=10) = (3, 1) + example 2: get_dims([150, 150, 150, 3], max_channels=10) = (3, 3) + example 3: get_dims([150, 150, 150, 15], max_channels=10) = (4, 1), because 5>3""" + if shape[-1] <= max_channels: + n_dims = len(shape) - 1 + n_channels = shape[-1] + else: + n_dims = len(shape) + n_channels = 1 + return n_dims, n_channels + + +def get_resample_shape(patch_shape, factor, n_channels=None): + """Compute the shape of a resampled array given a shape factor. + :param patch_shape: size of the initial array (without number of channels). + :param factor: resampling factor. Can be a number, sequence, or 1d numpy array. + :param n_channels: (optional) if not None, add a number of channel at the end of the computed shape. + :return: list containing the shape of the input array after being resampled by the given factor. + """ + factor = reformat_to_list(factor, length=len(patch_shape)) + shape = [math.ceil(patch_shape[i] * factor[i]) for i in range(len(patch_shape))] + if n_channels is not None: + shape += [n_channels] + return shape + + +def add_axis(x, axis=0): + """Add axis to a numpy array. + :param x: input array + :param axis: index of the new axis to add. Can also be a list of indices to add several axes at the same time.""" + axis = reformat_to_list(axis) + for ax in axis: + x = np.expand_dims(x, axis=ax) + return x + + +def get_padding_margin(cropping, loss_cropping): + """Compute padding margin""" + if (cropping is not None) & (loss_cropping is not None): + cropping = reformat_to_list(cropping) + loss_cropping = reformat_to_list(loss_cropping) + n_dims = max(len(cropping), len(loss_cropping)) + cropping = reformat_to_list(cropping, length=n_dims) + loss_cropping = reformat_to_list(loss_cropping, length=n_dims) + padding_margin = [int((cropping[i] - loss_cropping[i]) / 2) for i in range(n_dims)] + if len(padding_margin) == 1: + padding_margin = padding_margin[0] + else: + padding_margin = None + return padding_margin + + +# -------------------------------------------- build affine matrices/tensors ------------------------------------------- + + +def create_affine_transformation_matrix(n_dims, scaling=None, rotation=None, shearing=None, translation=None): + """Create a 4x4 affine transformation matrix from specified values + :param n_dims: integer, can either be 2 or 3. + :param scaling: list of 3 scaling values + :param rotation: list of 3 angles (degrees) for rotations around 1st, 2nd, 3rd axis + :param shearing: list of 6 shearing values + :param translation: list of 3 values + :return: 4x4 numpy matrix + """ + + T_scaling = np.eye(n_dims + 1) + T_shearing = np.eye(n_dims + 1) + T_translation = np.eye(n_dims + 1) + + if scaling is not None: + T_scaling[np.arange(n_dims + 1), np.arange(n_dims + 1)] = np.append(scaling, 1) + + if shearing is not None: + shearing_index = np.ones((n_dims + 1, n_dims + 1), dtype='bool') + shearing_index[np.eye(n_dims + 1, dtype='bool')] = False + shearing_index[-1, :] = np.zeros((n_dims + 1)) + shearing_index[:, -1] = np.zeros((n_dims + 1)) + T_shearing[shearing_index] = shearing + + if translation is not None: + T_translation[np.arange(n_dims), n_dims * np.ones(n_dims, dtype='int')] = translation + + if n_dims == 2: + if rotation is None: + rotation = np.zeros(1) + else: + rotation = np.asarray(rotation) * (math.pi / 180) + T_rot = np.eye(n_dims + 1) + T_rot[np.array([0, 1, 0, 1]), np.array([0, 0, 1, 1])] = [np.cos(rotation[0]), np.sin(rotation[0]), + np.sin(rotation[0]) * -1, np.cos(rotation[0])] + return T_translation @ T_rot @ T_shearing @ T_scaling + + else: + + if rotation is None: + rotation = np.zeros(n_dims) + else: + rotation = np.asarray(rotation) * (math.pi / 180) + T_rot1 = np.eye(n_dims + 1) + T_rot1[np.array([1, 2, 1, 2]), np.array([1, 1, 2, 2])] = [np.cos(rotation[0]), np.sin(rotation[0]), + np.sin(rotation[0]) * -1, np.cos(rotation[0])] + T_rot2 = np.eye(n_dims + 1) + T_rot2[np.array([0, 2, 0, 2]), np.array([0, 0, 2, 2])] = [np.cos(rotation[1]), np.sin(rotation[1]) * -1, + np.sin(rotation[1]), np.cos(rotation[1])] + T_rot3 = np.eye(n_dims + 1) + T_rot3[np.array([0, 1, 0, 1]), np.array([0, 0, 1, 1])] = [np.cos(rotation[2]), np.sin(rotation[2]), + np.sin(rotation[2]) * -1, np.cos(rotation[2])] + return T_translation @ T_rot3 @ T_rot2 @ T_rot1 @ T_shearing @ T_scaling + + +def sample_affine_transform(batchsize, + n_dims, + rotation_bounds=False, + scaling_bounds=False, + shearing_bounds=False, + translation_bounds=False, + enable_90_rotations=False): + """build batchsize x 4 x 4 tensor representing an affine transformation in homogeneous coordinates. + If return_inv is True, also returns the inverse of the created affine matrix.""" + + if (rotation_bounds is not False) | (enable_90_rotations is not False): + if n_dims == 2: + if rotation_bounds is not False: + rotation = draw_value_from_distribution(rotation_bounds, + size=1, + default_range=15.0, + return_as_tensor=True, + batchsize=batchsize) + else: + rotation = tf.zeros(tf.concat([batchsize, tf.ones(1, dtype='int32')], axis=0)) + else: # n_dims = 3 + if rotation_bounds is not False: + rotation = draw_value_from_distribution(rotation_bounds, + size=n_dims, + default_range=15.0, + return_as_tensor=True, + batchsize=batchsize) + else: + rotation = tf.zeros(tf.concat([batchsize, 3 * tf.ones(1, dtype='int32')], axis=0)) + if enable_90_rotations: + rotation = tf.cast(tf.random.uniform(tf.shape(rotation), maxval=4, dtype='int32') * 90, 'float32') \ + + rotation + T_rot = create_rotation_transform(rotation, n_dims) + else: + T_rot = tf.tile(tf.expand_dims(tf.eye(n_dims), axis=0), + tf.concat([batchsize, tf.ones(2, dtype='int32')], axis=0)) + + if shearing_bounds is not False: + shearing = draw_value_from_distribution(shearing_bounds, + size=n_dims ** 2 - n_dims, + default_range=.01, + return_as_tensor=True, + batchsize=batchsize) + T_shearing = create_shearing_transform(shearing, n_dims) + else: + T_shearing = tf.tile(tf.expand_dims(tf.eye(n_dims), axis=0), + tf.concat([batchsize, tf.ones(2, dtype='int32')], axis=0)) + + if scaling_bounds is not False: + scaling = draw_value_from_distribution(scaling_bounds, + size=n_dims, + centre=1, + default_range=.15, + return_as_tensor=True, + batchsize=batchsize) + T_scaling = tf.linalg.diag(scaling) + else: + T_scaling = tf.tile(tf.expand_dims(tf.eye(n_dims), axis=0), + tf.concat([batchsize, tf.ones(2, dtype='int32')], axis=0)) + + T = tf.matmul(T_scaling, tf.matmul(T_shearing, T_rot)) + + if translation_bounds is not False: + translation = draw_value_from_distribution(translation_bounds, + size=n_dims, + default_range=5, + return_as_tensor=True, + batchsize=batchsize) + T = tf.concat([T, tf.expand_dims(translation, axis=-1)], axis=-1) + else: + T = tf.concat([T, tf.zeros(tf.concat([tf.shape(T)[:2], tf.ones(1, dtype='int32')], 0))], axis=-1) + + # build rigid transform + T_last_row = tf.expand_dims(tf.concat([tf.zeros((1, n_dims)), tf.ones((1, 1))], axis=1), 0) + T_last_row = tf.tile(T_last_row, tf.concat([batchsize, tf.ones(2, dtype='int32')], axis=0)) + T = tf.concat([T, T_last_row], axis=1) + + return T + + +def create_rotation_transform(rotation, n_dims): + """build rotation transform from 3d or 2d rotation coefficients. Angles are given in degrees.""" + rotation = rotation * np.pi / 180 + if n_dims == 3: + shape = tf.shape(tf.expand_dims(rotation[..., 0], -1)) + + Rx_row0 = tf.expand_dims(tf.tile(tf.expand_dims(tf.convert_to_tensor([1., 0., 0.]), 0), shape), axis=1) + Rx_row1 = tf.stack([tf.zeros(shape), tf.expand_dims(tf.cos(rotation[..., 0]), -1), + tf.expand_dims(-tf.sin(rotation[..., 0]), -1)], axis=-1) + Rx_row2 = tf.stack([tf.zeros(shape), tf.expand_dims(tf.sin(rotation[..., 0]), -1), + tf.expand_dims(tf.cos(rotation[..., 0]), -1)], axis=-1) + Rx = tf.concat([Rx_row0, Rx_row1, Rx_row2], axis=1) + + Ry_row0 = tf.stack([tf.expand_dims(tf.cos(rotation[..., 1]), -1), tf.zeros(shape), + tf.expand_dims(tf.sin(rotation[..., 1]), -1)], axis=-1) + Ry_row1 = tf.expand_dims(tf.tile(tf.expand_dims(tf.convert_to_tensor([0., 1., 0.]), 0), shape), axis=1) + Ry_row2 = tf.stack([tf.expand_dims(-tf.sin(rotation[..., 1]), -1), tf.zeros(shape), + tf.expand_dims(tf.cos(rotation[..., 1]), -1)], axis=-1) + Ry = tf.concat([Ry_row0, Ry_row1, Ry_row2], axis=1) + + Rz_row0 = tf.stack([tf.expand_dims(tf.cos(rotation[..., 2]), -1), + tf.expand_dims(-tf.sin(rotation[..., 2]), -1), tf.zeros(shape)], axis=-1) + Rz_row1 = tf.stack([tf.expand_dims(tf.sin(rotation[..., 2]), -1), + tf.expand_dims(tf.cos(rotation[..., 2]), -1), tf.zeros(shape)], axis=-1) + Rz_row2 = tf.expand_dims(tf.tile(tf.expand_dims(tf.convert_to_tensor([0., 0., 1.]), 0), shape), axis=1) + Rz = tf.concat([Rz_row0, Rz_row1, Rz_row2], axis=1) + + T_rot = tf.matmul(tf.matmul(Rx, Ry), Rz) + + elif n_dims == 2: + R_row0 = tf.stack([tf.expand_dims(tf.cos(rotation[..., 0]), -1), + tf.expand_dims(tf.sin(rotation[..., 0]), -1)], axis=-1) + R_row1 = tf.stack([tf.expand_dims(-tf.sin(rotation[..., 0]), -1), + tf.expand_dims(tf.cos(rotation[..., 0]), -1)], axis=-1) + T_rot = tf.concat([R_row0, R_row1], axis=1) + + else: + raise Exception('only supports 2 or 3D.') + + return T_rot + + +def create_shearing_transform(shearing, n_dims): + """build shearing transform from 2d/3d shearing coefficients""" + shape = tf.shape(tf.expand_dims(shearing[..., 0], -1)) + if n_dims == 3: + shearing_row0 = tf.stack([tf.ones(shape), tf.expand_dims(shearing[..., 0], -1), + tf.expand_dims(shearing[..., 1], -1)], axis=-1) + shearing_row1 = tf.stack([tf.expand_dims(shearing[..., 2], -1), tf.ones(shape), + tf.expand_dims(shearing[..., 3], -1)], axis=-1) + shearing_row2 = tf.stack([tf.expand_dims(shearing[..., 4], -1), tf.expand_dims(shearing[..., 5], -1), + tf.ones(shape)], axis=-1) + T_shearing = tf.concat([shearing_row0, shearing_row1, shearing_row2], axis=1) + + elif n_dims == 2: + shearing_row0 = tf.stack([tf.ones(shape), tf.expand_dims(shearing[..., 0], -1)], axis=-1) + shearing_row1 = tf.stack([tf.expand_dims(shearing[..., 1], -1), tf.ones(shape)], axis=-1) + T_shearing = tf.concat([shearing_row0, shearing_row1], axis=1) + else: + raise Exception('only supports 2 or 3D.') + return T_shearing + + +# --------------------------------------------------- miscellaneous ---------------------------------------------------- + + +def infer(x): + """ Try to parse input to float. If it fails, tries boolean, and otherwise keep it as string """ + try: + x = float(x) + except ValueError: + if x == 'False': + x = False + elif x == 'True': + x = True + elif not isinstance(x, str): + raise TypeError('input should be an int/float/boolean/str, had {}'.format(type(x))) + return x + + +class LoopInfo: + """ + Class to print the current iteration in a for loop, and optionally the estimated remaining time. + Instantiate just before the loop, and call the update method at the start of the loop. + The printed text has the following format: + processing i/total remaining time: hh:mm:ss + """ + + def __init__(self, n_iterations, spacing=10, text='processing', print_time=False): + """ + :param n_iterations: total number of iterations of the for loop. + :param spacing: frequency at which the update info will be printed on screen. + :param text: text to print. Default is processing. + :param print_time: whether to print the estimated remaining time. Default is False. + """ + + # loop parameters + self.n_iterations = n_iterations + self.spacing = spacing + + # text parameters + self.text = text + self.print_time = print_time + self.print_previous_time = False + self.align = len(str(self.n_iterations)) * 2 + 1 + 3 + + # timing parameters + self.iteration_durations = np.zeros((n_iterations,)) + self.start = time.time() + self.previous = time.time() + + def update(self, idx): + + # time iteration + now = time.time() + self.iteration_durations[idx] = now - self.previous + self.previous = now + + # print text + if idx == 0: + print(self.text + ' 1/{}'.format(self.n_iterations)) + elif idx % self.spacing == self.spacing - 1: + iteration = str(idx + 1) + '/' + str(self.n_iterations) + if self.print_time: + # estimate remaining time + max_duration = np.max(self.iteration_durations) + average_duration = np.mean(self.iteration_durations[self.iteration_durations > .01 * max_duration]) + remaining_time = int(average_duration * (self.n_iterations - idx)) + # print total remaining time only if it is greater than 1s or if it was previously printed + if (remaining_time > 1) | self.print_previous_time: + eta = str(timedelta(seconds=remaining_time)) + print(self.text + ' {:<{x}} remaining time: {}'.format(iteration, eta, x=self.align)) + self.print_previous_time = True + else: + print(self.text + ' {}'.format(iteration)) + else: + print(self.text + ' {}'.format(iteration)) + + +def get_mapping_lut(source, dest=None): + """This functions returns the look-up table to map a list of N values (source) to another list (dest). + If the second list is not given, we assume it is equal to [0, ..., N-1].""" + + # initialise + source = np.array(reformat_to_list(source), dtype='int32') + n_labels = source.shape[0] + + # build new label list if necessary + if dest is None: + dest = np.arange(n_labels, dtype='int32') + else: + assert len(source) == len(dest), 'label_list and new_label_list should have the same length' + dest = np.array(reformat_to_list(dest, dtype='int')) + + # build look-up table + lut = np.zeros(np.max(source) + 1, dtype='int32') + for source, dest in zip(source, dest): + lut[source] = dest + + return lut + + +def build_training_generator(gen, batchsize): + """Build generator for training a network.""" + while True: + inputs = next(gen) + if batchsize > 1: + target = np.concatenate([np.zeros((1, 1))] * batchsize, 0) + else: + target = np.zeros((1, 1)) + yield inputs, target + + +def find_closest_number_divisible_by_m(n, m, answer_type='lower'): + """Return the closest integer to n that is divisible by m. answer_type can either be 'closer', 'lower' (only returns + values lower than n), or 'higher' (only returns values higher than m).""" + if n % m == 0: + return n + else: + q = int(n / m) + lower = q * m + higher = (q + 1) * m + if answer_type == 'lower': + return lower + elif answer_type == 'higher': + return higher + elif answer_type == 'closer': + return lower if (n - lower) < (higher - n) else higher + else: + raise Exception('answer_type should be lower, higher, or closer, had : %s' % answer_type) + + +def build_binary_structure(connectivity, n_dims, shape=None): + """Return a dilation/erosion element with provided connectivity""" + if shape is None: + shape = [connectivity * 2 + 1] * n_dims + else: + shape = reformat_to_list(shape, length=n_dims) + dist = np.ones(shape) + center = tuple([tuple([int(s / 2)]) for s in shape]) + dist[center] = 0 + dist = distance_transform_edt(dist) + struct = (dist <= connectivity) * 1 + return struct + + +def draw_value_from_distribution(hyperparameter, + size=1, + distribution='uniform', + centre=0., + default_range=10.0, + positive_only=False, + return_as_tensor=False, + batchsize=None): + """Sample values from a uniform, or normal distribution of given hyperparameters. + These hyperparameters are to the number of 2 in both uniform and normal cases. + :param hyperparameter: values of the hyperparameters. Can either be: + 1) None, in each case the two hyperparameters are given by [center-default_range, center+default_range], + 2) a number, where the two hyperparameters are given by [centre-hyperparameter, centre+hyperparameter], + 3) a sequence of length 2, directly defining the two hyperparameters: [min, max] if the distribution is uniform, + [mean, std] if the distribution is normal. + 4) a numpy array, with size (2, m). In this case, the function returns a 1d array of size m, where each value has + been sampled independently with the specified hyperparameters. If the distribution is uniform, rows correspond to + its lower and upper bounds, and if the distribution is normal, rows correspond to its mean and std deviation. + 5) a numpy array of size (2*n, m). Same as 4) but we first randomly select a block of two rows among the + n possibilities. + 6) the path to a numpy array corresponding to case 4 or 5. + 7) False, in which case this function returns None. + :param size: (optional) number of values to sample. All values are sampled independently. + Used only if hyperparameter is not a numpy array. + :param distribution: (optional) the distribution type. Can be 'uniform' or 'normal'. Default is 'uniform'. + :param centre: (optional) default centre to use if hyperparameter is None or a number. + :param default_range: (optional) default range to use if hyperparameter is None. + :param positive_only: (optional) whether to reset all negative values to zero. + :param return_as_tensor: (optional) whether to return the result as a tensorflow tensor + :param batchsize: (optional) if return_as_tensor is true, then you can sample a tensor of a given batchsize. Give + this batchsize as a tensorflow tensor here. + :return: a float, or a numpy 1d array if size > 1, or hyperparameter is itself a numpy array. + Returns None if hyperparameter is False. + """ + + # return False is hyperparameter is False + if hyperparameter is False: + return None + + # reformat parameter_range + hyperparameter = load_array_if_path(hyperparameter, load_as_numpy=True) + if not isinstance(hyperparameter, np.ndarray): + if hyperparameter is None: + hyperparameter = np.array([[centre - default_range] * size, [centre + default_range] * size]) + elif isinstance(hyperparameter, (int, float)): + hyperparameter = np.array([[centre - hyperparameter] * size, [centre + hyperparameter] * size]) + elif isinstance(hyperparameter, (list, tuple)): + assert len(hyperparameter) == 2, 'if list, parameter_range should be of length 2.' + hyperparameter = np.transpose(np.tile(np.array(hyperparameter), (size, 1))) + else: + raise ValueError('parameter_range should either be None, a number, a sequence, or a numpy array.') + elif isinstance(hyperparameter, np.ndarray): + assert hyperparameter.shape[0] % 2 == 0, 'number of rows of parameter_range should be divisible by 2' + n_modalities = int(hyperparameter.shape[0] / 2) + modality_idx = 2 * np.random.randint(n_modalities) + hyperparameter = hyperparameter[modality_idx: modality_idx + 2, :] + + # draw values as tensor + if return_as_tensor: + shape = KL.Lambda(lambda x: tf.convert_to_tensor(hyperparameter.shape[1], 'int32'))([]) + if batchsize is not None: + shape = KL.Lambda(lambda x: tf.concat([x[0], tf.expand_dims(x[1], axis=0)], axis=0))([batchsize, shape]) + if distribution == 'uniform': + parameter_value = KL.Lambda(lambda x: tf.random.uniform(shape=x, + minval=hyperparameter[0, :], + maxval=hyperparameter[1, :]))(shape) + elif distribution == 'normal': + parameter_value = KL.Lambda(lambda x: tf.random.normal(shape=x, + mean=hyperparameter[0, :], + stddev=hyperparameter[1, :]))(shape) + else: + raise ValueError("Distribution not supported, should be 'uniform' or 'normal'.") + + if positive_only: + parameter_value = KL.Lambda(lambda x: K.clip(x, 0, None))(parameter_value) + + # draw values as numpy array + else: + if distribution == 'uniform': + parameter_value = np.random.uniform(low=hyperparameter[0, :], high=hyperparameter[1, :]) + elif distribution == 'normal': + parameter_value = np.random.normal(loc=hyperparameter[0, :], scale=hyperparameter[1, :]) + else: + raise ValueError("Distribution not supported, should be 'uniform' or 'normal'.") + + if positive_only: + parameter_value[parameter_value < 0] = 0 + + return parameter_value + + +def build_exp(x, first, last, fix_point): + # first = f(0), last = f(+inf), fix_point = [x0, f(x0))] + a = last + b = first - last + c = - (1 / fix_point[0]) * np.log((fix_point[1] - last) / (first - last)) + return a + b * np.exp(-c * x)