a b/multipit/multi_model/latefusion.py
1
from itertools import combinations
2
3
import numpy as np
4
import pandas as pd
5
from joblib import Parallel, delayed
6
from lifelines.statistics import logrank_test
7
from sklearn.base import clone, BaseEstimator
8
from sklearn.linear_model import LogisticRegression
9
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV, cross_val_score
10
11
12
class LateFusionClassifier(BaseEstimator):
13
    """
14
    Late fusion classifier for multimodal integration.
15
16
    Parameters
17
    ----------
18
    estimators: list of (str, estimator, list, dict) tuples.
19
        List of unimodal estimators to fit and fuse. Each unimodal estimator is associate with a tuple
20
        (`name`, `estimator` , `features`, `tune_dict`) where `name` is a string and corresponds to the name of the
21
        estimator, `estimator` ia scikit-learn estimator inheriting from BaseEstimator (can be a Pipeline), `features`
22
        is a list of indexes corresponding to the columns of the data associated with the modality of intereset, and
23
        `tune_dict` is either a dictionnary or a tuple (dict, n_iterations) for hyperparameter tuning and GridSearch or
24
        RandomSearch strategy respectively.
25
26
    cv: cross-validation generator
27
        cross-validation scheme for hyperparameter tuning (if `tuning` is not None) and/or calibration (if `calibration`
28
        is not None). The default is None
29
30
    score: str or callable.
31
        Score to use for tuning the unimodal models or weighting them at the late fusion step (i.e. sum of the unimodal
32
        predictions weighted by the performance of each unimodal model estimated with cross-validation).
33
        See sklearn.model_selection.cross_val_score for more details. The default is None.
34
35
    random_score: float.
36
        Random score for classification. Used when weighting the unimodal models with their estimated score. Weights
37
        will be max(score - random_score, 0). Unimodal models whose estimated performance is below the random_score
38
        will not be taken into account. The default is 0.5
39
40
    sup_weights: bool.
41
        Whether to use weights associated with the cross-validation performance of each unimodal model. If false no
42
        weights are used when fusing the unimodal predictions. The default is False.
43
44
    missing_threshold: float in ]0, 1].
45
        Minimum frequency of missing values to consider a whole modality missing (e.g., if `missing_threshold = 0.9` it
46
        means that for each sample and each modality at least 90% of the features associated with this modality must be
47
        missing to consider the whole modality missing). The default is 0.9.
48
49
    tuning: str or None.
50
        Strategy for tuning each model. Either 'gridsearch' for GridSearchCV or 'randomsearch' for RandomSearchCV. If
51
        None no hyperparameter tuning will be performed. The default is None.
52
53
    n_jobs: int.
54
        Number of jobs to run in parallel for hyperparameter tuning, collecting the predictions for calibration, or
55
        estimating the performance of each unimodal model with cross-validation. The default is None.
56
57
    calibration: str or None.
58
        Calibration strategy.
59
            * `calibration = 'late'` means that the fusion is made before calibration. The predictions of each
60
            multimodal combination are collected with cross-validation and a univariate logistic regression model is
61
            fitted to these predictions.
62
            * `calibration = 'early'` means that each unimodal model is calibrated prior to the late fusion. The
63
            unimodal predictions are collected with a cross-validation scheme and univariate logistic regression models
64
            are fitted.
65
            * `calibration = None` means that no calibration is performed.
66
67
    Attributes
68
    ----------
69
    best_params_: list of dict or empty list.
70
        List of best parameters for each unimodal predictor (output of GridSearchCV or RandomSearchCV). It follows the
71
        same order as the one of `estimators` list. If `tuning` is None returns an empty list (i.e., no hyperparameter
72
        tuning is performed).
73
74
    weights_: list of float.
75
        List of the weights associated to each modality and used at the late fusion stage for weighted sum.
76
77
    fitted_estimators_: list of estimators.
78
        List of fitted unimodal estimators.
79
80
    fitted_meta_estimators_: dictionary of estimators.
81
        Dictionary of meta-estimators for calibration. If `calibration = "early"` the keys correspond to the indexes of
82
        each unimodal estimator (i.e., from 0 to n_estimators-1) and the values correspond to the logistic regression
83
        estimators fitted to calibrate the unimodal models. If `calibration = "late"` the keys correspond to tuples
84
        characterizing each multimodal combination (e.g., (1, 3, 5)) and the values correspond th the logistic regression
85
        estimators fitted to clibrate the multimodal models.
86
    """
