Diff of /ext/lab2im/utils.py [000000] .. [e571d1]

Switch to unified view

a b/ext/lab2im/utils.py
1
"""
2
This file contains all the utilities used in that project. They are classified in 5 categories:
3
1- loading/saving functions:
4
    -load_volume
5
    -save_volume
6
    -get_volume_info
7
    -get_list_labels
8
    -load_array_if_path
9
    -write_pickle
10
    -read_pickle
11
    -write_model_summary
12
2- reformatting functions
13
    -reformat_to_list
14
    -reformat_to_n_channels_array
15
3- path related functions
16
    -list_images_in_folder
17
    -list_files
18
    -list_subfolders
19
    -strip_extension
20
    -strip_suffix
21
    -mkdir
22
    -mkcmd
23
4- shape-related functions
24
    -get_dims
25
    -get_resample_shape
26
    -add_axis
27
    -get_padding_margin
28
5- build affine matrices/tensors
29
    -create_affine_transformation_matrix
30
    -sample_affine_transform
31
    -create_rotation_transform
32
    -create_shearing_transform
33
6- miscellaneous
34
    -infer
35
    -LoopInfo
36
    -get_mapping_lut
37
    -build_training_generator
38
    -find_closest_number_divisible_by_m
39
    -build_binary_structure
40
    -draw_value_from_distribution
41
    -build_exp
42
43
44
If you use this code, please cite the first SynthSeg paper:
45
https://github.com/BBillot/lab2im/blob/master/bibtex.bib
46
47
Copyright 2020 Benjamin Billot
48
49
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
50
compliance with the License. You may obtain a copy of the License at
51
https://www.apache.org/licenses/LICENSE-2.0
52
Unless required by applicable law or agreed to in writing, software distributed under the License is
53
distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
54
implied. See the License for the specific language governing permissions and limitations under the
55
License.
56
"""
57
58
59
import os
60
import glob
61
import math
62
import time
63
import pickle
64
import numpy as np
65
import nibabel as nib
66
import tensorflow as tf
67
import keras.layers as KL
68
import keras.backend as K
69
from datetime import timedelta
70
from scipy.ndimage.morphology import distance_transform_edt
71
72
73
# ---------------------------------------------- loading/saving functions ----------------------------------------------
74
75
76
def load_volume(path_volume, im_only=True, squeeze=True, dtype=None, aff_ref=None):
77
    """
78
    Load volume file.
79
    :param path_volume: path of the volume to load. Can either be a nii, nii.gz, mgz, or npz format.
80
    If npz format, 1) the variable name is assumed to be 'vol_data',
81
    2) the volume is associated with an identity affine matrix and blank header.
82
    :param im_only: (optional) if False, the function also returns the affine matrix and header of the volume.
83
    :param squeeze: (optional) whether to squeeze the volume when loading.
84
    :param dtype: (optional) if not None, convert the loaded volume to this numpy dtype.
85
    :param aff_ref: (optional) If not None, the loaded volume is aligned to this affine matrix.
86
    The returned affine matrix is also given in this new space. Must be a numpy array of dimension 4x4.
87
    :return: the volume, with corresponding affine matrix and header if im_only is False.
88
    """
89
    assert path_volume.endswith(('.nii', '.nii.gz', '.mgz', '.npz')), 'Unknown data file: %s' % path_volume
90
91
    if path_volume.endswith(('.nii', '.nii.gz', '.mgz')):
92
        x = nib.load(path_volume)
93
        if squeeze:
94
            volume = np.squeeze(x.get_fdata())
95
        else:
96
            volume = x.get_fdata()
97
        aff = x.affine
98
        header = x.header
99
    else:  # npz
100
        volume = np.load(path_volume)['vol_data']
101
        if squeeze:
102
            volume = np.squeeze(volume)
103
        aff = np.eye(4)
104
        header = nib.Nifti1Header()
105
    if dtype is not None:
106
        if 'int' in dtype:
107
            volume = np.round(volume)
108
        volume = volume.astype(dtype=dtype)
109
110
    # align image to reference affine matrix
111
    if aff_ref is not None:
112
        from ext.lab2im import edit_volumes  # the import is done here to avoid import loops
113
        n_dims, _ = get_dims(list(volume.shape), max_channels=10)
114
        volume, aff = edit_volumes.align_volume_to_ref(volume, aff, aff_ref=aff_ref, return_aff=True, n_dims=n_dims)
115
116
    if im_only:
117
        return volume
118
    else:
119
        return volume, aff, header
120
121
122
def save_volume(volume, aff, header, path, res=None, dtype=None, n_dims=3):
123
    """
124
    Save a volume.
125
    :param volume: volume to save
126
    :param aff: affine matrix of the volume to save. If aff is None, the volume is saved with an identity affine matrix.
127
    aff can also be set to 'FS', in which case the volume is saved with the affine matrix of FreeSurfer outputs.
128
    :param header: header of the volume to save. If None, the volume is saved with a blank header.
129
    :param path: path where to save the volume.
130
    :param res: (optional) update the resolution in the header before saving the volume.
131
    :param dtype: (optional) numpy dtype for the saved volume.
132
    :param n_dims: (optional) number of dimensions, to avoid confusion in multi-channel case. Default is None, where
133
    n_dims is automatically inferred.
134
    """
135
136
    mkdir(os.path.dirname(path))
137
    if '.npz' in path:
138
        np.savez_compressed(path, vol_data=volume)
139
    else:
140
        if header is None:
141
            header = nib.Nifti1Header()
142
        if isinstance(aff, str):
143
            if aff == 'FS':
144
                aff = np.array([[-1, 0, 0, 0], [0, 0, 1, 0], [0, -1, 0, 0], [0, 0, 0, 1]])
145
        elif aff is None:
146
            aff = np.eye(4)
147
        if dtype is not None:
148
            if 'int' in dtype:
149
                volume = np.round(volume)
150
            volume = volume.astype(dtype=dtype)
151
            nifty = nib.Nifti1Image(volume, aff, header)
152
            nifty.set_data_dtype(dtype)
153
        else:
154
            nifty = nib.Nifti1Image(volume, aff, header)
155
        if res is not None:
156
            if n_dims is None:
157
                n_dims, _ = get_dims(volume.shape)
158
            res = reformat_to_list(res, length=n_dims, dtype=None)
159
            nifty.header.set_zooms(res)
160
        nib.save(nifty, path)
