"""
@author: gbello & lisuru6
How to run the code
python demo_validateDL.py -c /path-to-conf
Default conf uses demo/scripts/default_validate_DL.conf
"""
import json
import shutil
from datetime import timedelta
import pickle
import numpy as np
from pathlib import Path
from argparse import ArgumentParser
from lifelines.utils import concordance_index
from survival4D.nn import hypersearch_nn
from survival4D.nn import train_nn
from survival4D.config import NNExperimentConfig, HypersearchConfig, ModelConfig
from matplotlib import pyplot as plt
DEFAULT_CONF_PATH = Path(__file__).parent.joinpath("default_nn.conf")
def parse_args():
parser = ArgumentParser()
parser.add_argument(
"-c", "--conf-path", dest="conf_path", type=str, default=None, help="Conf path."
)
return parser.parse_args()
def main():
args = parse_args()
if args.conf_path is None:
conf_path = DEFAULT_CONF_PATH
else:
conf_path = Path(args.conf_path)
exp_config = NNExperimentConfig.from_conf(conf_path)
exp_config.output_dir.mkdir(parents=True, exist_ok=True)
hypersearch_config = HypersearchConfig.from_conf(conf_path)
model_config = ModelConfig.from_conf(conf_path)
shutil.copy(str(conf_path), str(exp_config.output_dir.joinpath("nn.conf")))
# import input data: i_full=list of patient IDs, y_full=censoring status and survival times for patients,
# x_full=input data for patients (i.e. motion descriptors [11,514-element vector])
with open(str(exp_config.data_path), 'rb') as f:
c3 = pickle.load(f)
x_full = c3[0]
y_full = c3[1]
del c3
# Initialize lists to store predictions
preds_bootfull = []
inds_inbag = []
Cb_opts = []
# STEP 1
# (1a) find optimal hyperparameters
print("Step 1a")
opars, osummary = hypersearch_nn(
x_data=x_full,
y_data=y_full,
method=exp_config.search_method,
nfolds=exp_config.n_folds,
nevals=exp_config.n_evals,
batch_size=exp_config.batch_size,
num_epochs=exp_config.n_epochs,
backend=exp_config.backend,
model_kwargs=model_config.to_dict(),
**hypersearch_config.to_dict(),
)
# save opars
print("Step b")
# (1b) using optimal hyperparameters, train a model on full sample
olog = train_nn(
backend=exp_config.backend,
xtr=x_full,
ytr=y_full,
batch_size=exp_config.batch_size,
n_epochs=exp_config.n_epochs,
**model_config.to_dict(),
**opars,
)
# (1c) Compute Harrell's Concordance index
predfull = olog.predict(x_full, batch_size=1)[1]
C_app = concordance_index(y_full[:, 1], -predfull, y_full[:, 0])
save_params(opars, osummary, "step_1a", exp_config.output_dir, c_app=C_app)
print('Apparent concordance index = {0:.4f}'.format(C_app))
# BOOTSTRAP SAMPLING
# define useful variables
nsmp = len(x_full)
rowids = [_ for _ in range(nsmp)]
B = exp_config.n_bootstraps
plot_c_opts = []
plot_c_adjs = []
plot_bs_samples = []
plot_c_adjs_lb = []
plot_c_adjs_up = []
for b in range(B):
print('Current bootstrap sample:', b, 'of', B-1)
print('-------------------------------------')
# STEP 2: Generate a bootstrap sample by doing n random selections with replacement (where n is the sample size)
b_inds = np.random.choice(rowids, size=nsmp, replace=True)
xboot = x_full[b_inds]
yboot = y_full[b_inds]
# (2a) find optimal hyperparameters
print("Step 2a")
bpars, bsummary = hypersearch_nn(
backend=exp_config.backend,
x_data=xboot,
y_data=yboot,
method=exp_config.search_method,
nfolds=exp_config.n_folds,
nevals=exp_config.n_evals,
batch_size=exp_config.batch_size,
num_epochs=exp_config.n_epochs,
model_kwargs=model_config.to_dict(),
**hypersearch_config.to_dict(),
)
# (2b) using optimal hyperparameters, train a model on bootstrap sample
blog = train_nn(
backend=exp_config.backend,
xtr=xboot,
ytr=yboot,
batch_size=exp_config.batch_size,
n_epochs=exp_config.n_epochs,
**model_config.to_dict(),
**bpars
)
# (2c[i]) Using bootstrap-trained model, compute predictions on bootstrap sample.
# Evaluate accuracy of predictions (Harrell's Concordance index)
predboot = blog.predict(xboot, batch_size=1)[1]
Cb_boot = concordance_index(yboot[:, 1], -predboot, yboot[:, 0])
# (2c[ii]) Using bootstrap-trained model, compute predictions on FULL sample.
# Evaluate accuracy of predictions (Harrell's Concordance index)
predbootfull = blog.predict(x_full, batch_size=1)[1]
Cb_full = concordance_index(y_full[:, 1], -predbootfull, y_full[:, 0])
# STEP 3: Compute optimism for bth bootstrap sample, as difference between results from 2c[i] and 2c[ii]
Cb_opt = Cb_boot - Cb_full
# store data on current bootstrap sample (predictions, C-indices)
preds_bootfull.append(predbootfull)
inds_inbag.append(b_inds)
Cb_opts.append(Cb_opt)
print('Current bootstrap sample:', b, 'of', B-1)
print('-------------------------------------')
c_opt, c_adj, c_opt_95confint = compute_bootstrap_adjusted_c_index(C_app, Cb_opts)
print('Optimism bootstrap estimate = {0:.4f}'.format(c_opt))
print('Optimism-adjusted concordance index = {0:.4f}, and 95% CI = {1}'.format(c_adj, c_opt_95confint))
save_params(
bpars, bsummary, "bootstrap_{}".format(b), exp_config.output_dir,
c_opt=c_opt, c_adj=c_adj, c_opt_95confint=c_opt_95confint.tolist(),
cb_boot=Cb_boot, cb_full=Cb_full, cb_opt=Cb_opt, c_app=C_app,
)
# plot c_opt, c_adj with c_app as title
plot_c_opts.append(c_opt)
plot_bs_samples.append(b)
plot_c_adjs.append(c_adj)
plot_c_adjs_lb.append(c_opt_95confint[0])
plot_c_adjs_up.append(c_opt_95confint[1])
plot_c_indices(plot_bs_samples, plot_c_opts, plot_c_adjs, plot_c_adjs_lb, plot_c_adjs_up, C_app, exp_config.output_dir)
# STEP 5
# Compute bootstrap-estimated optimism (mean of optimism estimates across the B bootstrap samples)
c_opt, c_adj, c_opt_95confint = compute_bootstrap_adjusted_c_index(C_app, Cb_opts)
print('Optimism bootstrap estimate = {0:.4f}'.format(c_opt))
print('Optimism-adjusted concordance index = {0:.4f}, and 95% CI = {1}'.format(c_adj, c_opt_95confint))
def save_params(params: dict, search_log, name: str, output_dir: Path, **kwargs):
output_dir.mkdir(parents=True, exist_ok=True)
params["search_log_optimum_c_index"] = search_log.optimum
params["num_evals"] = search_log.stats["num_evals"]
params["time"] = str(timedelta(seconds=search_log.stats["time"]))
params["call_log"] = search_log.call_log
for key in kwargs.keys():
params[key] = kwargs[key]
with open(str(output_dir.joinpath(name + ".json")), "w") as fp:
json.dump(params, fp, indent=4)
def compute_bootstrap_adjusted_c_index(C_app, Cb_opts):
# Compute bootstrap-estimated optimism (mean of optimism estimates across the B bootstrap samples)
C_opt = np.mean(Cb_opts)
# Adjust apparent C using bootstrap-estimated optimism
C_adj = C_app - C_opt
# compute confidence intervals for optimism-adjusted C
C_opt_95confint = np.percentile([C_app - o for o in Cb_opts], q=[2.5, 97.5])
return C_opt, C_adj, C_opt_95confint
def plot_c_indices(bs_samples, c_obts, c_adjs, c_adjs_lb, c_adjst_up, c_app, output_dir: Path):
plt.figure()
plt.title("c_adj, c_app={:.4f}".format(c_app))
plt.fill_between(bs_samples, c_adjs_lb, c_adjst_up, facecolor='red', alpha=0.5, interpolate=True)
plt.plot(bs_samples, c_adjs, 'rx-')
plt.savefig(str(output_dir.joinpath("c_adj.png")))
plt.figure()
plt.title("c_opt, c_app={:.4f}".format(c_app))
plt.plot(bs_samples, c_obts, 'rx-')
plt.savefig(str(output_dir.joinpath("c_obt.png")))
if __name__ == '__main__':
main()