87
88
    def __init__(
89
        self,
90
        estimators,
91
        cv=None,
92
        score=None,
93
        random_score=0.5,
94
        sup_weights=False,
95
        missing_threshold=0.9,
96
        tuning=None,
97
        n_jobs=None,
98
        calibration="late",
99
    ):
100
        self.estimators = estimators
101
        self.cv = cv
102
        self.score = score
103
        self.random_score = random_score
104
        self.sup_weights = sup_weights
105
        self.missing_threshold = missing_threshold
106
        self.tuning = tuning
107
        self.n_jobs = n_jobs
108
        self.calibration = calibration
109
110
        self.best_params_ = []
111
        self.weights_ = []
112
        self.fitted_estimators_ = []
113
        self.fitted_meta_estimators_ = {}
114
115
    def fit(self, X, y):
116
        """
117
        Fit the latefusion classifier.
118
119
        Parameters
120
        ----------
121
        X: array of shape (n_samples, n_features)
122
            Multimodal array, concatenation of the features from the different modalities. Missing modalities are filled
123
            with NaNs values for each sample.
124
125
        y: array of shape (n_samples,)
126
            Target to predict.
127
128
        Returns
129
        -------
130
        self : object
131
            Returns the instance itself.
132
        """
133
        predictions = np.zeros((X.shape[0], len(self.estimators)))
134
        weights = np.zeros((X.shape[0], len(self.estimators)))
135
        i = 0
136
        for name, estim, features, grid in self.estimators:
137
            Xnew = X[:, features]
138
            # bool_mask = ~(np.sum(np.isnan(Xnew), axis=1) > self.missing_threshold * len(features))
139
            bool_mask = ~(
140
                np.sum(pd.isnull(Xnew), axis=1) > self.missing_threshold * len(features)
141
            )
142
            Xnew, ynew = Xnew[bool_mask, :], y[bool_mask]
143
            # Fit unimodal estimator
144
            self._fit_estim(
145
                Xnew, ynew, estim=estim, features=features, grid=grid, name=name
146
            )
147
148
            # Collect predictions and weights for further calibration
149
            if self.calibration is not None:
150
                weights[bool_mask, i] = max(self.weights_[-1] - self.random_score, 0)
151
                parallel = Parallel(n_jobs=self.n_jobs)
152
                collected_predictions = parallel(
153
                    delayed(_collect)(
154
                        Xdata=X,
155
                        ydata=y,
156
                        estimator=estim,
157
                        bmask=bool_mask,
158
                        feat=features,
159
                        train=train,
160
                        test=test,
161
                    )
162
                    for train, test in self.cv.split(X, y)
163
                )
164
                for indexes, preds in collected_predictions:
165
                    predictions[indexes, i] = preds
166
            i += 1
167
168
        # Calibrate models
169
        if self.calibration is not None:
170
            if self.calibration == "early":
171
                self._fit_early_calibration(predictions=predictions, y=y)
172
            elif self.calibration == "late":
173
                self._fit_late_calibration(
174
                    predictions=predictions, weights=weights, y=y
175
                )
176
            else:
177
                raise ValueError(
178
                    "'early', 'late' or None are the only values available for calibration parameter"
179
                )
180
181
        self.weights_ = np.array(self.weights_) - self.random_score
182
        self.weights_ = np.where(self.weights_ > 0, self.weights_, 0)
183
        return self