161
162
163
def get_volume_info(path_volume, return_volume=False, aff_ref=None, max_channels=10):
164
    """
165
    Gather information about a volume: shape, affine matrix, number of dimensions and channels, header, and resolution.
166
    :param path_volume: path of the volume to get information form.
167
    :param return_volume: (optional) whether to return the volume along with the information.
168
    :param aff_ref: (optional) If not None, the loaded volume is aligned to this affine matrix.
169
    All info relative to the volume is then given in this new space. Must be a numpy array of dimension 4x4.
170
    :param max_channels: maximum possible number of channels for the input volume.
171
    :return: volume (if return_volume is true), and corresponding info. If aff_ref is not None, the returned aff is
172
    the original one, i.e. the affine of the image before being aligned to aff_ref.
173
    """
174
    # read image
175
    im, aff, header = load_volume(path_volume, im_only=False)
176
177
    # understand if image is multichannel
178
    im_shape = list(im.shape)
179
    n_dims, n_channels = get_dims(im_shape, max_channels=max_channels)
180
    im_shape = im_shape[:n_dims]
181
182
    # get labels res
183
    if '.nii' in path_volume:
184
        data_res = np.array(header['pixdim'][1:n_dims + 1])
185
    elif '.mgz' in path_volume:
186
        data_res = np.array(header['delta'])  # mgz image
187
    else:
188
        data_res = np.array([1.0] * n_dims)
189
190
    # align to given affine matrix
191
    if aff_ref is not None:
192
        from ext.lab2im import edit_volumes  # the import is done here to avoid import loops
193
        ras_axes = edit_volumes.get_ras_axes(aff, n_dims=n_dims)
194
        ras_axes_ref = edit_volumes.get_ras_axes(aff_ref, n_dims=n_dims)
195
        im = edit_volumes.align_volume_to_ref(im, aff, aff_ref=aff_ref, n_dims=n_dims)
196
        im_shape = np.array(im_shape)
197
        data_res = np.array(data_res)
198
        im_shape[ras_axes_ref] = im_shape[ras_axes]
199
        data_res[ras_axes_ref] = data_res[ras_axes]
200
        im_shape = im_shape.tolist()
201
202
    # return info
203
    if return_volume:
204
        return im, im_shape, aff, n_dims, n_channels, header, data_res
205
    else:
206
        return im_shape, aff, n_dims, n_channels, header, data_res
207
208
209
def get_list_labels(label_list=None, labels_dir=None, save_label_list=None, FS_sort=False):
210
    """This function reads or computes a list of all label values used in a set of label maps.
211
    It can also sort all labels according to FreeSurfer lut.
212
    :param label_list: (optional) already computed label_list. Can be a sequence, a 1d numpy array, or the path to
213
    a numpy 1d array.
214
    :param labels_dir: (optional) if path_label_list is None, the label list is computed by reading all the label maps
215
    in the given folder. Can also be the path to a single label map.
216
    :param save_label_list: (optional) path where to save the label list.
217
    :param FS_sort: (optional) whether to sort label values according to the FreeSurfer classification.
218
    If true, the label values will be ordered as follows: neutral labels first (i.e. non-sided), left-side labels,
219
    and right-side labels. If FS_sort is True, this function also returns the number of neutral labels in label_list.
220
    :return: the label list (numpy 1d array), and the number of neutral (i.e. non-sided) labels if FS_sort is True.
221
    If one side of the brain is not represented at all in label_list, all labels are considered as neutral, and
222
    n_neutral_labels = len(label_list).
223
    """
224
225
    # load label list if previously computed
226
    if label_list is not None:
227
        label_list = np.array(reformat_to_list(label_list, load_as_numpy=True, dtype='int'))
228
229
    # compute label list from all label files
230
    elif labels_dir is not None:
231
        print('Compiling list of unique labels')
232
        # go through all labels files and compute unique list of labels
233
        labels_paths = list_images_in_folder(labels_dir)
234
        label_list = np.empty(0)
235
        loop_info = LoopInfo(len(labels_paths), 10, 'processing', print_time=True)
236
        for lab_idx, path in enumerate(labels_paths):
237
            loop_info.update(lab_idx)
238
            y = load_volume(path, dtype='int32')
239
            y_unique = np.unique(y)
240
            label_list = np.unique(np.concatenate((label_list, y_unique))).astype('int')
241
242
    else:
243
        raise Exception('either label_list, path_label_list or labels_dir should be provided')
244
245
    # sort labels in neutral/left/right according to FS labels
246
    n_neutral_labels = 0
247
    if FS_sort:
248
        neutral_FS_labels = [0, 14, 15, 16, 21, 22, 23, 24, 72, 77, 80, 85, 100, 101, 102, 103, 104, 105, 106, 107, 108,
249
                             109, 165, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210,
250
                             251, 252, 253, 254, 255, 258, 259, 260, 331, 332, 333, 334, 335, 336, 337, 338, 339, 340,
251
                             502, 506, 507, 508, 509, 511, 512, 514, 515, 516, 517, 530,
252
                             531, 532, 533, 534, 535, 536, 537]
253
        neutral = list()
254
        left = list()
255
        right = list()
256
        for la in label_list:
257
            if la in neutral_FS_labels:
258
                if la not in neutral:
259
                    neutral.append(la)
260
            elif (0 < la < 14) | (16 < la < 21) | (24 < la < 40) | (135 < la < 139) | (1000 <= la <= 1035) | \
261
                    (la == 865) | (20100 < la < 20110):
262
                if la not in left:
263
                    left.append(la)
264
            elif (39 < la < 72) | (162 < la < 165) | (2000 <= la <= 2035) | (20000 < la < 20010) | (la == 139) | \
265
                    (la == 866):
266
                if la not in right:
267
                    right.append(la)
268
            else:
269
                raise Exception('label {} not in our current FS classification, '
270
                                'please update get_list_labels in utils.py'.format(la))
271
        label_list = np.concatenate([sorted(neutral), sorted(left), sorted(right)])
272
        if ((len(left) > 0) & (len(right) > 0)) | ((len(left) == 0) & (len(right) == 0)):
273
            n_neutral_labels = len(neutral)
274
        else:
275
            n_neutral_labels = len(label_list)
276
277
    # save labels if specified
278
    if save_label_list is not None:
279
        np.save(save_label_list, np.int32(label_list))
280
281
    if FS_sort:
282
        return np.int32(label_list), n_neutral_labels
283
    else:
284
        return np.int32(label_list), None
