Switch to side-by-side view

--- a
+++ b/demo/scripts/nested_cv_cox.py
@@ -0,0 +1,143 @@
+"""
+@author: gbello & lisuru6
+How to run the code
+python demo_validateDL.py -c /path-to-conf
+
+Default conf uses demo/scripts/default_validate_DL.conf
+
+"""
+import json
+import shutil
+from datetime import timedelta
+import pickle
+import numpy as np
+from pathlib import Path
+from argparse import ArgumentParser
+from lifelines.utils import concordance_index
+
+from survival4D.cox_reg import hypersearch_cox
+from survival4D.cox_reg import train_cox_reg
+from survival4D.config import CoxExperimentConfig, HypersearchConfig
+from matplotlib import pyplot as plt
+from sklearn.model_selection import KFold
+
+DEFAULT_CONF_PATH = Path(__file__).parent.joinpath("default_cox.conf")
+
+
+def parse_args():
+    parser = ArgumentParser()
+    parser.add_argument(
+        "-c", "--conf-path", dest="conf_path", type=str, default=None, help="Conf path."
+    )
+    return parser.parse_args()
+
+
+def main():
+    args = parse_args()
+    if args.conf_path is None:
+        conf_path = DEFAULT_CONF_PATH
+    else:
+        conf_path = Path(args.conf_path)
+    exp_config = CoxExperimentConfig.from_conf(conf_path)
+    exp_config.output_dir.mkdir(parents=True, exist_ok=True)
+    hypersearch_config = HypersearchConfig.from_conf(conf_path)
+
+    shutil.copy(str(conf_path), str(exp_config.output_dir.joinpath("cox.conf")))
+
+    # import input data: i_full=list of patient IDs, y_full=censoring status and survival times for patients,
+    # x_full=input data for patients (i.e. motion descriptors [11,514-element vector])
+
+    with open(str(exp_config.data_path), 'rb') as f:
+        c3 = pickle.load(f)
+    x_full = c3[0]
+    y_full = c3[1]
+    print(x_full.shape, y_full.shape)
+    del c3
+
+    # Initialize lists to store predictions
+    c_vals = []
+    c_trains = []
+
+    kf = KFold(n_splits=exp_config.n_folds)
+    i = 0
+    for train_indices, test_indices in kf.split(x_full):
+        print(train_indices.shape, test_indices.shape)
+
+        x_train, y_train = x_full[train_indices], y_full[train_indices]
+        x_val, y_val = x_full[test_indices], y_full[test_indices]
+
+        # STEP 1: find optimal hyperparameters using CV
+        print("Step 1a")
+        opars, osummary = hypersearch_cox(
+            x_data=x_train,
+            y_data=y_train,
+            method=exp_config.search_method,
+            nfolds=exp_config.n_folds,
+            nevals=exp_config.n_evals,
+            penalty_range=hypersearch_config.penalty_exp
+        )
+        print("Step b")
+        # (1b) using optimal hyperparameters, train a model and test its performance on the holdout validation set.
+        olog = train_cox_reg(
+            xtr=x_train,
+            ytr=y_train,
+            penalty=10 ** opars['penalty'],
+        )
+
+        # (1c) Compute Harrell's Concordance index
+        pred_val = olog.predict_partial_hazard(x_val)
+        c_val = concordance_index(y_val[:, 1], -pred_val, y_val[:, 0])
+
+        pred_train = olog.predict_partial_hazard(x_train)
+        c_train = concordance_index(y_train[:, 1], -pred_train, y_train[:, 0])
+        c_vals.append(c_val)
+        c_trains.append(c_train)
+        save_params(
+            opars, osummary, "cv_{}".format(i), exp_config.output_dir,
+            c_val=c_val, c_train=c_train,
+            c_val_mean=np.mean(c_vals), c_val_var=np.var(c_vals),
+            c_train_mean=np.mean(c_trains), c_train_var=np.var(c_trains)
+        )
+        print('Validation concordance index = {0:.4f}'.format(c_val))
+        i += 1
+        plot_cs(c_trains, c_vals, exp_config.output_dir)
+    print('Mean Validation concordance index = {0:.4f}'.format(np.mean(c_vals)))
+    print('Variance = {0:.4f}'.format(np.var(c_vals)))
+
+
+def save_params(params: dict, search_log, name: str, output_dir: Path, **kwargs):
+    output_dir.mkdir(parents=True, exist_ok=True)
+    params["search_log_optimum_c_index"] = search_log.optimum
+    params["num_evals"] = search_log.stats["num_evals"]
+    params["time"] = str(timedelta(seconds=search_log.stats["time"]))
+    params["call_log"] = search_log.call_log
+    for key in kwargs.keys():
+        params[key] = kwargs[key]
+    with open(str(output_dir.joinpath(name + ".json")), "w") as fp:
+        json.dump(params, fp, indent=4)
+
+
+def compute_bootstrap_adjusted_c_index(C_app, Cb_opts):
+    # Compute bootstrap-estimated optimism (mean of optimism estimates across the B bootstrap samples)
+    C_opt = np.mean(Cb_opts)
+
+    # Adjust apparent C using bootstrap-estimated optimism
+    C_adj = C_app - C_opt
+
+    # compute confidence intervals for optimism-adjusted C
+    C_opt_95confint = np.percentile([C_app - o for o in Cb_opts], q=[2.5, 97.5])
+
+    return C_opt, C_adj, C_opt_95confint
+
+
+def plot_cs(c_trains, c_vals, output_dir):
+    plt.figure()
+    plt.title("CV validation, mean={:.4f}, var={:.4f}".format(np.mean(c_vals), np.var(c_vals)))
+    plt.plot(range(len(c_vals)), c_vals, 'rx-')
+
+    plt.plot(range(len(c_trains)), c_trains, 'bx-')
+    plt.savefig(str(output_dir.joinpath("c_train_val.png")))
+
+
+if __name__ == '__main__':
+    main()