184
185
    def _fit_estim(self, X, y, estim, features, grid, name):
186
        """
187
        Fit a unimodal estimator.
188
        """
189
        if (self.tuning is not None) and (len(grid) > 0):
190
            if self.tuning == "gridsearch":
191
                search = GridSearchCV(
192
                    estimator=clone(estim),
193
                    param_grid=grid,
194
                    cv=self.cv,
195
                    scoring=self.score,
196
                    n_jobs=self.n_jobs,
197
                )
198
            elif self.tuning == "randomsearch":
199
                search = RandomizedSearchCV(
200
                    estimator=clone(estim),
201
                    param_distributions=grid[1],
202
                    n_iter=grid[0],
203
                    scoring=self.score,
204
                    n_jobs=self.n_jobs,
205
                    cv=self.cv,
206
                )
207
208
            search.fit(X, y)
209
210
            if self.sup_weights:
211
                self.weights_.append(search.best_score_)
212
            else:
213
                self.weights_.append(1.0)
214
215
            temp = search.best_estimator_
216
            self.best_params_.append(search.best_params_)
217
            # print("Best params " + name + " :", search.best_params_)
218
            # print("Best score " + name + " :", search.best_score_)
219
        else:
220
            if self.sup_weights:
221
                self.weights_.append(
222
                    np.mean(
223
                        cross_val_score(
224
                            estimator=clone(estim),
225
                            X=X,
226
                            y=y,
227
                            cv=self.cv,
228
                            scoring=self.score,
229
                            n_jobs=self.n_jobs,
230
                        )
231
                    )
232
                )
233
            else:
234
                self.weights_.append(1.0)
235
            temp = clone(estim).fit(X, y)
236
237
        self.fitted_estimators_.append((name, temp, features))
238
        return
239
240
    def _fit_early_calibration(self, predictions, y):
241
        """
242
        Calibrate only each unimodal predictor.
243
        """
244
        for i in range(len(self.estimators)):
245
            probas = predictions[:, i]
246
            mask = (probas > 0).reshape(-1)
247
            self.fitted_meta_estimators_[i] = LogisticRegression(
248
                class_weight="balanced"
249
            ).fit(probas[mask].reshape(-1, 1), y[mask])
250
        return
251
252
    def _fit_late_calibration(self, predictions, weights, y):
253
        """
254
        Calibrate each combination of modalities.
255
        """
256
        for i in range(1, len(self.estimators) + 1):
257
            for comb in combinations(range(len(self.estimators)), i):
258
                probas = predictions[:, np.array(comb)]
259
                if len(comb) == 1:
260
                    mask = (probas > 0).reshape(-1)
261
                else:
262
                    w = weights[:, np.array(comb)]
263
                    mask = np.any(probas > 0, axis=1).reshape(-1)
264
                    temp = np.sum(w, axis=1)
265
                    w[temp > 0] = w[temp > 0] / (temp[temp > 0].reshape(-1, 1))
266
                    probas = np.sum(probas * w, axis=1)
267
268
                self.fitted_meta_estimators_[comb] = LogisticRegression(
269
                    class_weight="balanced"
270
                ).fit(probas[mask].reshape(-1, 1), y[mask])
271
        return
272
273
    def predict_proba(self, X, estim_ind=None):
274
        """
275
        Late fusion probability estimates
276
277
        Parameters
278
        ----------
279
        X: array of shape (n_samples, n_features)
280
            Multimodal array, concatenation of the features from the different modalities. Missing modalities are filled
281
            with NaNs values for each sample.
282
283
        estim_ind: tuple of integers.
284
            Tuple representing a multimodal combination (e.g. (i, j, k) corresponds to the combination of the ith, the
285
            jth and the kth estimators in self.fitted_estimators_). If None all the multimodal combination with all the
286
            fitted unimodal predictors is considered.
287
288
        Returns
289
        -------
290
        probas: array of shape (n_samples, 2).
291
            Probability of the samples for each class. If no modality are availbale for the sample, returns 0.5 for
292
            both classes.
293
        """