285
286
287
def load_array_if_path(var, load_as_numpy=True):
288
    """If var is a string and load_as_numpy is True, this function loads the array writen at the path indicated by var.
289
    Otherwise it simply returns var as it is."""
290
    if (isinstance(var, str)) & load_as_numpy:
291
        assert os.path.isfile(var), 'No such path: %s' % var
292
        var = np.load(var)
293
    return var
294
295
296
def write_pickle(filepath, obj):
297
    """ write a python object with a pickle at a given path"""
298
    with open(filepath, 'wb') as file:
299
        pickler = pickle.Pickler(file)
300
        pickler.dump(obj)
301
302
303
def read_pickle(filepath):
304
    """ read a python object with a pickle"""
305
    with open(filepath, 'rb') as file:
306
        unpickler = pickle.Unpickler(file)
307
        return unpickler.load()
308
309
310
def write_model_summary(model, filepath='./model_summary.txt', line_length=150):
311
    """Write the summary of a keras model at a given path, with a given length for each line"""
312
    with open(filepath, 'w') as fh:
313
        model.summary(print_fn=lambda x: fh.write(x + '\n'), line_length=line_length)
314
315
316
# ----------------------------------------------- reformatting functions -----------------------------------------------
317
318
319
def reformat_to_list(var, length=None, load_as_numpy=False, dtype=None):
320
    """This function takes a variable and reformat it into a list of desired
321
    length and type (int, float, bool, str).
322
    If variable is a string, and load_as_numpy is True, it will be loaded as a numpy array.
323
    If variable is None, this function returns None.
324
    :param var: a str, int, float, list, tuple, or numpy array
325
    :param length: (optional) if var is a single item, it will be replicated to a list of this length
326
    :param load_as_numpy: (optional) whether var is the path to a numpy array
327
    :param dtype: (optional) convert all item to this type. Can be 'int', 'float', 'bool', or 'str'
328
    :return: reformatted list
329
    """
330
331
    # convert to list
332
    if var is None:
333
        return None
334
    var = load_array_if_path(var, load_as_numpy=load_as_numpy)
335
    if isinstance(var, (int, float, np.int, np.int32, np.int64, np.float, np.float32, np.float64)):
336
        var = [var]
337
    elif isinstance(var, tuple):
338
        var = list(var)
339
    elif isinstance(var, np.ndarray):
340
        if var.shape == (1,):
341
            var = [var[0]]
342
        else:
343
            var = np.squeeze(var).tolist()
344
    elif isinstance(var, str):
345
        var = [var]
346
    elif isinstance(var, bool):
347
        var = [var]
348
    if isinstance(var, list):
349
        if length is not None:
350
            if len(var) == 1:
351
                var = var * length
352
            elif len(var) != length:
353
                raise ValueError('if var is a list/tuple/numpy array, it should be of length 1 or {0}, '
354
                                 'had {1}'.format(length, var))
355
    else:
356
        raise TypeError('var should be an int, float, tuple, list, numpy array, or path to numpy array')
357
358
    # convert items type
359
    if dtype is not None:
360
        if dtype == 'int':
361
            var = [int(v) for v in var]
362
        elif dtype == 'float':
363
            var = [float(v) for v in var]
364
        elif dtype == 'bool':
365
            var = [bool(v) for v in var]
366
        elif dtype == 'str':
367
            var = [str(v) for v in var]
368
        else:
369
            raise ValueError("dtype should be 'str', 'float', 'int', or 'bool'; had {}".format(dtype))
370
    return var
371
372
373
def reformat_to_n_channels_array(var, n_dims=3, n_channels=1):
374
    """This function takes an int, float, list or tuple and reformat it to an array of shape (n_channels, n_dims).
375
    If resolution is a str, it will be assumed to be the path of a numpy array.
376
    If resolution is a numpy array, it will be checked to have shape (n_channels, n_dims).
377
    Finally if resolution is None, this function returns None as well."""
378
    if var is None:
379
        return [None] * n_channels
380
    if isinstance(var, str):
381
        var = np.load(var)
382
    # convert to numpy array
383
    if isinstance(var, (int, float, list, tuple)):
384
        var = reformat_to_list(var, n_dims)
385
        var = np.tile(np.array(var), (n_channels, 1))
386
    # check shape if numpy array
387
    elif isinstance(var, np.ndarray):
388
        if n_channels == 1:
389
            var = var.reshape((1, n_dims))
390
        else:
391
            if np.squeeze(var).shape == (n_dims,):
392
                var = np.tile(var.reshape((1, n_dims)), (n_channels, 1))
393
            elif var.shape != (n_channels, n_dims):
394
                raise ValueError('if array, var should be {0} or {1}'.format((1, n_dims), (n_channels, n_dims)))
395
    else:
396
        raise TypeError('var should be int, float, list, tuple or ndarray')
397
    return np.round(var, 3)
398
399
400
# ----------------------------------------------- path-related functions -----------------------------------------------
401
402
403
def list_images_in_folder(path_dir, include_single_image=True, check_if_empty=True):
404
    """List all files with extension nii, nii.gz, mgz, or npz within a folder."""
405
    basename = os.path.basename(path_dir)
406
    if include_single_image & \
407
            (('.nii.gz' in basename) | ('.nii' in basename) | ('.mgz' in basename) | ('.npz' in basename)):
408
        assert os.path.isfile(path_dir), 'file %s does not exist' % path_dir
409
        list_images = [path_dir]
410
    else:
411
        if os.path.isdir(path_dir):
412
            list_images = sorted(glob.glob(os.path.join(path_dir, '*nii.gz')) +
413
                                 glob.glob(os.path.join(path_dir, '*nii')) +
414
                                 glob.glob(os.path.join(path_dir, '*.mgz')) +
415
                                 glob.glob(os.path.join(path_dir, '*.npz')))
416
        else:
417
            raise Exception('Folder does not exist: %s' % path_dir)
418
        if check_if_empty:
419
            assert len(list_images) > 0, 'no .nii, .nii.gz, .mgz or .npz image could be found in %s' % path_dir
420
    return list_images
421
422
423
def list_files(path_dir, whole_path=True, expr=None, cond_type='or'):
424
    """This function returns a list of files contained in a folder, with possible regexp.
425
    :param path_dir: path of a folder
426
    :param whole_path: (optional) whether to return whole path or just the filenames.
427
    :param expr: (optional) regexp for files to list. Can be a str or a list of str.
428
    :param cond_type: (optional) if exp is a list, specify the logical link between expressions in exp.
429
    Can be 'or', or 'and'.
430
    :return: a list of files
431
    """
