a b/survival4D/nn/__init__.py
1
import optunity
2
import numpy as np
3
from lifelines.utils import concordance_index
4
5
6
def prepare_data(x, e, t):
7
    return x.astype("float32"), e.astype("int32"), t.astype("float32")
8
9
10
def sort4minibatches(xvals, evals, tvals, batchsize):
11
    ntot = len(xvals)
12
    indices = np.arange(ntot)
13
    np.random.shuffle(indices)
14
    start_idx = 0
15
    esall = []
16
    for end_idx in list(range(batchsize, batchsize*(ntot//batchsize)+1, batchsize))+[ntot]:
17
        excerpt = indices[start_idx:end_idx]
18
        sort_idx = np.argsort(tvals[excerpt])[::-1]
19
        es = excerpt[sort_idx]
20
        esall += list(es)
21
        start_idx = end_idx
22
    return xvals[esall], evals[esall], tvals[esall], esall
23
24
25
# 1. Hyperparameter search for Deep Learning model
26
def hypersearch_nn(x_data, y_data, method, nfolds, nevals, batch_size, num_epochs, backend: str,
27
                   model_kwargs: dict, **hypersearch):
28
29
    @optunity.cross_validated(x=x_data, y=y_data, num_folds=nfolds)
30
    def modelrun(x_train, y_train, x_test, y_test, **hypersearch):
31
        cv_log = train_nn(
32
            backend=backend, xtr=x_train, ytr=y_train, batch_size=batch_size, n_epochs=num_epochs,
33
            **model_kwargs, **hypersearch
34
        )
35
        cv_preds = cv_log.predict(x_test, batch_size=1)[1]
36
        cv_C = concordance_index(y_test[:, 1], -cv_preds, y_test[:, 0])
37
        return cv_C
38
    optimal_pars, searchlog, _ = optunity.maximize(
39
        modelrun, num_evals=nevals, solver_name=method, **hypersearch
40
    )
41
    print('Optimal hyperparameters : ' + str(optimal_pars))
42
    print('Cross-validated C after tuning: %1.3f' % searchlog.optimum)
43
    return optimal_pars, searchlog
44
45
46
def train_nn(backend: str, xtr, ytr, batch_size, n_epochs, model_name, lr_exp, alpha, weight_decay_exp, **model_kwargs):
47
    if backend == "tf":
48
        from survival4D.nn.tf import train_nn
49
    elif backend == "torch":
50
        from survival4D.nn.torch import train_nn
51
    else:
52
        raise ValueError("Backend {} not supported. Only tf or torch. ".format(backend))
53
    return train_nn(xtr, ytr, batch_size, n_epochs, model_name, lr_exp, alpha, weight_decay_exp, **model_kwargs)