294
        fitted_estimators = (
295
            [self.fitted_estimators_[i] for i in estim_ind]
296
            if estim_ind is not None
297
            else self.fitted_estimators_
298
        )
299
        fitted_weights = (
300
            np.array([self.weights_[i] for i in estim_ind])
301
            if estim_ind is not None
302
            else self.weights_
303
        )
304
305
        # Collect predictions for each modality
306
        preds = np.zeros((X.shape[0], len(fitted_estimators)))
307
        weights = np.zeros((X.shape[0], len(fitted_weights)))
308
        for j, item in enumerate(fitted_estimators):
309
            Xpred = X[:, item[2]].copy()
310
            # bool_mask = ~(np.sum(np.isnan(Xpred), axis=1) > self.missing_threshold * len(item[2]))
311
            bool_mask = ~(
312
                np.sum(pd.isnull(Xpred), axis=1) > self.missing_threshold * len(item[2])
313
            )
314
            weights[:, j] = np.where(bool_mask, fitted_weights[j], 0)
315
            preds[bool_mask, j] = item[1].predict_proba(Xpred[bool_mask, :])[:, 1]
316
317
        # Calibrate the predictions and predict probas
318
        if self.calibration is not None:
319
            if self.calibration == "late":
320
                probas = self._predict_calibrate_late(preds, weights, estim_ind)
321
            elif self.calibration == "early":
322
                probas = self._predict_calibrate_early(preds, weights, estim_ind)
323
            else:
324
                raise ValueError(
325
                    "'early', 'late' or None are the only values available for calibration parameter"
326
                )
327
        else:
328
            probas = self._predict_uncalibrated(preds, weights)
329
        return np.hstack([1 - probas, probas])
330
331
    @staticmethod
332
    def _predict_uncalibrated(preds, weights):
333
        """
334
        Return weighted sum of available unimodal predictions
335
        """
336
        temp = np.sum(weights, axis=1)
337
        weights[temp > 0] = weights[temp > 0] / (temp[temp > 0].reshape(-1, 1))
338
        probas = np.sum(preds * weights, axis=1)
339
        return np.where(temp == 0, 0.5, probas).reshape(-1, 1)
340
341
    def _predict_calibrate_early(self, preds, weights, estim_ind):
342
        """
343
        Return weighted sum of available and calibrated unimodal predictions
344
        """
345
        temp = np.sum(weights, axis=1)
346
        weights[temp > 0] = weights[temp > 0] / (temp[temp > 0].reshape(-1, 1))
347
        list_meta_estimators = (
348
            [self.fitted_meta_estimators_[i] for i in estim_ind]
349
            if estim_ind is not None
350
            else list(self.fitted_meta_estimators_.values())
351
        )
352
        for j, meta in enumerate(list_meta_estimators):
353
            preds[:, j] = np.where(
354
                weights[:, j] != 0,
355
                meta.predict_proba(preds[:, j].reshape(-1, 1))[:, 1],
356
                0,
357
            )
358
        probas = np.sum(preds * weights, axis=1)
359
        return np.where(temp == 0, 0.5, probas).reshape(-1, 1)
360
361
    def _predict_calibrate_late(self, preds, weights, estim_ind):
362
        """
363
        Return calibrated weighted sum of availbale unimodal predictions
364
        """
365
        temp = np.sum(weights, axis=1)
366
        weights[temp > 0] = weights[temp > 0] / (temp[temp > 0].reshape(-1, 1))
367
        probas = np.sum(preds * weights, axis=1)
368
        meta_estimator = (
369
            self.fitted_meta_estimators_[estim_ind]
370
            if estim_ind is not None
371
            else list(self.fitted_meta_estimators_.values())[-1]
372
        )
373
        return np.where(
374
            temp == 0, 0.5, meta_estimator.predict_proba(probas.reshape(-1, 1))[:, 1]
375
        ).reshape(-1, 1)