432
    assert isinstance(whole_path, bool), "whole_path should be bool"
433
    assert cond_type in ['or', 'and'], "cond_type should be either 'or', or 'and'"
434
    if whole_path:
435
        files_list = sorted([os.path.join(path_dir, f) for f in os.listdir(path_dir)
436
                             if os.path.isfile(os.path.join(path_dir, f))])
437
    else:
438
        files_list = sorted([f for f in os.listdir(path_dir) if os.path.isfile(os.path.join(path_dir, f))])
439
    if expr is not None:  # assumed to be either str or list of str
440
        if isinstance(expr, str):
441
            expr = [expr]
442
        elif not isinstance(expr, (list, tuple)):
443
            raise Exception("if specified, 'expr' should be a string or list of strings.")
444
        matched_list_files = list()
445
        for match in expr:
446
            tmp_matched_files_list = sorted([f for f in files_list if match in os.path.basename(f)])
447
            if cond_type == 'or':
448
                files_list = [f for f in files_list if f not in tmp_matched_files_list]
449
                matched_list_files += tmp_matched_files_list
450
            elif cond_type == 'and':
451
                files_list = tmp_matched_files_list
452
                matched_list_files = tmp_matched_files_list
453
        files_list = sorted(matched_list_files)
454
    return files_list
455
456
457
def list_subfolders(path_dir, whole_path=True, expr=None, cond_type='or'):
458
    """This function returns a list of subfolders contained in a folder, with possible regexp.
459
    :param path_dir: path of a folder
460
    :param whole_path: (optional) whether to return whole path or just the subfolder names.
461
    :param expr: (optional) regexp for files to list. Can be a str or a list of str.
462
    :param cond_type: (optional) if exp is a list, specify the logical link between expressions in exp.
463
    Can be 'or', or 'and'.
464
    :return: a list of subfolders
465
    """
466
    assert isinstance(whole_path, bool), "whole_path should be bool"
467
    assert cond_type in ['or', 'and'], "cond_type should be either 'or', or 'and'"
468
    if whole_path:
469
        subdirs_list = sorted([os.path.join(path_dir, f) for f in os.listdir(path_dir)
470
                               if os.path.isdir(os.path.join(path_dir, f))])
471
    else:
472
        subdirs_list = sorted([f for f in os.listdir(path_dir) if os.path.isdir(os.path.join(path_dir, f))])
473
    if expr is not None:  # assumed to be either str or list of str
474
        if isinstance(expr, str):
475
            expr = [expr]
476
        elif not isinstance(expr, (list, tuple)):
477
            raise Exception("if specified, 'expr' should be a string or list of strings.")
478
        matched_list_subdirs = list()
479
        for match in expr:
480
            tmp_matched_list_subdirs = sorted([f for f in subdirs_list if match in os.path.basename(f)])
481
            if cond_type == 'or':
482
                subdirs_list = [f for f in subdirs_list if f not in tmp_matched_list_subdirs]
483
                matched_list_subdirs += tmp_matched_list_subdirs
484
            elif cond_type == 'and':
485
                subdirs_list = tmp_matched_list_subdirs
486
                matched_list_subdirs = tmp_matched_list_subdirs
487
        subdirs_list = sorted(matched_list_subdirs)
488
    return subdirs_list
489
490
491
def get_image_extension(path):
492
    name = os.path.basename(path)
493
    if name[-7:] == '.nii.gz':
494
        return 'nii.gz'
495
    elif name[-4:] == '.mgz':
496
        return 'mgz'
497
    elif name[-4:] == '.nii':
498
        return 'nii'
499
    elif name[-4:] == '.npz':
500
        return 'npz'
501
502
503
def strip_extension(path):
504
    """Strip classical image extensions (.nii.gz, .nii, .mgz, .npz) from a filename."""
505
    return path.replace('.nii.gz', '').replace('.nii', '').replace('.mgz', '').replace('.npz', '')
506
507
508
def strip_suffix(path):
509
    """Strip classical image suffix from a filename."""
510
    path = path.replace('_aseg', '')
511
    path = path.replace('aseg', '')
512
    path = path.replace('.aseg', '')
513
    path = path.replace('_aseg_1', '')
514
    path = path.replace('_aseg_2', '')
515
    path = path.replace('aseg_1_', '')
516
    path = path.replace('aseg_2_', '')
517
    path = path.replace('_orig', '')
518
    path = path.replace('orig', '')
519
    path = path.replace('.orig', '')
520
    path = path.replace('_norm', '')
521
    path = path.replace('norm', '')
522
    path = path.replace('.norm', '')
523
    path = path.replace('_talairach', '')
524
    path = path.replace('GSP_FS_4p5', 'GSP')
525
    path = path.replace('.nii_crispSegmentation', '')
526
    path = path.replace('_crispSegmentation', '')
527
    path = path.replace('_seg', '')
528
    path = path.replace('.seg', '')
529
    path = path.replace('seg', '')
530
    path = path.replace('_seg_1', '')
531
    path = path.replace('_seg_2', '')
532
    path = path.replace('seg_1_', '')
533
    path = path.replace('seg_2_', '')
534
    return path
535
536
537
def mkdir(path_dir):
538
    """Recursively creates the current dir as well as its parent folders if they do not already exist."""
539
    if path_dir[-1] == '/':
540
        path_dir = path_dir[:-1]
541
    if not os.path.isdir(path_dir):
542
        list_dir_to_create = [path_dir]
543
        while not os.path.isdir(os.path.dirname(list_dir_to_create[-1])):
544
            list_dir_to_create.append(os.path.dirname(list_dir_to_create[-1]))
545
        for dir_to_create in reversed(list_dir_to_create):
546
            os.mkdir(dir_to_create)
547
548
549
def mkcmd(*args):
550
    """Creates terminal command with provided inputs.
551
    Example: mkcmd('mv', 'source', 'dest') will give 'mv source dest'."""
552
    return ' '.join([str(arg) for arg in args])
