a b/scripts/latefusion_survival.py
1
import argparse
2
import inspect
3
import os
4
import sys
5
6
# import warnings
7
from datetime import datetime
8
9
import numpy as np
10
import pandas as pd
11
from joblib import delayed
12
from sklearn.base import clone
13
from sklearn.utils import check_random_state
14
from sksurv.util import Surv
15
from tqdm import tqdm
16
17
from _init_scripts import PredictionTask
18
from _utils import read_yaml, write_yaml, ProgressParallel
19
20
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
21
parentdir = os.path.dirname(currentdir)
22
sys.path.insert(0, parentdir)
23
24
from multipit.multi_model.latefusion import LateFusionSurvival
25
from multipit.utils.custom.cv import CensoredKFold
26
27
28
def main(params):
29
    """
30
    Repeated cross-validation experiment for survival prediction with late fusion
31
    """
32
33
    # Uncomment for disabling ConvergenceWarning
34
    # warnings.simplefilter("ignore")
35
    # os.environ["PYTHONWARNINGS"] = 'ignore'
36
37
    # 0. Read config file and save it in the results
38
    config = read_yaml(params.config)
39
    save_name = config["save_name"]
40
    if save_name is None:
41
        run_id = datetime.now().strftime(r"%m%d_%H%M%S")
42
        save_name = "exp_" + run_id
43
    save_dir = os.path.join(params.save_path, save_name)
44
    os.mkdir(save_dir)
45
    write_yaml(config, os.path.join(save_dir, "config.yaml"))
46
47
    # 1. fix random seeds for reproducibility
48
    seed = config["latefusion"]["seed"]
49
    np.random.seed(seed)
50
51
    # 2. Load data and define pipelines for each modality
52
    ptask = PredictionTask(config, survival=True, integration="late")
53
    ptask.load_data()
54
    X = ptask.data_concat.values
55
    y = Surv().from_arrays(
56
        event=ptask.labels.loc[ptask.data_concat.index, "event"].values,
57
        time=ptask.labels.loc[ptask.data_concat.index, "time"].values,
58
    )
59
    ptask.init_pipelines_latefusion()
60
61
    # 3. Perform repeated cross-validation
62
    parallel = ProgressParallel(
63
        n_jobs=config["parallelization"]["n_jobs_repeats"],
64
        total=config["latefusion"]["n_repeats"],
65
    )
66
    results_parallel = parallel(
67
        delayed(_fun_repeats)(
68
            ptask,
69
            X,
70
            y,
71
            r,
72
            disable_infos=(config["parallelization"]["n_jobs_repeats"] is not None)
73
            and (config["parallelization"]["n_jobs_repeats"] > 1),
74
        )
75
        for r in range(config["latefusion"]["n_repeats"])
76
    )
77
78
    # 4. Save results
79
    if config["permutation_test"]:
80
        perm_predictions = np.zeros(
81
            (
82
                len(y),
83
                len(ptask.names),
84
                config["n_permutations"],
85
                config["latefusion"]["n_repeats"],
86
            )
87
        )
88
        list_data_preds, list_data_thrs = [], []
89
        for p, res in enumerate(results_parallel):
90
            list_data_preds.append(res[0])
91
            list_data_thrs.append(res[1])
92
            perm_predictions[:, :, :, p] = res[2]
93
        perm_labels = results_parallel[-1][3]
94
        np.save(os.path.join(save_dir, "permutation_labels.npy"), perm_labels)
95
        np.save(os.path.join(save_dir, "permutation_predictions.npy"), perm_predictions)
96
        data_preds = pd.concat(list_data_preds, axis=0)
97
        data_preds.to_csv(os.path.join(save_dir, "predictions.csv"))
98
        if config["collect_thresholds"]:
99
            data_thrs = pd.concat(list_data_thrs, axis=0)
100
            data_thrs.to_csv(os.path.join(save_dir, "thresholds.csv"))
101
    else:
102
        list_data_preds, list_data_thrs = [], []
103
        for p, res in enumerate(results_parallel):
104
            list_data_preds.append(res[0])
105
            list_data_thrs.append(res[1])
106
        data_preds = pd.concat(list_data_preds, axis=0)
107
        data_preds.to_csv(os.path.join(save_dir, "predictions.csv"))
108
        if config["collect_thresholds"]:
109
            data_thrs = pd.concat(list_data_thrs, axis=0)
110
            data_thrs.to_csv(os.path.join(save_dir, "thresholds.csv"))
111
112
113
def _fun_repeats(prediction_task, X, y, r, disable_infos):
114
    """
115
    Train and test a late fusion model for survival task with cross-validation
116
117
    Parameters
118
    ----------
119
    prediction_task: PredictionTask object
120
121
    X: 2D array of shape (n_samples, n_features)
122
        Concatenation of the different modalities
123
124
    y: Structured array of size (n_samples,)
125
        Event indicator and observed time for each sample
126
127
    r: int
128
        Repeat number
129
130
    disable_infos: bool
131
132
    Returns
133
    -------
134
    df_pred: pd.DataFrame of shape (n_samples, n_models+4)
135
        Predictions collected over the test sets of the cross-validation scheme for each multimodal combination
136
137
    df_thrs: pd.DataFrame of shape (n_samples, n_models+2), None
138
        Thresholds that optimize the log-rank test on the training set for each fold and each multimodal combination.
139
140
    permut_predictions: 3D array of shape (n_samples, n_models, n_permutations)
141
        Predictions collected over the test sets of the cross_validation scheme for each multimodal combination and each random permutation of the labels.
142
143
    permut_labels: 3D array of shape (n_samples, n_permutations, 2)
144
        Permuted event indicators and observed times.
145
    """
