Diff of /scripts/earlyfusion.py [000000] .. [efd906]

Switch to unified view

a b/scripts/earlyfusion.py
1
import argparse
2
import inspect
3
import os
4
import sys
5
from datetime import datetime
6
7
import numpy as np
8
import pandas as pd
9
from joblib import delayed
10
from sklearn.model_selection import StratifiedKFold
11
from tqdm import tqdm
12
13
from _init_scripts import PredictionTask
14
from _utils import read_yaml, write_yaml, ProgressParallel
15
16
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
17
parentdir = os.path.dirname(currentdir)
18
sys.path.insert(0, parentdir)
19
20
from multipit.multi_model.earlyfusion import EarlyFusionClassifier
21
22
23
def main(params):
24
    """
25
    Repeated cross-validation experiment for classification with late fusion
26
    """
27
28
    # 0. Read config file and save it in the results
29
    config = read_yaml(params.config)
30
    save_name = config["save_name"]
31
    if save_name is None:
32
        run_id = datetime.now().strftime(r"%m%d_%H%M%S")
33
        save_name = "exp_" + run_id
34
    save_dir = os.path.join(params.save_path, save_name)
35
    os.mkdir(save_dir)
36
    write_yaml(config, os.path.join(save_dir, "config.yaml"))
37
38
    # 1. fix random seeds for reproducibility
39
    seed = config["earlyfusion"]["seed"]
40
    np.random.seed(seed)
41
    config = read_yaml(params.config)
42
43
    # 2. Load data and define pipeline
44
    ptask = PredictionTask(config, survival=False, integration="early")
45
    ptask.load_data()
46
    X, y = ptask.data_concat.values, ptask.labels.loc[ptask.data_concat.index].values
47
    ptask.init_pipelines_earlyfusion()
48
49
    # 3. Perform repeated cross-validation
50
    parallel = ProgressParallel(
51
        n_jobs=config["parallelization"]["n_jobs_repeats"],
52
        total=config["earlyfusion"]["n_repeats"],
53
    )
54
    results_parallel = parallel(
55
        delayed(_fun_repeats)(
56
            ptask,
57
            X,
58
            y,
59
            r,
60
            disable_infos=(config["parallelization"]["n_jobs_repeats"] is not None)
61
            and (config["parallelization"]["n_jobs_repeats"] > 1),
62
        )
63
        for r in range(config["earlyfusion"]["n_repeats"])
64
    )
65
66
    # 4. Save results
67
    list_data_preds, list_data_thrs = [], []
68
    for p, res in enumerate(results_parallel):
69
        list_data_preds.append(res[0])
70
        list_data_thrs.append(res[1])
71
    data_preds = pd.concat(list_data_preds, axis=0)
72
    data_preds.to_csv(os.path.join(save_dir, "predictions.csv"))
73
    if config["collect_thresholds"]:
74
        data_thrs = pd.concat(list_data_thrs, axis=0)
75
        data_thrs.to_csv(os.path.join(save_dir, "thresholds.csv"))
76
77
78
def _fun_repeats(prediction_task, X, y, r, disable_infos):
79
    """
80
    Train and test an early fusion model for classification with cross-validation
81
82
    Parameters
83
    ----------
84
    prediction_task: PredictionTask object
85
86
    X: 2D array of shape (n_samples, n_features)
87
        Concatenation of the different modalities
88
89
    y: 1D array of shape (n_samples,)
90
        Binary outcome
91
92
    r: int
93
        Repeat number
94
95
    disable_infos: bool
96
97
    Returns
98
    -------
99
    df_pred: pd.DataFrame of shape (n_samples, n_models+3)
100
        Predictions collected over the test sets of the cross-validation scheme for each multimodal combination
101
102
    df_thrs: pd.DataFrame of shape (n_samples, n_models+2), None
103
        Thresholds that optimize the log-rank test on the training set for each fold and each multimodal combination.
104
    """
105
    cv = StratifiedKFold(n_splits=10, shuffle=True)
106
    X_preds = np.zeros((len(y), 3 + len(prediction_task.names)))
107
    X_thresholds = (
108
        np.zeros((len(y), 2 + len(prediction_task.names)))
109
        if prediction_task.config["collect_thresholds"]
110
        else None
111
    )
112
113
    for fold_index, (train_index, test_index) in tqdm(
114
        enumerate(cv.split(np.zeros(len(y)), y)),
115
        leave=False,
116
        total=cv.get_n_splits(np.zeros(len(y))),
117
        disable=disable_infos,
118
    ):
119
        X_train, y_train, X_test, y_test = (
120
            X[train_index, :],
121
            y[train_index],
122
            X[test_index, :],
123
            y[test_index],
124
        )
125
        target_surv_train = prediction_task.target_surv[train_index]
126
        cv_inner = StratifiedKFold(
127
            n_splits=10, shuffle=True, random_state=np.random.seed(r)
128
        )
129
        for c, models in enumerate(prediction_task.names):
130
            t = {
131
                model: prediction_task.early_transformers[model]
132
                for model in models.split("+")
133
            }
134
            early_clf = EarlyFusionClassifier(
135
                estimator=prediction_task.early_estimator,
136
                transformers=t,
137
                modalities={
138
                    model: prediction_task.dic_modalities[model]
139
                    for model in models.split("+")
140
                },
141
                cv=cv_inner,
142
                **prediction_task.config["earlyfusion"]["classifier_args"]
143
            )
144
            if len(models.split("+")) == 1:
145
                early_clf.set_params(**{"select_features": False})
146
147
            early_clf.fit(X_train, y_train)
148
            X_preds[test_index, c] = early_clf.predict_proba(X_test)[:, 1]
149
            if prediction_task.config["collect_thresholds"]:
150
                X_thresholds[test_index, c] = early_clf.find_logrank_threshold(
151
                    X_train, target_surv_train
152
                )
153
154
        X_preds[test_index, -3] = fold_index
155
        if prediction_task.config["collect_thresholds"]:
156
            X_thresholds[test_index, -2] = fold_index
157
158
    X_preds[:, -2] = r
159
    if prediction_task.config["collect_thresholds"]:
160
        X_thresholds[:, -1] = r
161
    X_preds[:, -1] = y
162
    df_pred = (
163
        pd.DataFrame(
164
            X_preds,
165
            columns=prediction_task.names + ["fold_index", "repeat", "label"],
166
            index=prediction_task.data_concat.index,
167
        )
168
        .reset_index()
169
        .rename(columns={"index": "samples"})
170
        .set_index(["repeat", "samples"])
171
    )
172
    if prediction_task.config["collect_thresholds"]:
173
        df_thrs = (
174
            pd.DataFrame(
175
                X_thresholds,
176
                columns=prediction_task.names + ["fold_index", "repeat"],
177
                index=prediction_task.data_concat.index,
178
            )
179
            .reset_index()
180
            .rename(columns={"index": "samples"})
181
            .set_index(["repeat", "samples"])
182
        )
183
    else:
184
        df_thrs = None
185
186
    return df_pred, df_thrs
187
188
189
if __name__ == "__main__":
190
    args = argparse.ArgumentParser(description="Early fusion")
191
    args.add_argument(
192
        "-c",
193
        "--config",
194
        type=str,
195
        help="config file path",
196
    )
197
    args.add_argument(
198
        "-s",
199
        "--save_path",
200
        type=str,
201
        help="save path",
202
    )
203
    main(params=args.parse_args())