Diff of /scripts/latefusion.py [000000] .. [efd906]

Switch to unified view

a b/scripts/latefusion.py
1
import argparse
2
import inspect
3
import os
4
import sys
5
6
# import warnings
7
from datetime import datetime
8
9
import numpy as np
10
import pandas as pd
11
from joblib import delayed
12
from sklearn.base import clone
13
from sklearn.model_selection import StratifiedKFold
14
from sklearn.utils import check_random_state
15
from tqdm import tqdm
16
17
from _init_scripts import PredictionTask
18
from _utils import read_yaml, write_yaml, ProgressParallel
19
20
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
21
parentdir = os.path.dirname(currentdir)
22
sys.path.insert(0, parentdir)
23
24
from multipit.multi_model.latefusion import LateFusionClassifier
25
26
27
def main(params):
28
    """
29
    Repeated cross-validation experiment for classification with late fusion
30
    """
31
32
    # Uncomment for disabling ConvergenceWarning
33
    # warnings.simplefilter("ignore")
34
    # os.environ["PYTHONWARNINGS"] = 'ignore'
35
36
    # 0. Read config file and save it in the results
37
    config = read_yaml(params.config)
38
    save_name = config["save_name"]
39
    if save_name is None:
40
        run_id = datetime.now().strftime(r"%m%d_%H%M%S")
41
        save_name = "exp_" + run_id
42
    save_dir = os.path.join(params.save_path, save_name)
43
    os.mkdir(save_dir)
44
    write_yaml(config, os.path.join(save_dir, "config.yaml"))
45
46
    # 1. fix random seeds for reproducibility
47
    seed = config["latefusion"]["seed"]
48
    np.random.seed(seed)
49
50
    # 2. Load data and define pipelines for each modality
51
    ptask = PredictionTask(config, survival=False, integration="late")
52
    ptask.load_data()
53
    X, y = ptask.data_concat.values, ptask.labels.loc[ptask.data_concat.index].values
54
    ptask.init_pipelines_latefusion()
55
56
    # 3. Perform repeated cross-validation
57
    parallel = ProgressParallel(
58
        n_jobs=config["parallelization"]["n_jobs_repeats"],
59
        total=config["latefusion"]["n_repeats"],
60
    )
61
    results_parallel = parallel(
62
        delayed(_fun_repeats)(
63
            ptask,
64
            X,
65
            y,
66
            r,
67
            disable_infos=(config["parallelization"]["n_jobs_repeats"] is not None)
68
            and (config["parallelization"]["n_jobs_repeats"] > 1),
69
        )
70
        for r in range(config["latefusion"]["n_repeats"])
71
    )
72
73
    # 4. Save results
74
    if config["permutation_test"]:
75
        perm_predictions = np.zeros(
76
            (
77
                len(y),
78
                len(ptask.names),
79
                config["n_permutations"],
80
                config["latefusion"]["n_repeats"],
81
            )
82
        )
83
        list_data_preds, list_data_thrs = [], []
84
        for p, res in enumerate(results_parallel):
85
            list_data_preds.append(res[0])
86
            list_data_thrs.append(res[1])
87
            perm_predictions[:, :, :, p] = res[2]
88
        perm_labels = results_parallel[-1][3]
89
90
        np.save(os.path.join(save_dir, "permutation_labels.npy"), perm_labels)
91
        np.save(os.path.join(save_dir, "permutation_predictions.npy"), perm_predictions)
92
        data_preds = pd.concat(list_data_preds, axis=0)
93
        data_preds.to_csv(os.path.join(save_dir, "predictions.csv"))
94
        if config["collect_thresholds"]:
95
            data_thrs = pd.concat(list_data_thrs, axis=0)
96
            data_thrs.to_csv(os.path.join(save_dir, "thresholds.csv"))
97
    else:
98
        list_data_preds, list_data_thrs = [], []
99
        for p, res in enumerate(results_parallel):
100
            list_data_preds.append(res[0])
101
            list_data_thrs.append(res[1])
102
        data_preds = pd.concat(list_data_preds, axis=0)
103
        data_preds.to_csv(os.path.join(save_dir, "predictions.csv"))
104
        if config["collect_thresholds"]:
105
            data_thrs = pd.concat(list_data_thrs, axis=0)
106
            data_thrs.to_csv(os.path.join(save_dir, "thresholds.csv"))
107
108
109
def _fun_repeats(prediction_task, X, y, r, disable_infos):
110
    """
111
    Train and test a late fusion model for classification with cross-validation
112
113
    Parameters
114
    ----------
115
    prediction_task: PredictionTask object
116
117
    X: 2D array of shape (n_samples, n_features)
118
        Concatenation of the different modalities
119
120
    y: 1D array of shape (n_samples,)
121
        Binary outcome
122
123
    r: int
124
        Repeat number
125
126
    disable_infos: bool
127
128
    Returns
129
    -------
130
    df_pred: pd.DataFrame of shape (n_samples, n_models+3)
131
        Predictions collected over the test sets of the cross-validation scheme for each multimodal combination
132
133
    df_thrs: pd.DataFrame of shape (n_samples, n_models+2), None
134
        Thresholds that optimize the log-rank test on the training set for each fold and each multimodal combination.
135
136
    permut_predictions: 3D array of shape (n_samples, n_models, n_permutations)
137
        Predictions collected over the test sets of the cross_validation scheme for each multimodal combination and each random permutation of the labels.
138
139
    permut_labels: 2D array of shape (n_samples, n_permutations)
140
        Permuted labels
141
    """