146
    cv = CensoredKFold(n_splits=10, shuffle=True)  # , random_state=np.random.seed(i))
147
    X_preds = np.zeros((len(y), 4 + len(prediction_task.names)))
148
    X_thresholds = (
149
        np.zeros((len(y), 2 + len(prediction_task.names)))
150
        if prediction_task.config["collect_thresholds"]
151
        else None
152
    )
153
    late_clf = LateFusionSurvival(
154
        estimators=prediction_task.late_estimators,
155
        cv=CensoredKFold(n_splits=10, shuffle=True, random_state=np.random.seed(r)),
156
        **prediction_task.config["latefusion"]["args"]
157
    )
158
159
    for fold_index, (train_index, test_index) in tqdm(
160
        enumerate(cv.split(np.zeros(len(y)), y)),
161
        leave=False,
162
        total=cv.get_n_splits(np.zeros(len(y))),
163
        disable=disable_infos,
164
    ):
165
        X_train, y_train, X_test, y_test = (
166
            X[train_index, :],
167
            y[train_index],
168
            X[test_index, :],
169
            y[test_index],
170
        )
171
172
        clf = clone(late_clf)
173
        clf.fit(X_train, y_train)
174
175
        for c, idx in enumerate(prediction_task.indices):
176
            X_preds[test_index, c] = clf.predict(X_test, estim_ind=idx)
177
            if prediction_task.config["collect_thresholds"]:
178
                X_thresholds[test_index, c] = clf.find_logrank_threshold(
179
                    X_train, y_train, estim_ind=idx
180
                )
181
182
        X_preds[test_index, -4] = fold_index
183
        if prediction_task.config["collect_thresholds"]:
184
            X_thresholds[test_index, -2] = fold_index
185
186
    X_preds[:, -3] = r
187
    if prediction_task.config["collect_thresholds"]:
188
        X_thresholds[:, -1] = r
189
    X_preds[:, -2] = y["time"]
190
    X_preds[:, -1] = y["event"]
191
    df_pred = (
192
        pd.DataFrame(
193
            X_preds,
194
            columns=prediction_task.names
195
            + ["fold_index", "repeat", "label.time", "label.event"],
196
            index=prediction_task.data_concat.index,
197
        )
198
        .reset_index()
199
        .rename(columns={"index": "samples"})
200
        .set_index(["repeat", "samples"])
201
    )
202
203
    if prediction_task.config["collect_thresholds"]:
204
        df_thrs = (
205
            pd.DataFrame(
206
                X_thresholds,
207
                columns=prediction_task.names + ["fold_index", "repeat"],
208
                index=prediction_task.data_concat.index,
209
            )
210
            .reset_index()
211
            .rename(columns={"index": "samples"})
212
            .set_index(["repeat", "samples"])
213
        )
214
    else:
215
        df_thrs = None
216
217
    permut_predictions = None
218
    permut_labels = None
219
    if prediction_task.config["permutation_test"]:
220
        permut_labels = np.zeros((len(y), prediction_task.config["n_permutations"], 2))
221
        permut_predictions = np.zeros(
222
            (
223
                len(y),
224
                len(prediction_task.names),
225
                prediction_task.config["n_permutations"],
226
            )
227
        )
228
        for prm in range(prediction_task.config["n_permutations"]):
229
            X_perm = np.zeros((len(y), len(prediction_task.names)))
230
            random_state = check_random_state(prm)
231
            sh_ind = random_state.permutation(len(y))
232
            yshuffle = np.copy(y)[sh_ind]
233
            permut_labels[:, prm, 0] = yshuffle["time"]
234
            permut_labels[:, prm, 1] = yshuffle["event"]
235
            for fold_index, (train_index, test_index) in tqdm(
236
                enumerate(cv.split(np.zeros(len(y)), y)),
237
                leave=False,
238
                total=cv.get_n_splits(np.zeros(len(y))),
239
                disable=disable_infos,
240
            ):
241
                X_train, y_train, X_test, y_test = (
242
                    X[train_index, :],
243
                    yshuffle[train_index],
244
                    X[test_index, :],
245
                    yshuffle[test_index],
246
                )
247
                clf = clone(late_clf)
248
                clf.fit(X_train, y_train)
249
250
                for c, idx in enumerate(prediction_task.indices):
251
                    X_perm[test_index, c] = clf.predict(X_test, estim_ind=idx)
252
            permut_predictions[:, :, prm] = X_perm
253
    return df_pred, df_thrs, permut_predictions, permut_labels
254
255
256
if __name__ == "__main__":
257
    args = argparse.ArgumentParser(description="Late fusion")
258
    args.add_argument(
259
        "-c",
260
        "--config",
261
        type=str,
262
        help="config file path",
263
    )
264
    args.add_argument(
265
        "-s",
266
        "--save_path",
267
        type=str,
268
        help="save path",
269
    )
270
    main(params=args.parse_args())