a b/demo/scripts/nested_cv_nn.py
1
"""
2
@author: gbello & lisuru6
3
How to run the code
4
python demo_validateDL.py -c /path-to-conf
5
6
Default conf uses demo/scripts/default_validate_DL.conf
7
8
"""
9
import json
10
import shutil
11
from datetime import timedelta
12
import pickle
13
import numpy as np
14
from pathlib import Path
15
from argparse import ArgumentParser
16
from lifelines.utils import concordance_index
17
18
from survival4D.nn import hypersearch_nn
19
from survival4D.nn import train_nn
20
from survival4D.config import NNExperimentConfig, HypersearchConfig, ModelConfig
21
from matplotlib import pyplot as plt
22
from sklearn.model_selection import KFold
23
24
DEFAULT_CONF_PATH = Path(__file__).parent.joinpath("default_nn.conf")
25
26
27
def parse_args():
28
    parser = ArgumentParser()
29
    parser.add_argument(
30
        "-c", "--conf-path", dest="conf_path", type=str, default=None, help="Conf path."
31
    )
32
    return parser.parse_args()
33
34
35
def main():
36
    args = parse_args()
37
    if args.conf_path is None:
38
        conf_path = DEFAULT_CONF_PATH
39
    else:
40
        conf_path = Path(args.conf_path)
41
    exp_config = NNExperimentConfig.from_conf(conf_path)
42
    exp_config.output_dir.mkdir(parents=True, exist_ok=True)
43
    hypersearch_config = HypersearchConfig.from_conf(conf_path)
44
    model_config = ModelConfig.from_conf(conf_path)
45
46
    shutil.copy(str(conf_path), str(exp_config.output_dir.joinpath("nn.conf")))
47
48
    # import input data: i_full=list of patient IDs, y_full=censoring status and survival times for patients,
49
    # x_full=input data for patients (i.e. motion descriptors [11,514-element vector])
50
51
    with open(str(exp_config.data_path), 'rb') as f:
52
        c3 = pickle.load(f)
53
    x_full = c3[0]
54
    y_full = c3[1]
55
    print(x_full.shape, y_full.shape)
56
57
    del c3
58
59
    # Initialize lists to store predictions
60
    c_vals = []
61
    c_trains = []
62
63
    kf = KFold(n_splits=exp_config.n_folds)
64
    i = 0
65
    for train_indices, test_indices in kf.split(x_full):
66
67
        x_train, y_train = x_full[train_indices], y_full[train_indices]
68
        x_val, y_val = x_full[test_indices], y_full[test_indices]
69
70
        # STEP 1: find optimal hyperparameters using CV
71
        print("Step 1a")
72
        opars, osummary = hypersearch_nn(
73
            x_data=x_train,
74
            y_data=y_train,
75
            method=exp_config.search_method,
76
            nfolds=exp_config.n_folds,
77
            nevals=exp_config.n_evals,
78
            batch_size=exp_config.batch_size,
79
            num_epochs=exp_config.n_epochs,
80
            backend=exp_config.backend,
81
            model_kwargs=model_config.to_dict(),
82
            **hypersearch_config.to_dict(),
83
        )
84
        print("Step b")
85
        # (1b) using optimal hyperparameters, train a model and test its performance on the holdout validation set.
86
        olog = train_nn(
87
            backend=exp_config.backend,
88
            xtr=x_train,
89
            ytr=y_train,
90
            batch_size=exp_config.batch_size,
91
            n_epochs=exp_config.n_epochs,
92
            **model_config.to_dict(),
93
            **opars,
94
        )
95
96
        # (1c) Compute Harrell's Concordance index
97
        pred_val = olog.predict(x_val, batch_size=1)[1]
98
        c_val = concordance_index(y_val[:, 1], -pred_val, y_val[:, 0])
99
100
        pred_train = olog.predict(x_train, batch_size=1)[1]
101
        c_train = concordance_index(y_train[:, 1], -pred_train, y_train[:, 0])
102
        c_vals.append(c_val)
103
        c_trains.append(c_train)
104
        save_params(
105
            opars, osummary, "cv_{}".format(i), exp_config.output_dir,
106
            c_val=c_val, c_train=c_train,
107
            c_val_mean=np.mean(c_vals), c_val_var=np.var(c_vals),
108
            c_train_mean=np.mean(c_trains), c_train_var=np.var(c_trains)
109
        )
110
        print('Validation concordance index = {0:.4f}'.format(c_val))
111
        i += 1
112
        plot_cs(c_trains, c_vals, exp_config.output_dir)
113
    print('Mean Validation concordance index = {0:.4f}'.format(np.mean(c_vals)))
114
    print('Variance = {0:.4f}'.format(np.var(c_vals)))
115
116
117
def save_params(params: dict, search_log, name: str, output_dir: Path, **kwargs):
118
    output_dir.mkdir(parents=True, exist_ok=True)
119
    params["search_log_optimum_c_index"] = search_log.optimum
120
    params["num_evals"] = search_log.stats["num_evals"]
121
    params["time"] = str(timedelta(seconds=search_log.stats["time"]))
122
    params["call_log"] = search_log.call_log
123
    for key in kwargs.keys():
124
        params[key] = kwargs[key]
125
    with open(str(output_dir.joinpath(name + ".json")), "w") as fp:
126
        json.dump(params, fp, indent=4)
127
128
129
def compute_bootstrap_adjusted_c_index(C_app, Cb_opts):
130
    # Compute bootstrap-estimated optimism (mean of optimism estimates across the B bootstrap samples)
131
    C_opt = np.mean(Cb_opts)
132
133
    # Adjust apparent C using bootstrap-estimated optimism
134
    C_adj = C_app - C_opt
135
136
    # compute confidence intervals for optimism-adjusted C
137
    C_opt_95confint = np.percentile([C_app - o for o in Cb_opts], q=[2.5, 97.5])
138
139
    return C_opt, C_adj, C_opt_95confint
140
141
142
def plot_cs(c_trains, c_vals, output_dir):
143
    plt.figure()
144
    plt.title("CV validation, mean={:.4f}, var={:.4f}".format(np.mean(c_vals), np.var(c_vals)))
145
    plt.plot(range(len(c_vals)), c_vals, 'rx-')
146
147
    plt.plot(range(len(c_trains)), c_trains, 'bx-')
148
    plt.savefig(str(output_dir.joinpath("c_train_val.png")))
149
150
151
if __name__ == '__main__':
152
    main()