553
554
555
# ---------------------------------------------- shape-related functions -----------------------------------------------
556
557
558
def get_dims(shape, max_channels=10):
559
    """Get the number of dimensions and channels from the shape of an array.
560
    The number of dimensions is assumed to be the length of the shape, as long as the shape of the last dimension is
561
    inferior or equal to max_channels (default 3).
562
    :param shape: shape of an array. Can be a sequence or a 1d numpy array.
563
    :param max_channels: maximum possible number of channels.
564
    :return: the number of dimensions and channels associated with the provided shape.
565
    example 1: get_dims([150, 150, 150], max_channels=10) = (3, 1)
566
    example 2: get_dims([150, 150, 150, 3], max_channels=10) = (3, 3)
567
    example 3: get_dims([150, 150, 150, 15], max_channels=10) = (4, 1), because 5>3"""
568
    if shape[-1] <= max_channels:
569
        n_dims = len(shape) - 1
570
        n_channels = shape[-1]
571
    else:
572
        n_dims = len(shape)
573
        n_channels = 1
574
    return n_dims, n_channels
575
576
577
def get_resample_shape(patch_shape, factor, n_channels=None):
578
    """Compute the shape of a resampled array given a shape factor.
579
    :param patch_shape: size of the initial array (without number of channels).
580
    :param factor: resampling factor. Can be a number, sequence, or 1d numpy array.
581
    :param n_channels: (optional) if not None, add a number of channel at the end of the computed shape.
582
    :return: list containing the shape of the input array after being resampled by the given factor.
583
    """
584
    factor = reformat_to_list(factor, length=len(patch_shape))
585
    shape = [math.ceil(patch_shape[i] * factor[i]) for i in range(len(patch_shape))]
586
    if n_channels is not None:
587
        shape += [n_channels]
588
    return shape
589
590
591
def add_axis(x, axis=0):
592
    """Add axis to a numpy array.
593
    :param x: input array
594
    :param axis: index of the new axis to add. Can also be a list of indices to add several axes at the same time."""
595
    axis = reformat_to_list(axis)
596
    for ax in axis:
597
        x = np.expand_dims(x, axis=ax)
598
    return x
599
600
601
def get_padding_margin(cropping, loss_cropping):
602
    """Compute padding margin"""
603
    if (cropping is not None) & (loss_cropping is not None):
604
        cropping = reformat_to_list(cropping)
605
        loss_cropping = reformat_to_list(loss_cropping)
606
        n_dims = max(len(cropping), len(loss_cropping))
607
        cropping = reformat_to_list(cropping, length=n_dims)
608
        loss_cropping = reformat_to_list(loss_cropping, length=n_dims)
609
        padding_margin = [int((cropping[i] - loss_cropping[i]) / 2) for i in range(n_dims)]
610
        if len(padding_margin) == 1:
611
            padding_margin = padding_margin[0]
612
    else:
613
        padding_margin = None
614
    return padding_margin
615
616
617
# -------------------------------------------- build affine matrices/tensors -------------------------------------------
618
619
620
def create_affine_transformation_matrix(n_dims, scaling=None, rotation=None, shearing=None, translation=None):
621
    """Create a 4x4 affine transformation matrix from specified values
622
    :param n_dims: integer, can either be 2 or 3.
623
    :param scaling: list of 3 scaling values
624
    :param rotation: list of 3 angles (degrees) for rotations around 1st, 2nd, 3rd axis
625
    :param shearing: list of 6 shearing values
626
    :param translation: list of 3 values
627
    :return: 4x4 numpy matrix
628
    """
629
630
    T_scaling = np.eye(n_dims + 1)
631
    T_shearing = np.eye(n_dims + 1)
632
    T_translation = np.eye(n_dims + 1)
633
634
    if scaling is not None:
635
        T_scaling[np.arange(n_dims + 1), np.arange(n_dims + 1)] = np.append(scaling, 1)
636
637
    if shearing is not None:
638
        shearing_index = np.ones((n_dims + 1, n_dims + 1), dtype='bool')
639
        shearing_index[np.eye(n_dims + 1, dtype='bool')] = False
640
        shearing_index[-1, :] = np.zeros((n_dims + 1))
641
        shearing_index[:, -1] = np.zeros((n_dims + 1))
642
        T_shearing[shearing_index] = shearing
643
644
    if translation is not None:
645
        T_translation[np.arange(n_dims), n_dims * np.ones(n_dims, dtype='int')] = translation
646
647
    if n_dims == 2:
648
        if rotation is None:
649
            rotation = np.zeros(1)
650
        else:
651
            rotation = np.asarray(rotation) * (math.pi / 180)
652
        T_rot = np.eye(n_dims + 1)
653
        T_rot[np.array([0, 1, 0, 1]), np.array([0, 0, 1, 1])] = [np.cos(rotation[0]), np.sin(rotation[0]),
654
                                                                 np.sin(rotation[0]) * -1, np.cos(rotation[0])]
655
        return T_translation @ T_rot @ T_shearing @ T_scaling
656
657
    else:
658
659
        if rotation is None:
660
            rotation = np.zeros(n_dims)
661
        else:
662
            rotation = np.asarray(rotation) * (math.pi / 180)
663
        T_rot1 = np.eye(n_dims + 1)
664
        T_rot1[np.array([1, 2, 1, 2]), np.array([1, 1, 2, 2])] = [np.cos(rotation[0]), np.sin(rotation[0]),
665
                                                                  np.sin(rotation[0]) * -1, np.cos(rotation[0])]
666
        T_rot2 = np.eye(n_dims + 1)
667
        T_rot2[np.array([0, 2, 0, 2]), np.array([0, 0, 2, 2])] = [np.cos(rotation[1]), np.sin(rotation[1]) * -1,
668
                                                                  np.sin(rotation[1]), np.cos(rotation[1])]
669
        T_rot3 = np.eye(n_dims + 1)
670
        T_rot3[np.array([0, 1, 0, 1]), np.array([0, 0, 1, 1])] = [np.cos(rotation[2]), np.sin(rotation[2]),
671
                                                                  np.sin(rotation[2]) * -1, np.cos(rotation[2])]
672
        return T_translation @ T_rot3 @ T_rot2 @ T_rot1 @ T_shearing @ T_scaling
673
674
675
def sample_affine_transform(batchsize,
676
                            n_dims,
677
                            rotation_bounds=False,
678
                            scaling_bounds=False,
679
                            shearing_bounds=False,
680
                            translation_bounds=False,
681
                            enable_90_rotations=False):
682
    """build batchsize x 4 x 4 tensor representing an affine transformation in homogeneous coordinates.
683
    If return_inv is True, also returns the inverse of the created affine matrix."""
684
685
    if (rotation_bounds is not False) | (enable_90_rotations is not False):
686
        if n_dims == 2:
