Switch to side-by-side view

--- a
+++ b/multipit/multi_model/latefusion.py
@@ -0,0 +1,742 @@
+from itertools import combinations
+
+import numpy as np
+import pandas as pd
+from joblib import Parallel, delayed
+from lifelines.statistics import logrank_test
+from sklearn.base import clone, BaseEstimator
+from sklearn.linear_model import LogisticRegression
+from sklearn.model_selection import GridSearchCV, RandomizedSearchCV, cross_val_score
+
+
+class LateFusionClassifier(BaseEstimator):
+    """
+    Late fusion classifier for multimodal integration.
+
+    Parameters
+    ----------
+    estimators: list of (str, estimator, list, dict) tuples.
+        List of unimodal estimators to fit and fuse. Each unimodal estimator is associate with a tuple
+        (`name`, `estimator` , `features`, `tune_dict`) where `name` is a string and corresponds to the name of the
+        estimator, `estimator` ia scikit-learn estimator inheriting from BaseEstimator (can be a Pipeline), `features`
+        is a list of indexes corresponding to the columns of the data associated with the modality of intereset, and
+        `tune_dict` is either a dictionnary or a tuple (dict, n_iterations) for hyperparameter tuning and GridSearch or
+        RandomSearch strategy respectively.
+
+    cv: cross-validation generator
+        cross-validation scheme for hyperparameter tuning (if `tuning` is not None) and/or calibration (if `calibration`
+        is not None). The default is None
+
+    score: str or callable.
+        Score to use for tuning the unimodal models or weighting them at the late fusion step (i.e. sum of the unimodal
+        predictions weighted by the performance of each unimodal model estimated with cross-validation).
+        See sklearn.model_selection.cross_val_score for more details. The default is None.
+
+    random_score: float.
+        Random score for classification. Used when weighting the unimodal models with their estimated score. Weights
+        will be max(score - random_score, 0). Unimodal models whose estimated performance is below the random_score
+        will not be taken into account. The default is 0.5
+
+    sup_weights: bool.
+        Whether to use weights associated with the cross-validation performance of each unimodal model. If false no
+        weights are used when fusing the unimodal predictions. The default is False.
+
+    missing_threshold: float in ]0, 1].
+        Minimum frequency of missing values to consider a whole modality missing (e.g., if `missing_threshold = 0.9` it
+        means that for each sample and each modality at least 90% of the features associated with this modality must be
+        missing to consider the whole modality missing). The default is 0.9.
+
+    tuning: str or None.
+        Strategy for tuning each model. Either 'gridsearch' for GridSearchCV or 'randomsearch' for RandomSearchCV. If
+        None no hyperparameter tuning will be performed. The default is None.
+
+    n_jobs: int.
+        Number of jobs to run in parallel for hyperparameter tuning, collecting the predictions for calibration, or
+        estimating the performance of each unimodal model with cross-validation. The default is None.
+
+    calibration: str or None.
+        Calibration strategy.
+            * `calibration = 'late'` means that the fusion is made before calibration. The predictions of each
+            multimodal combination are collected with cross-validation and a univariate logistic regression model is
+            fitted to these predictions.
+            * `calibration = 'early'` means that each unimodal model is calibrated prior to the late fusion. The
+            unimodal predictions are collected with a cross-validation scheme and univariate logistic regression models
+            are fitted.
+            * `calibration = None` means that no calibration is performed.
+
+    Attributes
+    ----------
+    best_params_: list of dict or empty list.
+        List of best parameters for each unimodal predictor (output of GridSearchCV or RandomSearchCV). It follows the
+        same order as the one of `estimators` list. If `tuning` is None returns an empty list (i.e., no hyperparameter
+        tuning is performed).
+
+    weights_: list of float.
+        List of the weights associated to each modality and used at the late fusion stage for weighted sum.
+
+    fitted_estimators_: list of estimators.
+        List of fitted unimodal estimators.
+
+    fitted_meta_estimators_: dictionary of estimators.
+        Dictionary of meta-estimators for calibration. If `calibration = "early"` the keys correspond to the indexes of
+        each unimodal estimator (i.e., from 0 to n_estimators-1) and the values correspond to the logistic regression
+        estimators fitted to calibrate the unimodal models. If `calibration = "late"` the keys correspond to tuples
+        characterizing each multimodal combination (e.g., (1, 3, 5)) and the values correspond th the logistic regression
+        estimators fitted to clibrate the multimodal models.
+    """
+
+    def __init__(
+        self,
+        estimators,
+        cv=None,
+        score=None,
+        random_score=0.5,
+        sup_weights=False,
+        missing_threshold=0.9,
+        tuning=None,
+        n_jobs=None,
+        calibration="late",
+    ):
+        self.estimators = estimators
+        self.cv = cv
+        self.score = score
+        self.random_score = random_score
+        self.sup_weights = sup_weights
+        self.missing_threshold = missing_threshold
+        self.tuning = tuning
+        self.n_jobs = n_jobs
+        self.calibration = calibration
+
+        self.best_params_ = []
+        self.weights_ = []
+        self.fitted_estimators_ = []
+        self.fitted_meta_estimators_ = {}
+
+    def fit(self, X, y):
+        """
+        Fit the latefusion classifier.
+
+        Parameters
+        ----------
+        X: array of shape (n_samples, n_features)
+            Multimodal array, concatenation of the features from the different modalities. Missing modalities are filled
+            with NaNs values for each sample.
+
+        y: array of shape (n_samples,)
+            Target to predict.
+
+        Returns
+        -------
+        self : object
+            Returns the instance itself.
+        """
+        predictions = np.zeros((X.shape[0], len(self.estimators)))
+        weights = np.zeros((X.shape[0], len(self.estimators)))
+        i = 0
+        for name, estim, features, grid in self.estimators:
+            Xnew = X[:, features]
+            # bool_mask = ~(np.sum(np.isnan(Xnew), axis=1) > self.missing_threshold * len(features))
+            bool_mask = ~(
+                np.sum(pd.isnull(Xnew), axis=1) > self.missing_threshold * len(features)
+            )
+            Xnew, ynew = Xnew[bool_mask, :], y[bool_mask]
+            # Fit unimodal estimator
+            self._fit_estim(
+                Xnew, ynew, estim=estim, features=features, grid=grid, name=name
+            )
+
+            # Collect predictions and weights for further calibration
+            if self.calibration is not None:
+                weights[bool_mask, i] = max(self.weights_[-1] - self.random_score, 0)
+                parallel = Parallel(n_jobs=self.n_jobs)
+                collected_predictions = parallel(
+                    delayed(_collect)(
+                        Xdata=X,
+                        ydata=y,
+                        estimator=estim,
+                        bmask=bool_mask,
+                        feat=features,
+                        train=train,
+                        test=test,
+                    )
+                    for train, test in self.cv.split(X, y)
+                )
+                for indexes, preds in collected_predictions:
+                    predictions[indexes, i] = preds
+            i += 1
+
+        # Calibrate models
+        if self.calibration is not None:
+            if self.calibration == "early":
+                self._fit_early_calibration(predictions=predictions, y=y)
+            elif self.calibration == "late":
+                self._fit_late_calibration(
+                    predictions=predictions, weights=weights, y=y
+                )
+            else:
+                raise ValueError(
+                    "'early', 'late' or None are the only values available for calibration parameter"
+                )
+
+        self.weights_ = np.array(self.weights_) - self.random_score
+        self.weights_ = np.where(self.weights_ > 0, self.weights_, 0)
+        return self
+
+    def _fit_estim(self, X, y, estim, features, grid, name):
+        """
+        Fit a unimodal estimator.
+        """
+        if (self.tuning is not None) and (len(grid) > 0):
+            if self.tuning == "gridsearch":
+                search = GridSearchCV(
+                    estimator=clone(estim),
+                    param_grid=grid,
+                    cv=self.cv,
+                    scoring=self.score,
+                    n_jobs=self.n_jobs,
+                )
+            elif self.tuning == "randomsearch":
+                search = RandomizedSearchCV(
+                    estimator=clone(estim),
+                    param_distributions=grid[1],
+                    n_iter=grid[0],
+                    scoring=self.score,
+                    n_jobs=self.n_jobs,
+                    cv=self.cv,
+                )
+
+            search.fit(X, y)
+
+            if self.sup_weights:
+                self.weights_.append(search.best_score_)
+            else:
+                self.weights_.append(1.0)
+
+            temp = search.best_estimator_
+            self.best_params_.append(search.best_params_)
+            # print("Best params " + name + " :", search.best_params_)
+            # print("Best score " + name + " :", search.best_score_)
+        else:
+            if self.sup_weights:
+                self.weights_.append(
+                    np.mean(
+                        cross_val_score(
+                            estimator=clone(estim),
+                            X=X,
+                            y=y,
+                            cv=self.cv,
+                            scoring=self.score,
+                            n_jobs=self.n_jobs,
+                        )
+                    )
+                )
+            else:
+                self.weights_.append(1.0)
+            temp = clone(estim).fit(X, y)
+
+        self.fitted_estimators_.append((name, temp, features))
+        return
+
+    def _fit_early_calibration(self, predictions, y):
+        """
+        Calibrate only each unimodal predictor.
+        """
+        for i in range(len(self.estimators)):
+            probas = predictions[:, i]
+            mask = (probas > 0).reshape(-1)
+            self.fitted_meta_estimators_[i] = LogisticRegression(
+                class_weight="balanced"
+            ).fit(probas[mask].reshape(-1, 1), y[mask])
+        return
+
+    def _fit_late_calibration(self, predictions, weights, y):
+        """
+        Calibrate each combination of modalities.
+        """
+        for i in range(1, len(self.estimators) + 1):
+            for comb in combinations(range(len(self.estimators)), i):
+                probas = predictions[:, np.array(comb)]
+                if len(comb) == 1:
+                    mask = (probas > 0).reshape(-1)
+                else:
+                    w = weights[:, np.array(comb)]
+                    mask = np.any(probas > 0, axis=1).reshape(-1)
+                    temp = np.sum(w, axis=1)
+                    w[temp > 0] = w[temp > 0] / (temp[temp > 0].reshape(-1, 1))
+                    probas = np.sum(probas * w, axis=1)
+
+                self.fitted_meta_estimators_[comb] = LogisticRegression(
+                    class_weight="balanced"
+                ).fit(probas[mask].reshape(-1, 1), y[mask])
+        return
+
+    def predict_proba(self, X, estim_ind=None):
+        """
+        Late fusion probability estimates
+
+        Parameters
+        ----------
+        X: array of shape (n_samples, n_features)
+            Multimodal array, concatenation of the features from the different modalities. Missing modalities are filled
+            with NaNs values for each sample.
+
+        estim_ind: tuple of integers.
+            Tuple representing a multimodal combination (e.g. (i, j, k) corresponds to the combination of the ith, the
+            jth and the kth estimators in self.fitted_estimators_). If None all the multimodal combination with all the
+            fitted unimodal predictors is considered.
+
+        Returns
+        -------
+        probas: array of shape (n_samples, 2).
+            Probability of the samples for each class. If no modality are availbale for the sample, returns 0.5 for
+            both classes.
+        """
+        fitted_estimators = (
+            [self.fitted_estimators_[i] for i in estim_ind]
+            if estim_ind is not None
+            else self.fitted_estimators_
+        )
+        fitted_weights = (
+            np.array([self.weights_[i] for i in estim_ind])
+            if estim_ind is not None
+            else self.weights_
+        )
+
+        # Collect predictions for each modality
+        preds = np.zeros((X.shape[0], len(fitted_estimators)))
+        weights = np.zeros((X.shape[0], len(fitted_weights)))
+        for j, item in enumerate(fitted_estimators):
+            Xpred = X[:, item[2]].copy()
+            # bool_mask = ~(np.sum(np.isnan(Xpred), axis=1) > self.missing_threshold * len(item[2]))
+            bool_mask = ~(
+                np.sum(pd.isnull(Xpred), axis=1) > self.missing_threshold * len(item[2])
+            )
+            weights[:, j] = np.where(bool_mask, fitted_weights[j], 0)
+            preds[bool_mask, j] = item[1].predict_proba(Xpred[bool_mask, :])[:, 1]
+
+        # Calibrate the predictions and predict probas
+        if self.calibration is not None:
+            if self.calibration == "late":
+                probas = self._predict_calibrate_late(preds, weights, estim_ind)
+            elif self.calibration == "early":
+                probas = self._predict_calibrate_early(preds, weights, estim_ind)
+            else:
+                raise ValueError(
+                    "'early', 'late' or None are the only values available for calibration parameter"
+                )
+        else:
+            probas = self._predict_uncalibrated(preds, weights)
+        return np.hstack([1 - probas, probas])
+
+    @staticmethod
+    def _predict_uncalibrated(preds, weights):
+        """
+        Return weighted sum of available unimodal predictions
+        """
+        temp = np.sum(weights, axis=1)
+        weights[temp > 0] = weights[temp > 0] / (temp[temp > 0].reshape(-1, 1))
+        probas = np.sum(preds * weights, axis=1)
+        return np.where(temp == 0, 0.5, probas).reshape(-1, 1)
+
+    def _predict_calibrate_early(self, preds, weights, estim_ind):
+        """
+        Return weighted sum of available and calibrated unimodal predictions
+        """
+        temp = np.sum(weights, axis=1)
+        weights[temp > 0] = weights[temp > 0] / (temp[temp > 0].reshape(-1, 1))
+        list_meta_estimators = (
+            [self.fitted_meta_estimators_[i] for i in estim_ind]
+            if estim_ind is not None
+            else list(self.fitted_meta_estimators_.values())
+        )
+        for j, meta in enumerate(list_meta_estimators):
+            preds[:, j] = np.where(
+                weights[:, j] != 0,
+                meta.predict_proba(preds[:, j].reshape(-1, 1))[:, 1],
+                0,
+            )
+        probas = np.sum(preds * weights, axis=1)
+        return np.where(temp == 0, 0.5, probas).reshape(-1, 1)
+
+    def _predict_calibrate_late(self, preds, weights, estim_ind):
+        """
+        Return calibrated weighted sum of availbale unimodal predictions
+        """
+        temp = np.sum(weights, axis=1)
+        weights[temp > 0] = weights[temp > 0] / (temp[temp > 0].reshape(-1, 1))
+        probas = np.sum(preds * weights, axis=1)
+        meta_estimator = (
+            self.fitted_meta_estimators_[estim_ind]
+            if estim_ind is not None
+            else list(self.fitted_meta_estimators_.values())[-1]
+        )
+        return np.where(
+            temp == 0, 0.5, meta_estimator.predict_proba(probas.reshape(-1, 1))[:, 1]
+        ).reshape(-1, 1)
+
+    def find_logrank_threshold(
+        self, X, ysurv, estim_ind, percentile_min=30, percentile_max=70
+    ):
+        """
+        Find the best cutoff that optimize the stratification of samples with respect to survival data (using logrank
+        test).
+
+        Parameters
+        ----------
+        X: array of shape (n_samples, n_features)
+            Multimodal array, concatenation of the features from the different modalities. Missing modalities are filled
+            with NaNs values for each sample.
+
+        ysurv: structured array of shape (n_samples,) see sksurv.util.Surv (from scikit-survival)
+            Structured array for survival data associated with X.
+
+        estim_ind: tuple of integers.
+            Tuple representing a multimodal combination (e.g. (i, j, k) corresponds to the combination of the ith, the
+            jth and the kth estimators in self.fitted_estimators_). If None all the multimodal combination with all the
+            fitted unimodal predictors is considered.
+
+        percentile_min: int in [0, 100]
+            Minimum value of the percentile range used to explore various cutoff values for predicted probabilities
+
+        percentile_max: int in [0, 100]
+            Maximum value of the percentile range used to explore various cutoff values for predicted probabilities
+
+        Returns
+        -------
+        cutoff: float.
+            Best cutoff for the predicted probabilities that otpimize the log-rank test.
+        """
+        risk_score = self.predict_proba(X, estim_ind=estim_ind)[:, 1]
+        bool_mask = risk_score == 0.5
+        cutoffs, pvals = [], []
+        risk_score_new, y_new = risk_score[~bool_mask], ysurv[~bool_mask]
+        for p in np.arange(percentile_min, percentile_max + 1):
+            c = np.percentile(risk_score_new, p)
+            group1 = risk_score_new <= c
+            group2 = risk_score_new > c
+            test = logrank_test(
+                durations_A=y_new[group1]["time"],
+                durations_B=y_new[group2]["time"],
+                event_observed_A=1 * (y_new[group1]["event"]),
+                event_observed_B=1 * (y_new[group2]["event"]),
+            )
+            cutoffs.append(c)
+            pvals.append(test.summary["p"].values[0])
+        return cutoffs[np.argmin(pvals)]
+
+
+def _collect(Xdata, ydata, estimator, bmask, feat, train, test):
+    Xtrain, Xtest, ytrain, ytest = (
+        Xdata[np.intersect1d(np.where(bmask)[0], train), :],
+        Xdata[np.intersect1d(np.where(bmask)[0], test), :],
+        ydata[np.intersect1d(np.where(bmask)[0], train)],
+        ydata[np.intersect1d(np.where(bmask)[0], test)],
+    )
+    tempbis = clone(estimator).fit(Xtrain[:, feat], ytrain)
+    return (
+        np.intersect1d(np.where(bmask)[0], test),
+        tempbis.predict_proba(Xtest[:, feat])[:, 1],
+    )
+
+
+class LateFusionSurvival(BaseEstimator):
+    """
+    Late fusion survival model for multimodal integration.
+
+    Parameters
+    ----------
+     estimators: list of (str, estimator, list, dict) tuples.
+        List of unimodal estimators to fit and fuse. Each unimodal estimator is associate with a tuple
+        (`name`, `estimator` , `features`, `tune_dict`) where `name` is a string and corresponds to the name of the
+        estimator, `estimator` ia scikit-learn estimator inheriting from BaseEstimator (can be a Pipeline), `features`
+        is a list of indexes corresponding to the columns of the data associated with the modality of intereset, and
+        `tune_dict` is either a dictionnary or a tuple (dict, n_iterations) for hyperparameter tuning and GridSearch or
+        RandomSearch strategy respectively.
+
+    cv: cross-validation generator
+        cross-validation scheme for hyperparameter tuning (if `tuning` is not None) and/or calibration (if `calibration`
+        is not None). The default is None
+
+    score: str or callable.
+        Score to use for tuning the unimodal models or weighting them at the late fusion step (i.e. sum of the unimodal
+        predictions weighted by the performance of each unimodal model estimated with cross-validation).
+        See sklearn.model_selection.cross_val_score for more details. The default is None.
+
+    random_score: float.
+        Random score for classification. Used when weighting the unimodal models with their estimated score. Weights
+        will be max(score - random_score, 0). Unimodal models whose estimated performance is below the random_score
+        will not be taken into account. The default is 0.5
+
+    sup_weights: bool.
+        Whether to use weights associated with the cross-validation performance of each unimodal model. If false no
+        weights are used when fusing the unimodal predictions. The default is False.
+
+    missing_threshold: float in ]0, 1].
+        Minimum frequency of missing values to consider a whole modality missing (e.g., if `missing_threshold = 0.9` it
+        means that for each sample and each modality at least 90% of the features associated with this modality must be
+        missing to consider the whole modality missing). The default is 0.9.
+
+    tuning: str or None.
+        Strategy for tuning each model. Either 'gridsearch' for GridSearchCV or 'randomsearch' for RandomSearchCV. If
+        None no hyperparameter tuning will be performed. The default is None.
+
+    n_jobs: int.
+        Number of jobs to run in parallel for hyperparameter tuning, collecting the predictions for calibration, or
+        estimating the performance of each unimodal model with cross-validation. The default is None.
+
+    calibration: bool.
+        If True each unimodal model is associated with a tuple (mean, std) estimated on predictions collected with
+        cross-validation. The predictions of each unimodal model are then standardized before the late fusion step.
+
+    Attributes
+    ----------
+    best_params_: list of dict or empty list.
+        List of best parameters for each unimodal predictor (output of GridSearchCV or RandomSearchCV). It follows the
+        same order as the one of `estimators` list. If `tuning` is None returns an empty list (i.e., no hyperparameter
+        tuning is performed).
+
+    weights_: list of float.
+        List of the weights associated to each modality and used at the late fusion stage for weighted sum.
+
+    fitted_estimators_: list of estimators.
+        List of fitted unimodal estimators.
+    """
+
+    def __init__(
+        self,
+        estimators,
+        cv,
+        score=None,
+        random_score=0.5,
+        sup_weights=True,
+        missing_threshold=0.9,
+        tuning=None,
+        n_jobs=None,
+        calibration=True,
+    ):
+        self.estimators = estimators
+        self.cv = cv
+        self.score = score
+        self.random_score = random_score
+        self.sup_weights = sup_weights
+        self.missing_threshold = missing_threshold
+        self.tuning = tuning
+        self.n_jobs = n_jobs
+        self.calibration = calibration
+
+        self.weights_ = []
+        self.fitted_estimators_ = []
+        self.best_params_ = []
+
+    def _fit_estim(self, X, y, estim, features, grid, name):
+
+        if (self.tuning is not None) and (len(grid) > 0):
+            if self.tuning == "gridsearch":
+                search = GridSearchCV(
+                    estimator=clone(estim),
+                    param_grid=grid,
+                    cv=self.cv,
+                    scoring=self.score,
+                    n_jobs=self.n_jobs,
+                )
+
+            elif self.tuning == "randomsearch":
+                search = RandomizedSearchCV(
+                    estimator=clone(estim),
+                    param_distributions=grid[1],
+                    n_iter=grid[0],
+                    scoring=self.score,
+                    n_jobs=self.n_jobs,
+                    cv=self.cv,
+                )
+
+            search.fit(X, y)
+
+            if self.sup_weights:
+                self.weights_.append(search.best_score_)
+            else:
+                self.weights_.append(1.0)
+
+            temp = search.best_estimator_
+            self.best_params_.append(search.best_params_)
+            # print("Best params " + name + " :", search.best_params_)
+            # print("Best score " + name + " :", search.best_score_)
+        else:
+            if self.sup_weights:
+                self.weights_.append(
+                    np.mean(
+                        cross_val_score(
+                            estimator=clone(estim),
+                            X=X,
+                            y=y,
+                            cv=self.cv,
+                            scoring=self.score,
+                        )
+                    )
+                )
+            else:
+                self.weights_.append(1.0)
+            temp = clone(estim).fit(X, y)
+
+        # self.fitted_estimators_.append((name, temp, features))
+        return temp
+
+    def fit(self, X, y):
+        """
+        Fit the latefusion survival model.
+
+        Parameters
+        ----------
+        X: array of shape (n_samples, n_features)
+            Multimodal array, concatenation of the features from the different modalities. Missing modalities are filled
+            with NaNs values for each sample.
+
+        y: structured array of shape (n_samples, ) see sksurv.util.Surv (from scikit-survival).
+            Structured array for survival target/outcome
+
+        Returns
+        -------
+        self : object
+            Returns the instance itself.
+        """
+        for name, estim, features, grid in self.estimators:
+            Xnew = X[:, features]
+            bool_mask = ~(
+                np.sum(np.isnan(Xnew), axis=1) > self.missing_threshold * len(features)
+            )
+            Xnew, ynew = Xnew[bool_mask, :], y[bool_mask]
+
+            fitted_estim = self._fit_estim(
+                Xnew, ynew, estim=estim, features=features, grid=grid, name=name
+            )
+            if self.calibration:
+                parallel = Parallel(n_jobs=self.n_jobs)
+                collected_predictions = parallel(
+                    delayed(_collect_surv)(
+                        Xdata=X,
+                        ydata=y,
+                        estimator=estim,
+                        bmask=bool_mask,
+                        feat=features,
+                        train=train,
+                        test=test,
+                    )
+                    for train, test in self.cv.split(X, y)
+                )
+                temp = np.concatenate(collected_predictions)
+                mean, std = np.mean(temp), np.std(temp)
+            else:
+                mean, std = None, None
+            self.fitted_estimators_.append((name, fitted_estim, features, (mean, std)))
+
+        self.weights_ = np.array(self.weights_) - self.random_score
+        self.weights_ = np.where(self.weights_ > 0, self.weights_, 0)
+        # if np.sum(self.weights_) > 0:
+        #    self.weights_ = self.weights_/np.sum(self.weights_)
+        return self
+
+    def predict(self, X, estim_ind=None):
+        """
+        Predict risk scores
+
+        Parameters
+        ----------
+        X: array of shape (n_samples, n_features)
+            Multimodal array, concatenation of the features from the different modalities. Missing modalities are filled
+            with NaNs values for each sample.
+
+        estim_ind: tuple of integers.
+            Tuple representing a multimodal combination (e.g. (i, j, k) corresponds to the combination of the ith, the
+            jth and the kth estimators in self.fitted_estimators_). If None all the multimodal combination with all the
+            fitted unimodal predictors is considered.
+
+        Returns
+        -------
+        risk_scores: array of shape (n_samples,).
+            Predictied risk scores. If no modality are availbale for the sample, returns 0.
+        """
+        if estim_ind is not None:
+            fitted_estimators = [self.fitted_estimators_[i] for i in estim_ind]
+        else:
+            fitted_estimators = self.fitted_estimators_
+
+        preds = np.zeros((X.shape[0], len(fitted_estimators)))
+        weights = np.zeros((X.shape[0], len(fitted_estimators)))
+        for j, item in enumerate(fitted_estimators):
+            Xpred = X[:, item[2]].copy()
+            bool_mask = ~(
+                np.sum(np.isnan(Xpred), axis=1) > self.missing_threshold * len(item[2])
+            )
+            weights[:, j] = np.where(bool_mask, self.weights_[j], 0)
+            if self.calibration:
+                mean = item[3][0]
+                std = item[3][1] if item[3][1] != 0 else 1
+                preds[bool_mask, j] = (
+                    item[1].predict(Xpred[bool_mask, :]) - mean
+                ) / std
+            else:
+                preds[bool_mask, j] = item[1].predict(Xpred[bool_mask, :])
+        temp = np.sum(weights, axis=1)
+        weights[temp > 0] = weights[temp > 0] / (temp[temp > 0].reshape(-1, 1))
+        return np.sum(preds * weights, axis=1)
+
+    def find_logrank_threshold(
+        self, X, y, estim_ind, percentile_min=30, percentile_max=70
+    ):
+        """
+        Find the best cutoff that optimize the stratification of samples with respect to survival data (using logrank
+        test).
+
+        Parameters
+        ----------
+        X: array of shape (n_samples, n_features)
+            Multimodal array, concatenation of the features from the different modalities. Missing modalities are filled
+            with NaNs values for each sample.
+
+        y: structured array of shape (n_samples,) see sksurv.util.Surv (from scikit-survival)
+            Structured array for survival data associated with X.
+
+        estim_ind: tuple of integers.
+            Tuple representing a multimodal combination (e.g. (i, j, k) corresponds to the combination of the ith, the
+            jth and the kth estimators in self.fitted_estimators_). If None all the multimodal combination with all the
+            fitted unimodal predictors is considered.
+
+        percentile_min: int in [0, 100]
+            Minimum value of the percentile range used to explore various cutoff values for predicted probabilities
+
+        percentile_max: int in [0, 100]
+            Maximum value of the percentile range used to explore various cutoff values for predicted probabilities
+
+        Returns
+        -------
+        cutoff: float.
+            Best cutoff for the predicted probabilities that otpimize the log-rank test.
+        """
+        risk_score = self.predict(X, estim_ind=estim_ind)
+        bool_mask = risk_score == 0
+        cutoffs, pvals = [], []
+        risk_score_new, y_new = risk_score[~bool_mask], y[~bool_mask]
+        for p in np.arange(percentile_min, percentile_max + 1):
+            c = np.percentile(risk_score_new, p)
+            group1 = risk_score_new <= c
+            group2 = risk_score_new > c
+            test = logrank_test(
+                durations_A=y_new[group1]["time"],
+                durations_B=y_new[group2]["time"],
+                event_observed_A=1 * (y_new[group1]["event"]),
+                event_observed_B=1 * (y_new[group2]["event"]),
+            )
+            cutoffs.append(c)
+            pvals.append(test.summary["p"].values[0])
+        return cutoffs[np.argmin(pvals)]
+
+
+def _collect_surv(Xdata, ydata, estimator, bmask, feat, train, test):
+    Xtrain, Xtest, ytrain, ytest = (
+        Xdata[np.intersect1d(np.where(bmask)[0], train), :],
+        Xdata[np.intersect1d(np.where(bmask)[0], test), :],
+        ydata[np.intersect1d(np.where(bmask)[0], train)],
+        ydata[np.intersect1d(np.where(bmask)[0], test)],
+    )
+    tempbis = clone(estimator).fit(Xtrain[:, feat], ytrain)
+    return tempbis.predict(Xtest[:, feat])