376
377
    def find_logrank_threshold(
378
        self, X, ysurv, estim_ind, percentile_min=30, percentile_max=70
379
    ):
380
        """
381
        Find the best cutoff that optimize the stratification of samples with respect to survival data (using logrank
382
        test).
383
384
        Parameters
385
        ----------
386
        X: array of shape (n_samples, n_features)
387
            Multimodal array, concatenation of the features from the different modalities. Missing modalities are filled
388
            with NaNs values for each sample.
389
390
        ysurv: structured array of shape (n_samples,) see sksurv.util.Surv (from scikit-survival)
391
            Structured array for survival data associated with X.
392
393
        estim_ind: tuple of integers.
394
            Tuple representing a multimodal combination (e.g. (i, j, k) corresponds to the combination of the ith, the
395
            jth and the kth estimators in self.fitted_estimators_). If None all the multimodal combination with all the
396
            fitted unimodal predictors is considered.
397
398
        percentile_min: int in [0, 100]
399
            Minimum value of the percentile range used to explore various cutoff values for predicted probabilities
400
401
        percentile_max: int in [0, 100]
402
            Maximum value of the percentile range used to explore various cutoff values for predicted probabilities
403
404
        Returns
405
        -------
406
        cutoff: float.
407
            Best cutoff for the predicted probabilities that otpimize the log-rank test.
408
        """
409
        risk_score = self.predict_proba(X, estim_ind=estim_ind)[:, 1]
410
        bool_mask = risk_score == 0.5
411
        cutoffs, pvals = [], []
412
        risk_score_new, y_new = risk_score[~bool_mask], ysurv[~bool_mask]
413
        for p in np.arange(percentile_min, percentile_max + 1):
414
            c = np.percentile(risk_score_new, p)
415
            group1 = risk_score_new <= c
416
            group2 = risk_score_new > c
417
            test = logrank_test(
418
                durations_A=y_new[group1]["time"],
419
                durations_B=y_new[group2]["time"],
420
                event_observed_A=1 * (y_new[group1]["event"]),
421
                event_observed_B=1 * (y_new[group2]["event"]),
422
            )
423
            cutoffs.append(c)
424
            pvals.append(test.summary["p"].values[0])
425
        return cutoffs[np.argmin(pvals)]
426
427
428
def _collect(Xdata, ydata, estimator, bmask, feat, train, test):
429
    Xtrain, Xtest, ytrain, ytest = (
430
        Xdata[np.intersect1d(np.where(bmask)[0], train), :],
431
        Xdata[np.intersect1d(np.where(bmask)[0], test), :],
432
        ydata[np.intersect1d(np.where(bmask)[0], train)],
433
        ydata[np.intersect1d(np.where(bmask)[0], test)],
434
    )
435
    tempbis = clone(estimator).fit(Xtrain[:, feat], ytrain)
436
    return (
437
        np.intersect1d(np.where(bmask)[0], test),
438
        tempbis.predict_proba(Xtest[:, feat])[:, 1],
439
    )
