Diff of /scripts/collect_shap.py [000000] .. [efd906]

Switch to side-by-side view

--- a
+++ b/scripts/collect_shap.py
@@ -0,0 +1,218 @@
+import argparse
+import inspect
+import os
+import sys
+
+# import warnings
+from datetime import datetime
+
+import numpy as np
+import pandas as pd
+import shap
+from joblib import delayed
+from sklearn.base import clone
+from sklearn.model_selection import StratifiedKFold
+from tqdm import tqdm
+
+from _init_scripts import PredictionTask
+from _utils import read_yaml, write_yaml, ProgressParallel
+
+currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
+parentdir = os.path.dirname(currentdir)
+sys.path.insert(0, parentdir)
+
+from multipit.multi_model.latefusion import LateFusionClassifier
+
+
+def main(params):
+    """ """
+
+    # 0. Read config file and save it in the results
+    config = read_yaml(params.config)
+    save_name = config["save_name"]
+    if save_name is None:
+        run_id = datetime.now().strftime(r"%m%d_%H%M%S")
+        save_name = "exp_" + run_id
+    save_dir = os.path.join(params.save_path, save_name)
+    os.mkdir(save_dir)
+    write_yaml(config, os.path.join(save_dir, "config.yaml"))
+
+    # 1. fix random seeds for reproducibility
+    seed = config["latefusion"]["seed"]
+    np.random.seed(seed)
+
+    # 2. Load data and define pipelines for each modality
+    ptask = PredictionTask(config, survival=False, integration="late")
+    ptask.load_data()
+    X, y = ptask.data_concat.values, ptask.labels.loc[ptask.data_concat.index].values
+    ptask.init_pipelines_latefusion()
+
+    parallel = ProgressParallel(
+        n_jobs=config["parallelization"]["n_jobs_repeats"],
+        total=config["latefusion"]["n_repeats"],
+    )
+    list_shap = parallel(
+        delayed(_fun_parallel)(
+            ptask,
+            X,
+            y,
+            r,
+            disable_infos=(config["parallelization"]["n_jobs_repeats"] is not None)
+            and (config["parallelization"]["n_jobs_repeats"] > 1),
+        )
+        for r in range(config["latefusion"]["n_repeats"])
+    )
+
+    shap_explain = {"clinical": [], "radiomics": [], "pathomics": [], "RNA": []}
+    coefs_LR = {"clinical": [], "radiomics": [], "pathomics": [], "RNA": []}
+
+    for results in list_shap:
+        for moda, shapley in results[0].items():
+            shap_explain[moda].append(shapley)
+
+    for key, val in shap_explain.items():
+        df_shap = pd.concat(val, axis=0, join="outer")
+        df_shap.to_csv(os.path.join(save_dir, "Shap_" + key + ".csv"))
+
+    if config["classifier"]["type"] == "LR":
+        for results in list_shap:
+            for moda, coefs in results[1].items():
+                coefs_LR[moda].append(coefs)
+
+        for key, val in coefs_LR.items():
+            coefficients = np.stack(val, axis=-1)
+            np.save(os.path.join(save_dir, "coef_LR_" + key + ".npy"), coefficients)
+
+
+def _fun_parallel(prediction_task, X, y, r, disable_infos):
+    """
+    Collect SHAP values for several unimodal classifiers with cross-validation
+
+    Parameters
+    ----------
+    prediction_task: PredictionTask object
+
+    X: 2D array of shape (n_samples, n_features)
+        Concatenation of the different modalities
+
+    y: 1D array of shape (n_samples,)
+        Binary outcome
+
+    r: int
+        Repeat number
+
+    disable_infos: bool
+
+    Returns
+    -------
+    shap_dict: dictionary
+        Dictionary whose keys correspond to the different modalities (e.g., "RNA", "clinical") and the items correspond
+        to pandas dataframe of size (n_samples, n_features) that contain the SHAP values collected across the test sets
+        of the cross-validation scheme.
+
+    coefs_dict: dictionary or None
+        Dictionary whose keys correspond to the different modalities (e.g., "RNA", "clinical") and the items correspond
+        to arrays of size (n_folds, n_features) that contain the linear coefficients collected across the different
+        folds of the cross-validation scheme. None if the classifier type is not linear.
+    """
+
+    cv = StratifiedKFold(n_splits=10, shuffle=True)
+    late_clf = LateFusionClassifier(
+        estimators=prediction_task.late_estimators,
+        cv=StratifiedKFold(n_splits=10, shuffle=True, random_state=np.random.seed(r)),
+        **prediction_task.config["latefusion"]["args"]
+    )
+
+    shap_dict = {name: [] for name, *_ in late_clf.estimators}
+
+    if prediction_task.config["classifier"]["type"] == "LR":
+        coef_dict = {name: [] for name, *_ in late_clf.estimators}
+
+    for fold_index, (train_index, test_index) in tqdm(
+        enumerate(cv.split(np.zeros(len(y)), y)),
+        leave=False,
+        total=cv.get_n_splits(np.zeros(len(y))),
+        disable=disable_infos,
+    ):
+        X_train, y_train, X_test, y_test = (
+            X[train_index, :],
+            y[train_index],
+            X[test_index, :],
+            y[test_index],
+        )
+        # Fit late fusion on the training set of the fold
+        clf = clone(late_clf)
+        clf.fit(X_train, y_train)
+        # Collect SHAP values on the test set of the fold for each unimodal classifier
+        for ind, (name, estim, features) in enumerate(clf.fitted_estimators_):
+            X_background = X_train[:, features]
+            bool_mask = ~(
+                np.sum(np.isnan(X_background), axis=1)
+                > clf.missing_threshold * len(features)
+            )
+            X_explain = X_test[:, features]
+            bool_mask_explain = ~(
+                np.sum(np.isnan(X_explain), axis=1)
+                > clf.missing_threshold * len(features)
+            )
+            if clf.calibration is not None:
+                explainer = shap.Explainer(
+                    lambda x: (
+                        clf.fitted_meta_estimators_[(ind,)].predict_proba(
+                            estim.predict_proba(x)[:, 1].reshape(-1, 1)
+                        )
+                    ),
+                    X_background[bool_mask, :],
+                )
+            else:
+                explainer = shap.Explainer(
+                    lambda x: estim.predict_proba(x), X_background[bool_mask, :]
+                )
+            shap_values = explainer(X_explain[bool_mask_explain, :])
+            shap_df = pd.DataFrame(
+                shap_values.values[:, :, 1],
+                columns=prediction_task.data_concat.columns[features],
+                index=prediction_task.data_concat.index.values[
+                    test_index[bool_mask_explain]
+                ],
+            )
+            shap_df["fold_index"] = fold_index
+            shap_df["repeat"] = r
+            shap_dict[name].append(shap_df)
+            # Also collect coefficients for logistic regreression
+            if prediction_task.config["classifier"]["type"] == "LR":
+                if name == "RNA":
+                    temp = np.zeros((1, 40))
+                    temp[:, : estim[-1].coef_.shape[1]] = estim[-1].coef_
+                    coef_dict[name].append(temp)
+                else:
+                    coef_dict[name].append(estim[-1].coef_)
+
+    if prediction_task.config["classifier"]["type"] == "LR":
+        coefs_dict = {name: np.vstack(value) for name, value in coef_dict.items()}
+    else:
+        coefs_dict = None
+
+    shap_dict = {
+        name: pd.concat(value, axis=0, join="outer")
+        for name, value in shap_dict.items()
+    }
+
+    return shap_dict, coefs_dict
+
+
+if __name__ == "__main__":
+    args = argparse.ArgumentParser(description="Collect Shap")
+    args.add_argument(
+        "-c",
+        "--config",
+        type=str,
+        help="config file path",
+    )
+    args.add_argument(
+        "-s",
+        "--save_path",
+        type=str,
+        help="save path",
+    )
+    main(params=args.parse_args())