Switch to side-by-side view

--- a
+++ b/demo/scripts/demo_KMplot.py
@@ -0,0 +1,59 @@
+import pickle
+import numpy as np
+import pandas as pd
+from pathlib import Path
+import matplotlib.pyplot as plt
+from argparse import ArgumentParser
+from lifelines import KaplanMeierFitter
+from survival4D.paths import DATA_DIR
+
+
+def parse_args():
+    parser = ArgumentParser()
+    parser.add_argument(
+        "-d", "--data-dir", dest="data_dir", type=str, default=None, help="Directory where the data file is."
+    )
+    parser.add_argument(
+        "-f", "--file-name", dest="file_name", type=str, default="bootout_conv.pkl", help="Data file name."
+    )
+    return parser.parse_args()
+
+
+def main():
+    args = parse_args()
+    if args.data_dir is None:
+        data_dir = DATA_DIR
+    else:
+        data_dir = Path(args.data_dir)
+    with open(str(data_dir.joinpath(args.file_name)), 'rb') as f:
+        inputdata_list = pickle.load(f)
+    y_orig = inputdata_list[0]
+    preds_bootfull = inputdata_list[1]
+    inds_inbag = inputdata_list[2]
+    del inputdata_list
+
+    preds_bootfull_mat = np.concatenate(preds_bootfull, axis=1)
+    inds_inbag_mat = np.array(inds_inbag).T
+    inbag_mask = 1*np.array([np.any(inds_inbag_mat==_, axis=0) for _ in range(inds_inbag_mat.shape[0])])
+    preds_bootave_oob = np.divide(np.sum(np.multiply((1-inbag_mask), preds_bootfull_mat), axis=1), np.sum(1-inbag_mask, axis=1))
+    risk_groups = 1*(preds_bootave_oob > np.median(preds_bootave_oob))
+
+    wdf = pd.DataFrame(
+        np.concatenate((y_orig, preds_bootave_oob[:, np.newaxis],risk_groups[:, np.newaxis]), axis=-1),
+        columns=['status', 'time', 'preds', 'risk_groups'], index=[str(_) for _ in risk_groups]
+    )
+
+    kmf = KaplanMeierFitter()
+    ax = plt.subplot(111)
+    kmf.fit(durations=wdf.loc['0','time'], event_observed=wdf.loc['0','status'], label="Low Risk")
+    ax = kmf.plot(ax=ax)
+    kmf.fit(durations=wdf.loc['1','time'], event_observed=wdf.loc['1','status'], label="High Risk")
+    ax = kmf.plot(ax=ax)
+    plt.ylim(0,1)
+    plt.title("Kaplan-Meier Plots")
+    plt.xlabel('Time (days)')
+    plt.ylabel('Survival Probability')
+
+
+if __name__ == '__main__':
+    main()