440
441
442
class LateFusionSurvival(BaseEstimator):
443
    """
444
    Late fusion survival model for multimodal integration.
445
446
    Parameters
447
    ----------
448
     estimators: list of (str, estimator, list, dict) tuples.
449
        List of unimodal estimators to fit and fuse. Each unimodal estimator is associate with a tuple
450
        (`name`, `estimator` , `features`, `tune_dict`) where `name` is a string and corresponds to the name of the
451
        estimator, `estimator` ia scikit-learn estimator inheriting from BaseEstimator (can be a Pipeline), `features`
452
        is a list of indexes corresponding to the columns of the data associated with the modality of intereset, and
453
        `tune_dict` is either a dictionnary or a tuple (dict, n_iterations) for hyperparameter tuning and GridSearch or
454
        RandomSearch strategy respectively.
455
456
    cv: cross-validation generator
457
        cross-validation scheme for hyperparameter tuning (if `tuning` is not None) and/or calibration (if `calibration`
458
        is not None). The default is None
459
460
    score: str or callable.
461
        Score to use for tuning the unimodal models or weighting them at the late fusion step (i.e. sum of the unimodal
462
        predictions weighted by the performance of each unimodal model estimated with cross-validation).
463
        See sklearn.model_selection.cross_val_score for more details. The default is None.
464
465
    random_score: float.
466
        Random score for classification. Used when weighting the unimodal models with their estimated score. Weights
467
        will be max(score - random_score, 0). Unimodal models whose estimated performance is below the random_score
468
        will not be taken into account. The default is 0.5
469
470
    sup_weights: bool.
471
        Whether to use weights associated with the cross-validation performance of each unimodal model. If false no
472
        weights are used when fusing the unimodal predictions. The default is False.
473
474
    missing_threshold: float in ]0, 1].
475
        Minimum frequency of missing values to consider a whole modality missing (e.g., if `missing_threshold = 0.9` it
476
        means that for each sample and each modality at least 90% of the features associated with this modality must be
477
        missing to consider the whole modality missing). The default is 0.9.
478
479
    tuning: str or None.
480
        Strategy for tuning each model. Either 'gridsearch' for GridSearchCV or 'randomsearch' for RandomSearchCV. If
481
        None no hyperparameter tuning will be performed. The default is None.
482
483
    n_jobs: int.
484
        Number of jobs to run in parallel for hyperparameter tuning, collecting the predictions for calibration, or
485
        estimating the performance of each unimodal model with cross-validation. The default is None.
486
487
    calibration: bool.
488
        If True each unimodal model is associated with a tuple (mean, std) estimated on predictions collected with
489
        cross-validation. The predictions of each unimodal model are then standardized before the late fusion step.
490
491
    Attributes
492
    ----------
493
    best_params_: list of dict or empty list.
494
        List of best parameters for each unimodal predictor (output of GridSearchCV or RandomSearchCV). It follows the
495
        same order as the one of `estimators` list. If `tuning` is None returns an empty list (i.e., no hyperparameter
496
        tuning is performed).
497
498
    weights_: list of float.
499
        List of the weights associated to each modality and used at the late fusion stage for weighted sum.
500
501
    fitted_estimators_: list of estimators.
502
        List of fitted unimodal estimators.
503
    """
504
505
    def __init__(
506
        self,
507
        estimators,
508
        cv,
509
        score=None,
510
        random_score=0.5,
511
        sup_weights=True,
512
        missing_threshold=0.9,
513
        tuning=None,
514
        n_jobs=None,
515
        calibration=True,
516
    ):
517
        self.estimators = estimators
518
        self.cv = cv
519
        self.score = score
520
        self.random_score = random_score
521
        self.sup_weights = sup_weights
522
        self.missing_threshold = missing_threshold
523
        self.tuning = tuning
524
        self.n_jobs = n_jobs
525
        self.calibration = calibration
526
527
        self.weights_ = []
528
        self.fitted_estimators_ = []
529
        self.best_params_ = []
530
531
    def _fit_estim(self, X, y, estim, features, grid, name):
532
533
        if (self.tuning is not None) and (len(grid) > 0):
534
            if self.tuning == "gridsearch":
535
                search = GridSearchCV(
536
                    estimator=clone(estim),
537
                    param_grid=grid,
538
                    cv=self.cv,
539
                    scoring=self.score,
540
                    n_jobs=self.n_jobs,
541
                )
542
543
            elif self.tuning == "randomsearch":
544
                search = RandomizedSearchCV(
545
                    estimator=clone(estim),
546
                    param_distributions=grid[1],
547
                    n_iter=grid[0],
548
                    scoring=self.score,
549
                    n_jobs=self.n_jobs,
550
                    cv=self.cv,
551
                )