142
    cv = StratifiedKFold(n_splits=10, shuffle=True)  # , random_state=np.random.seed(i))
143
    X_preds = np.zeros((len(y), 3 + len(prediction_task.names)))
144
    X_thresholds = (
145
        np.zeros((len(y), 2 + len(prediction_task.names)))
146
        if prediction_task.config["collect_thresholds"]
147
        else None
148
    )
149
150
    late_clf = LateFusionClassifier(
151
        estimators=prediction_task.late_estimators,
152
        cv=StratifiedKFold(n_splits=10, shuffle=True, random_state=np.random.seed(r)),
153
        **prediction_task.config["latefusion"]["args"]
154
    )
155
156
    # 1. Cross-validation scheme
157
    for fold_index, (train_index, test_index) in tqdm(
158
        enumerate(cv.split(np.zeros(len(y)), y)),
159
        leave=False,
160
        total=cv.get_n_splits(np.zeros(len(y))),
161
        disable=disable_infos,
162
    ):
163
        X_train, y_train, X_test = (
164
            X[train_index, :],
165
            y[train_index],
166
            X[test_index, :],
167
        )
168
        target_surv_train = prediction_task.target_surv[train_index]
169
        # Fit late fusion on the training set of the fold
170
        clf = clone(late_clf)
171
        clf.fit(X_train, y_train)
172
        # Collect predictions on the test set of the fold for each multimodal combination
173
        for c, idx in enumerate(prediction_task.indices):
174
            X_preds[test_index, c] = clf.predict_proba(X_test, estim_ind=idx)[:, 1]
175
            # Collect the threshold that optimizes log-rank test on the training set
176
            if prediction_task.config["collect_thresholds"]:
177
                X_thresholds[test_index, c] = clf.find_logrank_threshold(
178
                    X_train, target_surv_train, estim_ind=idx
179
                )
180
        X_preds[test_index, -3] = fold_index
181
        if prediction_task.config["collect_thresholds"]:
182
            X_thresholds[test_index, -2] = fold_index
183
184
    X_preds[:, -2] = r
185
    if prediction_task.config["collect_thresholds"]:
186
        X_thresholds[:, -1] = r
187
    X_preds[:, -1] = y
188
189
    df_pred = (
190
        pd.DataFrame(
191
            X_preds,
192
            columns=prediction_task.names + ["fold_index", "repeat", "label"],
193
            index=prediction_task.data_concat.index,
194
        )
195
        .reset_index()
196
        .rename(columns={"index": "samples"})
197
        .set_index(["repeat", "samples"])
198
    )
199
200
    if prediction_task.config["collect_thresholds"]:
201
        df_thrs = (
202
            pd.DataFrame(
203
                X_thresholds,
204
                columns=prediction_task.names + ["fold_index", "repeat"],
205
                index=prediction_task.data_concat.index,
206
            )
207
            .reset_index()
208
            .rename(columns={"index": "samples"})
209
            .set_index(["repeat", "samples"])
210
        )
211
    else:
212
        df_thrs = None
213
214
    # 2. Perform permutation test
215
    permut_predictions = None
216
    permut_labels = None
217
    if prediction_task.config["permutation_test"]:
218
        permut_labels = np.zeros((len(y), prediction_task.config["n_permutations"]))
219
        permut_predictions = np.zeros(
220
            (
221
                len(y),
222
                len(prediction_task.names),
223
                prediction_task.config["n_permutations"],
224
            )
225
        )
226
        for prm in range(prediction_task.config["n_permutations"]):
227
            X_perm = np.zeros((len(y), len(prediction_task.names)))
228
            random_state = check_random_state(prm)
229
            sh_ind = random_state.permutation(len(y))
230
            yshuffle = np.copy(y)[sh_ind]
231
            permut_labels[:, prm] = yshuffle
232
            for fold_index, (train_index, test_index) in tqdm(
233
                enumerate(cv.split(np.zeros(len(y)), y)),
234
                leave=False,
235
                total=cv.get_n_splits(np.zeros(len(y))),
236
                disable=disable_infos,
237
            ):
238
                X_train, yshuffle_train, X_test = (
239
                    X[train_index, :],
240
                    yshuffle[train_index],
241
                    X[test_index, :],
242
                )
243
                clf = clone(late_clf)
244
                clf.fit(X_train, yshuffle_train)
245
246
                for c, idx in enumerate(prediction_task.indices):
247
                    X_perm[test_index, c] = clf.predict_proba(X_test, estim_ind=idx)[
248
                        :, 1
249
                    ]
250
            permut_predictions[:, :, prm] = X_perm
251
    return df_pred, df_thrs, permut_predictions, permut_labels
252
253
254
if __name__ == "__main__":
255
    args = argparse.ArgumentParser(description="Late fusion")
256
    args.add_argument(
257
        "-c",
258
        "--config",
259
        type=str,
260
        help="config file path",
261
    )
262
    args.add_argument(
263
        "-s",
264
        "--save_path",
265
        type=str,
266
        help="save path",
267
    )
268
    main(params=args.parse_args())