687
            if rotation_bounds is not False:
688
                rotation = draw_value_from_distribution(rotation_bounds,
689
                                                        size=1,
690
                                                        default_range=15.0,
691
                                                        return_as_tensor=True,
692
                                                        batchsize=batchsize)
693
            else:
694
                rotation = tf.zeros(tf.concat([batchsize, tf.ones(1, dtype='int32')], axis=0))
695
        else:  # n_dims = 3
696
            if rotation_bounds is not False:
697
                rotation = draw_value_from_distribution(rotation_bounds,
698
                                                        size=n_dims,
699
                                                        default_range=15.0,
700
                                                        return_as_tensor=True,
701
                                                        batchsize=batchsize)
702
            else:
703
                rotation = tf.zeros(tf.concat([batchsize, 3 * tf.ones(1, dtype='int32')], axis=0))
704
        if enable_90_rotations:
705
            rotation = tf.cast(tf.random.uniform(tf.shape(rotation), maxval=4, dtype='int32') * 90, 'float32') \
706
                       + rotation
707
        T_rot = create_rotation_transform(rotation, n_dims)
708
    else:
709
        T_rot = tf.tile(tf.expand_dims(tf.eye(n_dims), axis=0),
710
                        tf.concat([batchsize, tf.ones(2, dtype='int32')], axis=0))
711
712
    if shearing_bounds is not False:
713
        shearing = draw_value_from_distribution(shearing_bounds,
714
                                                size=n_dims ** 2 - n_dims,
715
                                                default_range=.01,
716
                                                return_as_tensor=True,
717
                                                batchsize=batchsize)
718
        T_shearing = create_shearing_transform(shearing, n_dims)
719
    else:
720
        T_shearing = tf.tile(tf.expand_dims(tf.eye(n_dims), axis=0),
721
                             tf.concat([batchsize, tf.ones(2, dtype='int32')], axis=0))
722
723
    if scaling_bounds is not False:
724
        scaling = draw_value_from_distribution(scaling_bounds,
725
                                               size=n_dims,
726
                                               centre=1,
727
                                               default_range=.15,
728
                                               return_as_tensor=True,
729
                                               batchsize=batchsize)
730
        T_scaling = tf.linalg.diag(scaling)
731
    else:
732
        T_scaling = tf.tile(tf.expand_dims(tf.eye(n_dims), axis=0),
733
                            tf.concat([batchsize, tf.ones(2, dtype='int32')], axis=0))
734
735
    T = tf.matmul(T_scaling, tf.matmul(T_shearing, T_rot))
736
737
    if translation_bounds is not False:
738
        translation = draw_value_from_distribution(translation_bounds,
739
                                                   size=n_dims,
740
                                                   default_range=5,
741
                                                   return_as_tensor=True,
742
                                                   batchsize=batchsize)
743
        T = tf.concat([T, tf.expand_dims(translation, axis=-1)], axis=-1)
744
    else:
745
        T = tf.concat([T, tf.zeros(tf.concat([tf.shape(T)[:2], tf.ones(1, dtype='int32')], 0))], axis=-1)
746
747
    # build rigid transform
748
    T_last_row = tf.expand_dims(tf.concat([tf.zeros((1, n_dims)), tf.ones((1, 1))], axis=1), 0)
749
    T_last_row = tf.tile(T_last_row, tf.concat([batchsize, tf.ones(2, dtype='int32')], axis=0))
750
    T = tf.concat([T, T_last_row], axis=1)
751
752
    return T
753
754
755
def create_rotation_transform(rotation, n_dims):
756
    """build rotation transform from 3d or 2d rotation coefficients. Angles are given in degrees."""
757
    rotation = rotation * np.pi / 180
758
    if n_dims == 3:
759
        shape = tf.shape(tf.expand_dims(rotation[..., 0], -1))
760
761
        Rx_row0 = tf.expand_dims(tf.tile(tf.expand_dims(tf.convert_to_tensor([1., 0., 0.]), 0), shape), axis=1)
762
        Rx_row1 = tf.stack([tf.zeros(shape), tf.expand_dims(tf.cos(rotation[..., 0]), -1),
763
                            tf.expand_dims(-tf.sin(rotation[..., 0]), -1)], axis=-1)
764
        Rx_row2 = tf.stack([tf.zeros(shape), tf.expand_dims(tf.sin(rotation[..., 0]), -1),
765
                            tf.expand_dims(tf.cos(rotation[..., 0]), -1)], axis=-1)
766
        Rx = tf.concat([Rx_row0, Rx_row1, Rx_row2], axis=1)
767
768
        Ry_row0 = tf.stack([tf.expand_dims(tf.cos(rotation[..., 1]), -1), tf.zeros(shape),
769
                            tf.expand_dims(tf.sin(rotation[..., 1]), -1)], axis=-1)
770
        Ry_row1 = tf.expand_dims(tf.tile(tf.expand_dims(tf.convert_to_tensor([0., 1., 0.]), 0), shape), axis=1)
771
        Ry_row2 = tf.stack([tf.expand_dims(-tf.sin(rotation[..., 1]), -1), tf.zeros(shape),
772
                            tf.expand_dims(tf.cos(rotation[..., 1]), -1)], axis=-1)
773
        Ry = tf.concat([Ry_row0, Ry_row1, Ry_row2], axis=1)
774
775
        Rz_row0 = tf.stack([tf.expand_dims(tf.cos(rotation[..., 2]), -1),
776
                            tf.expand_dims(-tf.sin(rotation[..., 2]), -1), tf.zeros(shape)], axis=-1)
777
        Rz_row1 = tf.stack([tf.expand_dims(tf.sin(rotation[..., 2]), -1),
778
                            tf.expand_dims(tf.cos(rotation[..., 2]), -1), tf.zeros(shape)], axis=-1)
779
        Rz_row2 = tf.expand_dims(tf.tile(tf.expand_dims(tf.convert_to_tensor([0., 0., 1.]), 0), shape), axis=1)
780
        Rz = tf.concat([Rz_row0, Rz_row1, Rz_row2], axis=1)
781
782
        T_rot = tf.matmul(tf.matmul(Rx, Ry), Rz)
783
784
    elif n_dims == 2:
785
        R_row0 = tf.stack([tf.expand_dims(tf.cos(rotation[..., 0]), -1),
786
                           tf.expand_dims(tf.sin(rotation[..., 0]), -1)], axis=-1)