552
553
            search.fit(X, y)
554
555
            if self.sup_weights:
556
                self.weights_.append(search.best_score_)
557
            else:
558
                self.weights_.append(1.0)
559
560
            temp = search.best_estimator_
561
            self.best_params_.append(search.best_params_)
562
            # print("Best params " + name + " :", search.best_params_)
563
            # print("Best score " + name + " :", search.best_score_)
564
        else:
565
            if self.sup_weights:
566
                self.weights_.append(
567
                    np.mean(
568
                        cross_val_score(
569
                            estimator=clone(estim),
570
                            X=X,
571
                            y=y,
572
                            cv=self.cv,
573
                            scoring=self.score,
574
                        )
575
                    )
576
                )
577
            else:
578
                self.weights_.append(1.0)
579
            temp = clone(estim).fit(X, y)
580
581
        # self.fitted_estimators_.append((name, temp, features))
582
        return temp
583
584
    def fit(self, X, y):
585
        """
586
        Fit the latefusion survival model.
587
588
        Parameters
589
        ----------
590
        X: array of shape (n_samples, n_features)
591
            Multimodal array, concatenation of the features from the different modalities. Missing modalities are filled
592
            with NaNs values for each sample.
593
594
        y: structured array of shape (n_samples, ) see sksurv.util.Surv (from scikit-survival).
595
            Structured array for survival target/outcome
596
597
        Returns
598
        -------
599
        self : object
600
            Returns the instance itself.
601
        """
602
        for name, estim, features, grid in self.estimators:
603
            Xnew = X[:, features]
604
            bool_mask = ~(
605
                np.sum(np.isnan(Xnew), axis=1) > self.missing_threshold * len(features)
606
            )
607
            Xnew, ynew = Xnew[bool_mask, :], y[bool_mask]
608
609
            fitted_estim = self._fit_estim(
610
                Xnew, ynew, estim=estim, features=features, grid=grid, name=name
611
            )
612
            if self.calibration:
613
                parallel = Parallel(n_jobs=self.n_jobs)
614
                collected_predictions = parallel(
615
                    delayed(_collect_surv)(
616
                        Xdata=X,
617
                        ydata=y,
618
                        estimator=estim,
619
                        bmask=bool_mask,
620
                        feat=features,
621
                        train=train,
622
                        test=test,
623
                    )
624
                    for train, test in self.cv.split(X, y)
625
                )
626
                temp = np.concatenate(collected_predictions)
627
                mean, std = np.mean(temp), np.std(temp)
628
            else:
629
                mean, std = None, None
630
            self.fitted_estimators_.append((name, fitted_estim, features, (mean, std)))
631
632
        self.weights_ = np.array(self.weights_) - self.random_score
633
        self.weights_ = np.where(self.weights_ > 0, self.weights_, 0)
634
        # if np.sum(self.weights_) > 0:
635
        #    self.weights_ = self.weights_/np.sum(self.weights_)
636
        return self
637
638
    def predict(self, X, estim_ind=None):
639
        """
640
        Predict risk scores
641
642
        Parameters
643
        ----------
644
        X: array of shape (n_samples, n_features)
645
            Multimodal array, concatenation of the features from the different modalities. Missing modalities are filled
646
            with NaNs values for each sample.
647
648
        estim_ind: tuple of integers.
649
            Tuple representing a multimodal combination (e.g. (i, j, k) corresponds to the combination of the ith, the
650
            jth and the kth estimators in self.fitted_estimators_). If None all the multimodal combination with all the
651
            fitted unimodal predictors is considered.
652
653
        Returns
654
        -------
655
        risk_scores: array of shape (n_samples,).
656
            Predictied risk scores. If no modality are availbale for the sample, returns 0.
657
        """
658
        if estim_ind is not None:
659
            fitted_estimators = [self.fitted_estimators_[i] for i in estim_ind]
660
        else:
661
            fitted_estimators = self.fitted_estimators_
