Switch to unified view

a b/scripts/earlyfusion_survival.py
1
import argparse
2
import inspect
3
import os
4
import sys
5
from datetime import datetime
6
7
import numpy as np
8
import pandas as pd
9
from joblib import delayed
10
from sksurv.util import Surv
11
from tqdm import tqdm
12
13
from _init_scripts import PredictionTask
14
from _utils import read_yaml, write_yaml, ProgressParallel
15
16
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
17
parentdir = os.path.dirname(currentdir)
18
sys.path.insert(0, parentdir)
19
20
from multipit.multi_model.earlyfusion import EarlyFusionSurvival
21
from multipit.utils.custom.cv import CensoredKFold
22
23
24
def main(params):
25
    """
26
    Repeated cross-validation experiment for survival prediction with early fusion
27
    """
28
29
    # 0. Read config file and save it in the results
30
    config = read_yaml(params.config)
31
    save_name = config["save_name"]
32
    if save_name is None:
33
        run_id = datetime.now().strftime(r"%m%d_%H%M%S")
34
        save_name = "exp_" + run_id
35
36
    save_dir = os.path.join(params.save_path, save_name)
37
    os.mkdir(save_dir)
38
    write_yaml(config, os.path.join(save_dir, "config.yaml"))
39
40
    # 1. fix random seeds for reproducibility
41
    seed = config["earlyfusion"]["seed"]
42
    np.random.seed(seed)
43
44
    # 2. Load data and define pipelines for each modality
45
    ptask = PredictionTask(config, survival=True, integration="early")
46
    ptask.load_data()
47
    X = ptask.data_concat.values
48
    y = Surv().from_arrays(
49
        event=ptask.labels.loc[ptask.data_concat.index, "event"].values,
50
        time=ptask.labels.loc[ptask.data_concat.index, "time"].values,
51
    )
52
    ptask.init_pipelines_earlyfusion()
53
54
    # 5. Define function to apply for each cross-validation scheme
55
56
    # 6. Perform repeated cross-validation
57
    parallel = ProgressParallel(
58
        n_jobs=config["parallelization"]["n_jobs_repeats"],
59
        total=config["earlyfusion"]["n_repeats"],
60
    )
61
    results_parallel = parallel(
62
        delayed(_fun_repeats)(
63
            ptask,
64
            X,
65
            y,
66
            r,
67
            disable_infos=(config["parallelization"]["n_jobs_repeats"] is not None)
68
            and (config["parallelization"]["n_jobs_repeats"] > 1),
69
        )
70
        for r in range(config["earlyfusion"]["n_repeats"])
71
    )
72
73
    # 7. Save results
74
    list_data_preds, list_data_thrs = [], []
75
    for res in results_parallel:
76
        list_data_preds.append(res[0])
77
        list_data_thrs.append(res[1])
78
    data_preds = pd.concat(list_data_preds, axis=0)
79
    data_preds.to_csv(os.path.join(save_dir, "predictions.csv"))
80
    if config["collect_thresholds"]:
81
        data_thrs = pd.concat(list_data_thrs, axis=0)
82
        data_thrs.to_csv(os.path.join(save_dir, "thresholds.csv"))
83
84
85
def _fun_repeats(prediction_task, X, y, r, disable_infos):
86
    """
87
    Train and test an early fusion model for survival task with cross-validation
88
89
    Parameters
90
    ----------
91
    prediction_task: PredictionTask object
92
93
    X: 2D array of shape (n_samples, n_features)
94
        Concatenation of the different modalities
95
96
    y: Structured array of size (n_samples,)
97
        Event indicator and observed time for each sample
98
99
    r: int
100
        Repeat number
101
102
    disable_infos: bool
103
104
    Returns
105
    -------
106
    df_pred: pd.DataFrame of shape (n_samples, n_models+4)
107
        Predictions collected over the test sets of the cross-validation scheme for each multimodal combination
108
109
    df_thrs: pd.DataFrame of shape (n_samples, n_models+2), None
110
        Thresholds that optimize the log-rank test on the training set for each fold and each multimodal combination.
111
    """
112
    cv = CensoredKFold(n_splits=10, shuffle=True)  # , random_state=np.random.seed(i))
113
    X_preds = np.zeros((len(y), 4 + len(prediction_task.names)))
114
    X_thresholds = (
115
        np.zeros((len(y), 2 + len(prediction_task.names)))
116
        if prediction_task.config["collect_thresholds"]
117
        else None
118
    )
119
    for fold_index, (train_index, test_index) in tqdm(
120
        enumerate(cv.split(np.zeros(len(y)), y)),
121
        leave=False,
122
        total=cv.get_n_splits(np.zeros(len(y))),
123
        disable=disable_infos,
124
    ):
125
        X_train, y_train, X_test, y_test = (
126
            X[train_index, :],
127
            y[train_index],
128
            X[test_index, :],
129
            y[test_index],
130
        )
131
132
        cv_inner = CensoredKFold(
133
            n_splits=10, shuffle=True, random_state=np.random.seed(r)
134
        )
135
        for c, models in enumerate(prediction_task.names):
136
            t = {
137
                model: prediction_task.early_transformers[model]
138
                for model in models.split("+")
139
            }
140
            early_surv = EarlyFusionSurvival(
141
                estimator=prediction_task.early_estimator,
142
                transformers=t,
143
                modalities={
144
                    model: prediction_task.dic_modalities[model]
145
                    for model in models.split("+")
146
                },
147
                cv=cv_inner,
148
                **prediction_task.config["earlyfusion"]["args"]
149
            )
150
            if len(models.split("+")) == 1:
151
                early_surv.set_params(**{"select_features": False})
152
153
            early_surv.fit(X_train, y_train)
154
            X_preds[test_index, c] = early_surv.predict(X_test)
155
            if prediction_task.config["collect_thresholds"]:
156
                X_thresholds[test_index, c] = early_surv.find_logrank_threshold(
157
                    X_train, y_train
158
                )
159
        X_preds[test_index, -4] = fold_index
160
        if prediction_task.config["collect_thresholds"]:
161
            X_thresholds[test_index, -2] = fold_index
162
163
    X_preds[:, -3] = r
164
    if prediction_task.config["collect_thresholds"]:
165
        X_thresholds[:, -1] = r
166
    X_preds[:, -2] = y["time"]
167
    X_preds[:, -1] = y["event"]
168
    df_pred = (
169
        pd.DataFrame(
170
            X_preds,
171
            columns=prediction_task.names
172
            + ["fold_index", "repeat", "label.time", "label.event"],
173
            index=prediction_task.data_concat.index,
174
        )
175
        .reset_index()
176
        .rename(columns={"index": "samples"})
177
        .set_index(["repeat", "samples"])
178
    )
179
180
    if prediction_task.config["collect_thresholds"]:
181
        df_thrs = (
182
            pd.DataFrame(
183
                X_thresholds,
184
                columns=prediction_task.names + ["fold_index", "repeat"],
185
                index=prediction_task.data_concat.index,
186
            )
187
            .reset_index()
188
            .rename(columns={"index": "samples"})
189
            .set_index(["repeat", "samples"])
190
        )
191
    else:
192
        df_thrs = None
193
194
    return df_pred, df_thrs
195
196
197
if __name__ == "__main__":
198
    args = argparse.ArgumentParser(description="Early fusion")
199
    args.add_argument(
200
        "-c",
201
        "--config",
202
        type=str,
203
        help="config file path",
204
    )
205
    args.add_argument(
206
        "-s",
207
        "--save_path",
208
        type=str,
209
        help="save path",
210
    )
211
    main(params=args.parse_args())