|
a |
|
b/survival4D/nn/__init__.py |
|
|
1 |
import optunity |
|
|
2 |
import numpy as np |
|
|
3 |
from lifelines.utils import concordance_index |
|
|
4 |
|
|
|
5 |
|
|
|
6 |
def prepare_data(x, e, t): |
|
|
7 |
return x.astype("float32"), e.astype("int32"), t.astype("float32") |
|
|
8 |
|
|
|
9 |
|
|
|
10 |
def sort4minibatches(xvals, evals, tvals, batchsize): |
|
|
11 |
ntot = len(xvals) |
|
|
12 |
indices = np.arange(ntot) |
|
|
13 |
np.random.shuffle(indices) |
|
|
14 |
start_idx = 0 |
|
|
15 |
esall = [] |
|
|
16 |
for end_idx in list(range(batchsize, batchsize*(ntot//batchsize)+1, batchsize))+[ntot]: |
|
|
17 |
excerpt = indices[start_idx:end_idx] |
|
|
18 |
sort_idx = np.argsort(tvals[excerpt])[::-1] |
|
|
19 |
es = excerpt[sort_idx] |
|
|
20 |
esall += list(es) |
|
|
21 |
start_idx = end_idx |
|
|
22 |
return xvals[esall], evals[esall], tvals[esall], esall |
|
|
23 |
|
|
|
24 |
|
|
|
25 |
# 1. Hyperparameter search for Deep Learning model |
|
|
26 |
def hypersearch_nn(x_data, y_data, method, nfolds, nevals, batch_size, num_epochs, backend: str, |
|
|
27 |
model_kwargs: dict, **hypersearch): |
|
|
28 |
|
|
|
29 |
@optunity.cross_validated(x=x_data, y=y_data, num_folds=nfolds) |
|
|
30 |
def modelrun(x_train, y_train, x_test, y_test, **hypersearch): |
|
|
31 |
cv_log = train_nn( |
|
|
32 |
backend=backend, xtr=x_train, ytr=y_train, batch_size=batch_size, n_epochs=num_epochs, |
|
|
33 |
**model_kwargs, **hypersearch |
|
|
34 |
) |
|
|
35 |
cv_preds = cv_log.predict(x_test, batch_size=1)[1] |
|
|
36 |
cv_C = concordance_index(y_test[:, 1], -cv_preds, y_test[:, 0]) |
|
|
37 |
return cv_C |
|
|
38 |
optimal_pars, searchlog, _ = optunity.maximize( |
|
|
39 |
modelrun, num_evals=nevals, solver_name=method, **hypersearch |
|
|
40 |
) |
|
|
41 |
print('Optimal hyperparameters : ' + str(optimal_pars)) |
|
|
42 |
print('Cross-validated C after tuning: %1.3f' % searchlog.optimum) |
|
|
43 |
return optimal_pars, searchlog |
|
|
44 |
|
|
|
45 |
|
|
|
46 |
def train_nn(backend: str, xtr, ytr, batch_size, n_epochs, model_name, lr_exp, alpha, weight_decay_exp, **model_kwargs): |
|
|
47 |
if backend == "tf": |
|
|
48 |
from survival4D.nn.tf import train_nn |
|
|
49 |
elif backend == "torch": |
|
|
50 |
from survival4D.nn.torch import train_nn |
|
|
51 |
else: |
|
|
52 |
raise ValueError("Backend {} not supported. Only tf or torch. ".format(backend)) |
|
|
53 |
return train_nn(xtr, ytr, batch_size, n_epochs, model_name, lr_exp, alpha, weight_decay_exp, **model_kwargs) |