662
663
        preds = np.zeros((X.shape[0], len(fitted_estimators)))
664
        weights = np.zeros((X.shape[0], len(fitted_estimators)))
665
        for j, item in enumerate(fitted_estimators):
666
            Xpred = X[:, item[2]].copy()
667
            bool_mask = ~(
668
                np.sum(np.isnan(Xpred), axis=1) > self.missing_threshold * len(item[2])
669
            )
670
            weights[:, j] = np.where(bool_mask, self.weights_[j], 0)
671
            if self.calibration:
672
                mean = item[3][0]
673
                std = item[3][1] if item[3][1] != 0 else 1
674
                preds[bool_mask, j] = (
675
                    item[1].predict(Xpred[bool_mask, :]) - mean
676
                ) / std
677
            else:
678
                preds[bool_mask, j] = item[1].predict(Xpred[bool_mask, :])
679
        temp = np.sum(weights, axis=1)
680
        weights[temp > 0] = weights[temp > 0] / (temp[temp > 0].reshape(-1, 1))
681
        return np.sum(preds * weights, axis=1)
682
683
    def find_logrank_threshold(
684
        self, X, y, estim_ind, percentile_min=30, percentile_max=70
685
    ):
686
        """
687
        Find the best cutoff that optimize the stratification of samples with respect to survival data (using logrank
688
        test).
689
690
        Parameters
691
        ----------
692
        X: array of shape (n_samples, n_features)
693
            Multimodal array, concatenation of the features from the different modalities. Missing modalities are filled
694
            with NaNs values for each sample.
695
696
        y: structured array of shape (n_samples,) see sksurv.util.Surv (from scikit-survival)
697
            Structured array for survival data associated with X.
698
699
        estim_ind: tuple of integers.
700
            Tuple representing a multimodal combination (e.g. (i, j, k) corresponds to the combination of the ith, the
701
            jth and the kth estimators in self.fitted_estimators_). If None all the multimodal combination with all the
702
            fitted unimodal predictors is considered.
703
704
        percentile_min: int in [0, 100]
705
            Minimum value of the percentile range used to explore various cutoff values for predicted probabilities
706
707
        percentile_max: int in [0, 100]
708
            Maximum value of the percentile range used to explore various cutoff values for predicted probabilities
709
710
        Returns
711
        -------
712
        cutoff: float.
713
            Best cutoff for the predicted probabilities that otpimize the log-rank test.
714
        """
715
        risk_score = self.predict(X, estim_ind=estim_ind)
716
        bool_mask = risk_score == 0
717
        cutoffs, pvals = [], []
718
        risk_score_new, y_new = risk_score[~bool_mask], y[~bool_mask]
719
        for p in np.arange(percentile_min, percentile_max + 1):
720
            c = np.percentile(risk_score_new, p)
721
            group1 = risk_score_new <= c
722
            group2 = risk_score_new > c
723
            test = logrank_test(
724
                durations_A=y_new[group1]["time"],
725
                durations_B=y_new[group2]["time"],
726
                event_observed_A=1 * (y_new[group1]["event"]),
727
                event_observed_B=1 * (y_new[group2]["event"]),
728
            )
729
            cutoffs.append(c)
730
            pvals.append(test.summary["p"].values[0])
731
        return cutoffs[np.argmin(pvals)]
732
733
734
def _collect_surv(Xdata, ydata, estimator, bmask, feat, train, test):
735
    Xtrain, Xtest, ytrain, ytest = (
736
        Xdata[np.intersect1d(np.where(bmask)[0], train), :],
737
        Xdata[np.intersect1d(np.where(bmask)[0], test), :],
738
        ydata[np.intersect1d(np.where(bmask)[0], train)],
739
        ydata[np.intersect1d(np.where(bmask)[0], test)],
740
    )
741
    tempbis = clone(estimator).fit(Xtrain[:, feat], ytrain)
742
    return tempbis.predict(Xtest[:, feat])