787
        R_row1 = tf.stack([tf.expand_dims(-tf.sin(rotation[..., 0]), -1),
788
                           tf.expand_dims(tf.cos(rotation[..., 0]), -1)], axis=-1)
789
        T_rot = tf.concat([R_row0, R_row1], axis=1)
790
791
    else:
792
        raise Exception('only supports 2 or 3D.')
793
794
    return T_rot
795
796
797
def create_shearing_transform(shearing, n_dims):
798
    """build shearing transform from 2d/3d shearing coefficients"""
799
    shape = tf.shape(tf.expand_dims(shearing[..., 0], -1))
800
    if n_dims == 3:
801
        shearing_row0 = tf.stack([tf.ones(shape), tf.expand_dims(shearing[..., 0], -1),
802
                                  tf.expand_dims(shearing[..., 1], -1)], axis=-1)
803
        shearing_row1 = tf.stack([tf.expand_dims(shearing[..., 2], -1), tf.ones(shape),
804
                                  tf.expand_dims(shearing[..., 3], -1)], axis=-1)
805
        shearing_row2 = tf.stack([tf.expand_dims(shearing[..., 4], -1), tf.expand_dims(shearing[..., 5], -1),
806
                                  tf.ones(shape)], axis=-1)
807
        T_shearing = tf.concat([shearing_row0, shearing_row1, shearing_row2], axis=1)
808
809
    elif n_dims == 2:
810
        shearing_row0 = tf.stack([tf.ones(shape), tf.expand_dims(shearing[..., 0], -1)], axis=-1)
811
        shearing_row1 = tf.stack([tf.expand_dims(shearing[..., 1], -1), tf.ones(shape)], axis=-1)
812
        T_shearing = tf.concat([shearing_row0, shearing_row1], axis=1)
813
    else:
814
        raise Exception('only supports 2 or 3D.')
815
    return T_shearing
816
817
818
# --------------------------------------------------- miscellaneous ----------------------------------------------------
819
820
821
def infer(x):
822
    """ Try to parse input to float. If it fails, tries boolean, and otherwise keep it as string """
823
    try:
824
        x = float(x)
825
    except ValueError:
826
        if x == 'False':
827
            x = False
828
        elif x == 'True':
829
            x = True
830
        elif not isinstance(x, str):
831
            raise TypeError('input should be an int/float/boolean/str, had {}'.format(type(x)))
832
    return x
833
834
835
class LoopInfo:
836
    """
837
    Class to print the current iteration in a for loop, and optionally the estimated remaining time.
838
    Instantiate just before the loop, and call the update method at the start of the loop.
839
    The printed text has the following format:
840
    processing i/total    remaining time: hh:mm:ss
841
    """
842
843
    def __init__(self, n_iterations, spacing=10, text='processing', print_time=False):
844
        """
845
        :param n_iterations: total number of iterations of the for loop.
846
        :param spacing: frequency at which the update info will be printed on screen.
847
        :param text: text to print. Default is processing.
848
        :param print_time: whether to print the estimated remaining time. Default is False.
849
        """
850
851
        # loop parameters
852
        self.n_iterations = n_iterations
853
        self.spacing = spacing
854
855
        # text parameters
856
        self.text = text
857
        self.print_time = print_time
858
        self.print_previous_time = False
859
        self.align = len(str(self.n_iterations)) * 2 + 1 + 3
860
861
        # timing parameters
862
        self.iteration_durations = np.zeros((n_iterations,))
863
        self.start = time.time()
864
        self.previous = time.time()
865
866
    def update(self, idx):
867
868
        # time iteration
869
        now = time.time()
870
        self.iteration_durations[idx] = now - self.previous
871
        self.previous = now
872
873
        # print text
874
        if idx == 0:
875
            print(self.text + ' 1/{}'.format(self.n_iterations))
876
        elif idx % self.spacing == self.spacing - 1:
877
            iteration = str(idx + 1) + '/' + str(self.n_iterations)
878
            if self.print_time:
879
                # estimate remaining time
880
                max_duration = np.max(self.iteration_durations)
881
                average_duration = np.mean(self.iteration_durations[self.iteration_durations > .01 * max_duration])
882
                remaining_time = int(average_duration * (self.n_iterations - idx))
883
                # print total remaining time only if it is greater than 1s or if it was previously printed
884
                if (remaining_time > 1) | self.print_previous_time:
885
                    eta = str(timedelta(seconds=remaining_time))
886
                    print(self.text + ' {:<{x}} remaining time: {}'.format(iteration, eta, x=self.align))
887
                    self.print_previous_time = True
888
                else:
889
                    print(self.text + ' {}'.format(iteration))
890
            else:
891
                print(self.text + ' {}'.format(iteration))
892
893
894
def get_mapping_lut(source, dest=None):
895
    """This functions returns the look-up table to map a list of N values (source) to another list (dest).
896
    If the second list is not given, we assume it is equal to [0, ..., N-1]."""
897
898
    # initialise
899
    source = np.array(reformat_to_list(source), dtype='int32')
900
    n_labels = source.shape[0]
901
902
    # build new label list if necessary
903
    if dest is None:
904
        dest = np.arange(n_labels, dtype='int32')
905
    else:
906
        assert len(source) == len(dest), 'label_list and new_label_list should have the same length'
907
        dest = np.array(reformat_to_list(dest, dtype='int'))
908
909
    # build look-up table
910
    lut = np.zeros(np.max(source) + 1, dtype='int32')
911
    for source, dest in zip(source, dest):
912
        lut[source] = dest
913
914
    return lut
915
916
917
def build_training_generator(gen, batchsize):
918
    """Build generator for training a network."""
919
    while True:
920
        inputs = next(gen)
921
        if batchsize > 1:
922
            target = np.concatenate([np.zeros((1, 1))] * batchsize, 0)
923
        else:
924
            target = np.zeros((1, 1))
925
        yield inputs, target
926
927
928
def find_closest_number_divisible_by_m(n, m, answer_type='lower'):
929
    """Return the closest integer to n that is divisible by m. answer_type can either be 'closer', 'lower' (only returns
930
    values lower than n), or 'higher' (only returns values higher than m)."""
931
    if n % m == 0:
932
        return n
933
    else:
934
        q = int(n / m)
935
        lower = q * m
936
        higher = (q + 1) * m
937
        if answer_type == 'lower':
938
            return lower
939
        elif answer_type == 'higher':
940
            return higher
941
        elif answer_type == 'closer':
942
            return lower if (n - lower) < (higher - n) else higher
