--- a
+++ b/scripts/latefusion.py
@@ -0,0 +1,268 @@
+import argparse
+import inspect
+import os
+import sys
+
+# import warnings
+from datetime import datetime
+
+import numpy as np
+import pandas as pd
+from joblib import delayed
+from sklearn.base import clone
+from sklearn.model_selection import StratifiedKFold
+from sklearn.utils import check_random_state
+from tqdm import tqdm
+
+from _init_scripts import PredictionTask
+from _utils import read_yaml, write_yaml, ProgressParallel
+
+currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
+parentdir = os.path.dirname(currentdir)
+sys.path.insert(0, parentdir)
+
+from multipit.multi_model.latefusion import LateFusionClassifier
+
+
+def main(params):
+    """
+    Repeated cross-validation experiment for classification with late fusion
+    """
+
+    # Uncomment for disabling ConvergenceWarning
+    # warnings.simplefilter("ignore")
+    # os.environ["PYTHONWARNINGS"] = 'ignore'
+
+    # 0. Read config file and save it in the results
+    config = read_yaml(params.config)
+    save_name = config["save_name"]
+    if save_name is None:
+        run_id = datetime.now().strftime(r"%m%d_%H%M%S")
+        save_name = "exp_" + run_id
+    save_dir = os.path.join(params.save_path, save_name)
+    os.mkdir(save_dir)
+    write_yaml(config, os.path.join(save_dir, "config.yaml"))
+
+    # 1. fix random seeds for reproducibility
+    seed = config["latefusion"]["seed"]
+    np.random.seed(seed)
+
+    # 2. Load data and define pipelines for each modality
+    ptask = PredictionTask(config, survival=False, integration="late")
+    ptask.load_data()
+    X, y = ptask.data_concat.values, ptask.labels.loc[ptask.data_concat.index].values
+    ptask.init_pipelines_latefusion()
+
+    # 3. Perform repeated cross-validation
+    parallel = ProgressParallel(
+        n_jobs=config["parallelization"]["n_jobs_repeats"],
+        total=config["latefusion"]["n_repeats"],
+    )
+    results_parallel = parallel(
+        delayed(_fun_repeats)(
+            ptask,
+            X,
+            y,
+            r,
+            disable_infos=(config["parallelization"]["n_jobs_repeats"] is not None)
+            and (config["parallelization"]["n_jobs_repeats"] > 1),
+        )
+        for r in range(config["latefusion"]["n_repeats"])
+    )
+
+    # 4. Save results
+    if config["permutation_test"]:
+        perm_predictions = np.zeros(
+            (
+                len(y),
+                len(ptask.names),
+                config["n_permutations"],
+                config["latefusion"]["n_repeats"],
+            )
+        )
+        list_data_preds, list_data_thrs = [], []
+        for p, res in enumerate(results_parallel):
+            list_data_preds.append(res[0])
+            list_data_thrs.append(res[1])
+            perm_predictions[:, :, :, p] = res[2]
+        perm_labels = results_parallel[-1][3]
+
+        np.save(os.path.join(save_dir, "permutation_labels.npy"), perm_labels)
+        np.save(os.path.join(save_dir, "permutation_predictions.npy"), perm_predictions)
+        data_preds = pd.concat(list_data_preds, axis=0)
+        data_preds.to_csv(os.path.join(save_dir, "predictions.csv"))
+        if config["collect_thresholds"]:
+            data_thrs = pd.concat(list_data_thrs, axis=0)
+            data_thrs.to_csv(os.path.join(save_dir, "thresholds.csv"))
+    else:
+        list_data_preds, list_data_thrs = [], []
+        for p, res in enumerate(results_parallel):
+            list_data_preds.append(res[0])
+            list_data_thrs.append(res[1])
+        data_preds = pd.concat(list_data_preds, axis=0)
+        data_preds.to_csv(os.path.join(save_dir, "predictions.csv"))
+        if config["collect_thresholds"]:
+            data_thrs = pd.concat(list_data_thrs, axis=0)
+            data_thrs.to_csv(os.path.join(save_dir, "thresholds.csv"))
+
+
+def _fun_repeats(prediction_task, X, y, r, disable_infos):
+    """
+    Train and test a late fusion model for classification with cross-validation
+
+    Parameters
+    ----------
+    prediction_task: PredictionTask object
+
+    X: 2D array of shape (n_samples, n_features)
+        Concatenation of the different modalities
+
+    y: 1D array of shape (n_samples,)
+        Binary outcome
+
+    r: int
+        Repeat number
+
+    disable_infos: bool
+
+    Returns
+    -------
+    df_pred: pd.DataFrame of shape (n_samples, n_models+3)
+        Predictions collected over the test sets of the cross-validation scheme for each multimodal combination
+
+    df_thrs: pd.DataFrame of shape (n_samples, n_models+2), None
+        Thresholds that optimize the log-rank test on the training set for each fold and each multimodal combination.
+
+    permut_predictions: 3D array of shape (n_samples, n_models, n_permutations)
+        Predictions collected over the test sets of the cross_validation scheme for each multimodal combination and each random permutation of the labels.
+
+    permut_labels: 2D array of shape (n_samples, n_permutations)
+        Permuted labels
+    """
+    cv = StratifiedKFold(n_splits=10, shuffle=True)  # , random_state=np.random.seed(i))
+    X_preds = np.zeros((len(y), 3 + len(prediction_task.names)))
+    X_thresholds = (
+        np.zeros((len(y), 2 + len(prediction_task.names)))
+        if prediction_task.config["collect_thresholds"]
+        else None
+    )
+
+    late_clf = LateFusionClassifier(
+        estimators=prediction_task.late_estimators,
+        cv=StratifiedKFold(n_splits=10, shuffle=True, random_state=np.random.seed(r)),
+        **prediction_task.config["latefusion"]["args"]
+    )
+
+    # 1. Cross-validation scheme
+    for fold_index, (train_index, test_index) in tqdm(
+        enumerate(cv.split(np.zeros(len(y)), y)),
+        leave=False,
+        total=cv.get_n_splits(np.zeros(len(y))),
+        disable=disable_infos,
+    ):
+        X_train, y_train, X_test = (
+            X[train_index, :],
+            y[train_index],
+            X[test_index, :],
+        )
+        target_surv_train = prediction_task.target_surv[train_index]
+        # Fit late fusion on the training set of the fold
+        clf = clone(late_clf)
+        clf.fit(X_train, y_train)
+        # Collect predictions on the test set of the fold for each multimodal combination
+        for c, idx in enumerate(prediction_task.indices):
+            X_preds[test_index, c] = clf.predict_proba(X_test, estim_ind=idx)[:, 1]
+            # Collect the threshold that optimizes log-rank test on the training set
+            if prediction_task.config["collect_thresholds"]:
+                X_thresholds[test_index, c] = clf.find_logrank_threshold(
+                    X_train, target_surv_train, estim_ind=idx
+                )
+        X_preds[test_index, -3] = fold_index
+        if prediction_task.config["collect_thresholds"]:
+            X_thresholds[test_index, -2] = fold_index
+
+    X_preds[:, -2] = r
+    if prediction_task.config["collect_thresholds"]:
+        X_thresholds[:, -1] = r
+    X_preds[:, -1] = y
+
+    df_pred = (
+        pd.DataFrame(
+            X_preds,
+            columns=prediction_task.names + ["fold_index", "repeat", "label"],
+            index=prediction_task.data_concat.index,
+        )
+        .reset_index()
+        .rename(columns={"index": "samples"})
+        .set_index(["repeat", "samples"])
+    )
+
+    if prediction_task.config["collect_thresholds"]:
+        df_thrs = (
+            pd.DataFrame(
+                X_thresholds,
+                columns=prediction_task.names + ["fold_index", "repeat"],
+                index=prediction_task.data_concat.index,
+            )
+            .reset_index()
+            .rename(columns={"index": "samples"})
+            .set_index(["repeat", "samples"])
+        )
+    else:
+        df_thrs = None
+
+    # 2. Perform permutation test
+    permut_predictions = None
+    permut_labels = None
+    if prediction_task.config["permutation_test"]:
+        permut_labels = np.zeros((len(y), prediction_task.config["n_permutations"]))
+        permut_predictions = np.zeros(
+            (
+                len(y),
+                len(prediction_task.names),
+                prediction_task.config["n_permutations"],
+            )
+        )
+        for prm in range(prediction_task.config["n_permutations"]):
+            X_perm = np.zeros((len(y), len(prediction_task.names)))
+            random_state = check_random_state(prm)
+            sh_ind = random_state.permutation(len(y))
+            yshuffle = np.copy(y)[sh_ind]
+            permut_labels[:, prm] = yshuffle
+            for fold_index, (train_index, test_index) in tqdm(
+                enumerate(cv.split(np.zeros(len(y)), y)),
+                leave=False,
+                total=cv.get_n_splits(np.zeros(len(y))),
+                disable=disable_infos,
+            ):
+                X_train, yshuffle_train, X_test = (
+                    X[train_index, :],
+                    yshuffle[train_index],
+                    X[test_index, :],
+                )
+                clf = clone(late_clf)
+                clf.fit(X_train, yshuffle_train)
+
+                for c, idx in enumerate(prediction_task.indices):
+                    X_perm[test_index, c] = clf.predict_proba(X_test, estim_ind=idx)[
+                        :, 1
+                    ]
+            permut_predictions[:, :, prm] = X_perm
+    return df_pred, df_thrs, permut_predictions, permut_labels
+
+
+if __name__ == "__main__":
+    args = argparse.ArgumentParser(description="Late fusion")
+    args.add_argument(
+        "-c",
+        "--config",
+        type=str,
+        help="config file path",
+    )
+    args.add_argument(
+        "-s",
+        "--save_path",
+        type=str,
+        help="save path",
+    )
+    main(params=args.parse_args())