--- a +++ b/scripts/earlyfusion_survival.py @@ -0,0 +1,211 @@ +import argparse +import inspect +import os +import sys +from datetime import datetime + +import numpy as np +import pandas as pd +from joblib import delayed +from sksurv.util import Surv +from tqdm import tqdm + +from _init_scripts import PredictionTask +from _utils import read_yaml, write_yaml, ProgressParallel + +currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) +parentdir = os.path.dirname(currentdir) +sys.path.insert(0, parentdir) + +from multipit.multi_model.earlyfusion import EarlyFusionSurvival +from multipit.utils.custom.cv import CensoredKFold + + +def main(params): + """ + Repeated cross-validation experiment for survival prediction with early fusion + """ + + # 0. Read config file and save it in the results + config = read_yaml(params.config) + save_name = config["save_name"] + if save_name is None: + run_id = datetime.now().strftime(r"%m%d_%H%M%S") + save_name = "exp_" + run_id + + save_dir = os.path.join(params.save_path, save_name) + os.mkdir(save_dir) + write_yaml(config, os.path.join(save_dir, "config.yaml")) + + # 1. fix random seeds for reproducibility + seed = config["earlyfusion"]["seed"] + np.random.seed(seed) + + # 2. Load data and define pipelines for each modality + ptask = PredictionTask(config, survival=True, integration="early") + ptask.load_data() + X = ptask.data_concat.values + y = Surv().from_arrays( + event=ptask.labels.loc[ptask.data_concat.index, "event"].values, + time=ptask.labels.loc[ptask.data_concat.index, "time"].values, + ) + ptask.init_pipelines_earlyfusion() + + # 5. Define function to apply for each cross-validation scheme + + # 6. Perform repeated cross-validation + parallel = ProgressParallel( + n_jobs=config["parallelization"]["n_jobs_repeats"], + total=config["earlyfusion"]["n_repeats"], + ) + results_parallel = parallel( + delayed(_fun_repeats)( + ptask, + X, + y, + r, + disable_infos=(config["parallelization"]["n_jobs_repeats"] is not None) + and (config["parallelization"]["n_jobs_repeats"] > 1), + ) + for r in range(config["earlyfusion"]["n_repeats"]) + ) + + # 7. Save results + list_data_preds, list_data_thrs = [], [] + for res in results_parallel: + list_data_preds.append(res[0]) + list_data_thrs.append(res[1]) + data_preds = pd.concat(list_data_preds, axis=0) + data_preds.to_csv(os.path.join(save_dir, "predictions.csv")) + if config["collect_thresholds"]: + data_thrs = pd.concat(list_data_thrs, axis=0) + data_thrs.to_csv(os.path.join(save_dir, "thresholds.csv")) + + +def _fun_repeats(prediction_task, X, y, r, disable_infos): + """ + Train and test an early fusion model for survival task with cross-validation + + Parameters + ---------- + prediction_task: PredictionTask object + + X: 2D array of shape (n_samples, n_features) + Concatenation of the different modalities + + y: Structured array of size (n_samples,) + Event indicator and observed time for each sample + + r: int + Repeat number + + disable_infos: bool + + Returns + ------- + df_pred: pd.DataFrame of shape (n_samples, n_models+4) + Predictions collected over the test sets of the cross-validation scheme for each multimodal combination + + df_thrs: pd.DataFrame of shape (n_samples, n_models+2), None + Thresholds that optimize the log-rank test on the training set for each fold and each multimodal combination. + """ + cv = CensoredKFold(n_splits=10, shuffle=True) # , random_state=np.random.seed(i)) + X_preds = np.zeros((len(y), 4 + len(prediction_task.names))) + X_thresholds = ( + np.zeros((len(y), 2 + len(prediction_task.names))) + if prediction_task.config["collect_thresholds"] + else None + ) + for fold_index, (train_index, test_index) in tqdm( + enumerate(cv.split(np.zeros(len(y)), y)), + leave=False, + total=cv.get_n_splits(np.zeros(len(y))), + disable=disable_infos, + ): + X_train, y_train, X_test, y_test = ( + X[train_index, :], + y[train_index], + X[test_index, :], + y[test_index], + ) + + cv_inner = CensoredKFold( + n_splits=10, shuffle=True, random_state=np.random.seed(r) + ) + for c, models in enumerate(prediction_task.names): + t = { + model: prediction_task.early_transformers[model] + for model in models.split("+") + } + early_surv = EarlyFusionSurvival( + estimator=prediction_task.early_estimator, + transformers=t, + modalities={ + model: prediction_task.dic_modalities[model] + for model in models.split("+") + }, + cv=cv_inner, + **prediction_task.config["earlyfusion"]["args"] + ) + if len(models.split("+")) == 1: + early_surv.set_params(**{"select_features": False}) + + early_surv.fit(X_train, y_train) + X_preds[test_index, c] = early_surv.predict(X_test) + if prediction_task.config["collect_thresholds"]: + X_thresholds[test_index, c] = early_surv.find_logrank_threshold( + X_train, y_train + ) + X_preds[test_index, -4] = fold_index + if prediction_task.config["collect_thresholds"]: + X_thresholds[test_index, -2] = fold_index + + X_preds[:, -3] = r + if prediction_task.config["collect_thresholds"]: + X_thresholds[:, -1] = r + X_preds[:, -2] = y["time"] + X_preds[:, -1] = y["event"] + df_pred = ( + pd.DataFrame( + X_preds, + columns=prediction_task.names + + ["fold_index", "repeat", "label.time", "label.event"], + index=prediction_task.data_concat.index, + ) + .reset_index() + .rename(columns={"index": "samples"}) + .set_index(["repeat", "samples"]) + ) + + if prediction_task.config["collect_thresholds"]: + df_thrs = ( + pd.DataFrame( + X_thresholds, + columns=prediction_task.names + ["fold_index", "repeat"], + index=prediction_task.data_concat.index, + ) + .reset_index() + .rename(columns={"index": "samples"}) + .set_index(["repeat", "samples"]) + ) + else: + df_thrs = None + + return df_pred, df_thrs + + +if __name__ == "__main__": + args = argparse.ArgumentParser(description="Early fusion") + args.add_argument( + "-c", + "--config", + type=str, + help="config file path", + ) + args.add_argument( + "-s", + "--save_path", + type=str, + help="save path", + ) + main(params=args.parse_args())