943
        else:
944
            raise Exception('answer_type should be lower, higher, or closer, had : %s' % answer_type)
945
946
947
def build_binary_structure(connectivity, n_dims, shape=None):
948
    """Return a dilation/erosion element with provided connectivity"""
949
    if shape is None:
950
        shape = [connectivity * 2 + 1] * n_dims
951
    else:
952
        shape = reformat_to_list(shape, length=n_dims)
953
    dist = np.ones(shape)
954
    center = tuple([tuple([int(s / 2)]) for s in shape])
955
    dist[center] = 0
956
    dist = distance_transform_edt(dist)
957
    struct = (dist <= connectivity) * 1
958
    return struct
959
960
961
def draw_value_from_distribution(hyperparameter,
962
                                 size=1,
963
                                 distribution='uniform',
964
                                 centre=0.,
965
                                 default_range=10.0,
966
                                 positive_only=False,
967
                                 return_as_tensor=False,
968
                                 batchsize=None):
969
    """Sample values from a uniform, or normal distribution of given hyperparameters.
970
    These hyperparameters are to the number of 2 in both uniform and normal cases.
971
    :param hyperparameter: values of the hyperparameters. Can either be:
972
    1) None, in each case the two hyperparameters are given by [center-default_range, center+default_range],
973
    2) a number, where the two hyperparameters are given by [centre-hyperparameter, centre+hyperparameter],
974
    3) a sequence of length 2, directly defining the two hyperparameters: [min, max] if the distribution is uniform,
975
    [mean, std] if the distribution is normal.
976
    4) a numpy array, with size (2, m). In this case, the function returns a 1d array of size m, where each value has
977
    been sampled independently with the specified hyperparameters. If the distribution is uniform, rows correspond to
978
    its lower and upper bounds, and if the distribution is normal, rows correspond to its mean and std deviation.
979
    5) a numpy array of size (2*n, m). Same as 4) but we first randomly select a block of two rows among the
980
    n possibilities.
981
    6) the path to a numpy array corresponding to case 4 or 5.
982
    7) False, in which case this function returns None.
983
    :param size: (optional) number of values to sample. All values are sampled independently.
984
    Used only if hyperparameter is not a numpy array.
985
    :param distribution: (optional) the distribution type. Can be 'uniform' or 'normal'. Default is 'uniform'.
986
    :param centre: (optional) default centre to use if hyperparameter is None or a number.
987
    :param default_range: (optional) default range to use if hyperparameter is None.
988
    :param positive_only: (optional) whether to reset all negative values to zero.
989
    :param return_as_tensor: (optional) whether to return the result as a tensorflow tensor
990
    :param batchsize: (optional) if return_as_tensor is true, then you can sample a tensor of a given batchsize. Give
991
    this batchsize as a tensorflow tensor here.
992
    :return: a float, or a numpy 1d array if size > 1, or hyperparameter is itself a numpy array.
993
    Returns None if hyperparameter is False.
994
    """
995
996
    # return False is hyperparameter is False
997
    if hyperparameter is False:
998
        return None
999
1000
    # reformat parameter_range
1001
    hyperparameter = load_array_if_path(hyperparameter, load_as_numpy=True)
1002
    if not isinstance(hyperparameter, np.ndarray):
1003
        if hyperparameter is None:
1004
            hyperparameter = np.array([[centre - default_range] * size, [centre + default_range] * size])
1005
        elif isinstance(hyperparameter, (int, float)):
1006
            hyperparameter = np.array([[centre - hyperparameter] * size, [centre + hyperparameter] * size])
1007
        elif isinstance(hyperparameter, (list, tuple)):
1008
            assert len(hyperparameter) == 2, 'if list, parameter_range should be of length 2.'
1009
            hyperparameter = np.transpose(np.tile(np.array(hyperparameter), (size, 1)))
1010
        else:
1011
            raise ValueError('parameter_range should either be None, a number, a sequence, or a numpy array.')
1012
    elif isinstance(hyperparameter, np.ndarray):
1013
        assert hyperparameter.shape[0] % 2 == 0, 'number of rows of parameter_range should be divisible by 2'
1014
        n_modalities = int(hyperparameter.shape[0] / 2)
1015
        modality_idx = 2 * np.random.randint(n_modalities)
1016
        hyperparameter = hyperparameter[modality_idx: modality_idx + 2, :]
1017
1018
    # draw values as tensor
1019
    if return_as_tensor:
1020
        shape = KL.Lambda(lambda x: tf.convert_to_tensor(hyperparameter.shape[1], 'int32'))([])
1021
        if batchsize is not None:
1022
            shape = KL.Lambda(lambda x: tf.concat([x[0], tf.expand_dims(x[1], axis=0)], axis=0))([batchsize, shape])
1023
        if distribution == 'uniform':
1024
            parameter_value = KL.Lambda(lambda x: tf.random.uniform(shape=x,
1025
                                                                    minval=hyperparameter[0, :],
1026
                                                                    maxval=hyperparameter[1, :]))(shape)
1027
        elif distribution == 'normal':
1028
            parameter_value = KL.Lambda(lambda x: tf.random.normal(shape=x,
1029
                                                                   mean=hyperparameter[0, :],
1030
                                                                   stddev=hyperparameter[1, :]))(shape)
1031
        else:
1032
            raise ValueError("Distribution not supported, should be 'uniform' or 'normal'.")
1033
1034
        if positive_only:
1035
            parameter_value = KL.Lambda(lambda x: K.clip(x, 0, None))(parameter_value)
1036
1037
    # draw values as numpy array
1038
    else:
1039
        if distribution == 'uniform':
1040
            parameter_value = np.random.uniform(low=hyperparameter[0, :], high=hyperparameter[1, :])
1041
        elif distribution == 'normal':
1042
            parameter_value = np.random.normal(loc=hyperparameter[0, :], scale=hyperparameter[1, :])
1043
        else:
1044
            raise ValueError("Distribution not supported, should be 'uniform' or 'normal'.")
1045
1046
        if positive_only:
1047
            parameter_value[parameter_value < 0] = 0
1048
1049
    return parameter_value
1050
1051
1052
def build_exp(x, first, last, fix_point):
1053
    # first = f(0), last = f(+inf), fix_point = [x0, f(x0))]
1054
    a = last
1055
    b = first - last
1056
    c = - (1 / fix_point[0]) * np.log((fix_point[1] - last) / (first - last))
1057
    return a + b * np.exp(-c * x)