--- a
+++ b/src/scpanel/train.py
@@ -0,0 +1,766 @@
+import os
+import pickle
+import time
+
+from sklearn.ensemble import RandomForestClassifier
+from sklearn.linear_model import LogisticRegression
+
+# import sklearn.linear_model as lm
+from sklearn.metrics import (
+    accuracy_score,
+    auc,
+    average_precision_score,
+    balanced_accuracy_score,
+    classification_report,
+    confusion_matrix,
+    matthews_corrcoef,
+    recall_score,
+    roc_auc_score,
+)
+from sklearn.model_selection import GridSearchCV
+from sklearn.neighbors import KNeighborsClassifier
+from sklearn.svm import SVC
+
+from .GATclassifier import GATclassifier
+
+# import pandas as pd
+# import numpy as np
+from .utils_func import *
+import sklearn.ensemble._forest
+import sklearn.linear_model._logistic
+import sklearn.neighbors._classification
+import sklearn.svm._classes
+from anndata._core.anndata import AnnData
+from matplotlib.axes._axes import Axes
+from matplotlib.figure import Figure
+from matplotlib.gridspec import GridSpec
+from numpy import float64, ndarray
+from pandas.core.frame import DataFrame
+from pandas.core.series import Series
+from scpanel.GATclassifier import GATclassifier
+from torch import Tensor
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+
+def transform_adata(adata_train: AnnData, adata_test_dict: Dict[str, AnnData], selected_gene: Optional[List[str]]=None) -> Tuple[AnnData, AnnData]:
+    ## Transforming train set and test set from the same dataset (batch effect free)
+    ## subset adata_train with selected genes
+    ## subset adata_test_dict with selected cell types and genes
+    ## WATCH OUT: X matrix in adata_test_dict is log-normalized, need to scale further
+    if selected_gene == None:
+        selected_gene = adata_train.uns["svm_rfe_genes"]
+
+    adata_train_final = adata_train[:, selected_gene]
+
+    mean = adata_train_final.var["mean"].values
+    std = adata_train_final.var["std"].values
+
+    ct_selected = adata_train_final.obs.ct.unique()[0]
+
+    # transform test data with selected gene, celltype and scaling
+    adata_test = adata_test_dict[ct_selected].copy()
+    adata_test_final = adata_test[:, selected_gene].copy()
+
+    if isinstance(adata_test_final.X, np.ndarray):
+        test_X = adata_test_final.X
+    else:
+        test_X = adata_test_final.X.toarray()
+    test_X -= mean
+    test_X /= std
+
+    max_value = 10
+    test_X[test_X > max_value] = max_value
+    adata_test_final.X = test_X
+
+    return adata_train_final, adata_test_final
+
+
+def models_train(adata_train_final: AnnData, search_grid: bool, out_dir: Optional[str]=None, param_grid: Optional[Dict[str, Dict[str, int]]]=None) -> List[Union[Tuple[str, sklearn.linear_model._logistic.LogisticRegression], Tuple[str, sklearn.ensemble._forest.RandomForestClassifier], Tuple[str, sklearn.svm._classes.SVC], Tuple[str, sklearn.neighbors._classification.KNeighborsClassifier], Tuple[str, GATclassifier]]]:
+
+    X_tr, y_tr, adj_tr = get_X_y_from_ann(
+        adata_train_final, return_adj=True, n_neigh=10
+    )
+    sample_weight = compute_cell_weight(adata_train_final)
+
+    # Make sure no nan in matrix
+    X_tr = np.nan_to_num(X_tr)
+
+    grid_search = search_grid
+    models = [
+        ("LR", LogisticRegression(solver="saga", max_iter=500, random_state=42)),
+        ("RF", RandomForestClassifier(random_state=42)),
+        ("SVM", SVC(probability=True, random_state=42)),
+        ("KNN", KNeighborsClassifier()),
+        (
+            "GAT",
+            GATclassifier(
+                nFeatures=adata_train_final.n_vars, NumParts=10, nEpochs=1000, verbose=1
+            ),
+        ),
+    ]
+
+    # Parameter tuning grids-------------------------
+    LR_params = [{"C": [10, 1.0, 0.1, 0.01], "max_iter": [10, 50, 200, 500]}]
+    RF_params = [
+        {"max_depth": [2, 5, 10, 15, 20, 30, None], "n_estimators": [50, 100, 500]}
+    ]
+    SVM_params = [{"C": [100, 10, 1.0, 0.1, 0.001], "gamma": [1, 0.1, 0.01, 0.001]}]
+    KNN_params = [{"n_neighbors": [3, 5, 10, 20, 50], "p": [1, 2]}]
+
+    my_grid = {"LR": LR_params, "RF": RF_params, "SVM": SVM_params, "KNN": KNN_params}
+
+    clfs = []
+    names = []
+    runtimes = []
+    best_params = []
+
+    for name, model in models:
+        start_time = time.time()
+
+        if grid_search:
+            if name != "GAT":
+                clf = GridSearchCV(
+                    model, my_grid[name], cv=5, scoring="roc_auc", n_jobs=10
+                )
+            else:
+                clf = model
+        else:
+            clf = model
+            if param_grid is not None:
+                if name in param_grid:
+                    clf.set_params(**param_grid[name])
+
+        if name == "GAT":
+            clf.fit(X_tr, y_tr, adj_tr)
+        elif name == "KNN":
+            clf.fit(X_tr, y_tr)
+        else:
+            clf.fit(X_tr, y_tr, sample_weight=sample_weight)
+
+        runtime = time.time() - start_time
+
+        # save outputs
+        clfs.append((name, clf))
+        names.append(name)
+        runtimes.append(runtime)
+
+        print("---%s finished in %s seconds ---" % (name, runtime))
+
+    # save models
+    if out_dir is not None:
+
+        if not os.path.exists(out_dir):
+            os.makedirs(out_dir)
+
+        with open(f"{out_dir}/clfs.pkl", "wb") as f:
+            pickle.dump(clfs, f, protocol=pickle.HIGHEST_PROTOCOL)
+            f.close()
+
+        with open(f"{out_dir}/adata_train_final.pkl", "wb") as f:
+            pickle.dump(adata_train_final, f, protocol=pickle.HIGHEST_PROTOCOL)
+            f.close()
+
+    return clfs
+
+
+def models_predict(clfs: List[Union[Tuple[str, sklearn.linear_model._logistic.LogisticRegression], Tuple[str, sklearn.ensemble._forest.RandomForestClassifier], Tuple[str, sklearn.svm._classes.SVC], Tuple[str, sklearn.neighbors._classification.KNeighborsClassifier], Tuple[str, GATclassifier]]], adata_test_final: AnnData, out_dir: Optional[str]=None) -> Tuple[AnnData, List[Union[Tuple[str, ndarray], Tuple[str, Tensor]]], List[Tuple[str, ndarray]]]:
+    X_test, y_test, adj_test = get_X_y_from_ann(
+        adata_test_final, return_adj=True, n_neigh=10
+    )
+    X_test = np.nan_to_num(X_test)
+
+    ## Predicting---------------
+    y_pred_list = []
+    y_pred_score_list = []
+
+    for name, clf in clfs:
+        if name == "GAT":
+            y_pred = clf.predict(X_test, y_test, adj_test)
+            y_pred_score = clf.predict_proba(X_test, y_test, adj_test)
+        else:
+            y_pred = clf.predict(X_test)
+            y_pred_score = clf.predict_proba(X_test)
+
+        y_pred_list.append((name, y_pred))
+        y_pred_score_list.append((name, y_pred_score))
+
+    # add prediction result to adata_test_final
+    y_pred = pd.DataFrame(dict([(name + "_pred", pred) for name, pred in y_pred_list]))
+    y_pred_score = pd.DataFrame(
+        dict([(name + "_pred_score", pred[:, 1]) for name, pred in y_pred_score_list])
+    )
+
+    y_pred_df = pd.concat([y_pred, y_pred_score], axis=1)
+    y_pred_df.index = adata_test_final.obs.index
+
+    if set(y_pred_df.columns).issubset(set(adata_test_final.obs.columns)):
+        print("Prediction result already exits in test adata, overwrite it...")
+        adata_test_final.obs.update(y_pred_df)
+    else:
+        adata_test_final.obs = pd.concat([adata_test_final.obs, y_pred_df], axis=1)
+
+    # calcuate median prediction score out of 5 classifiers
+    pred_col = [
+        col for col in adata_test_final.obs.columns if col.endswith("_pred_score")
+    ]
+    adata_test_final.obs["median_pred_score"] = adata_test_final.obs[pred_col].median(
+        axis=1
+    )
+
+    return adata_test_final, y_pred_list, y_pred_score_list
+
+
+def models_score(adata_test_final, y_pred_list, y_pred_score_list, out_dir=None):
+    X_test, y_test = get_X_y_from_ann(adata_test_final)
+
+    ## Scoring-------------------------------------
+    ## define scoring metrics (from sklearn)
+    scorers = {
+        "accuracy": (accuracy_score, {}),
+        "balanced_accuracy": (balanced_accuracy_score, {}),
+        "MCC": (matthews_corrcoef, {}),
+    }  # Passing Dictionary as Arguments to Function
+
+    scorers_prob = {
+        "AUROC": (roc_auc_score, {}),
+        "AUPRC": (average_precision_score, {}),
+    }
+
+    ## calculate
+    eval_res_1 = pd.DataFrame()
+    for name, y_pred in y_pred_list:
+        eval_res_dict = dict(
+            [
+                (score_name, score_func(y_test, y_pred, **score_para))
+                for score_name, (score_func, score_para) in scorers.items()
+            ]
+        )
+        eval_res_i = pd.DataFrame(eval_res_dict, index=[name])
+
+        eval_res_1 = pd.concat(objs=[eval_res_1, eval_res_i], axis=0)
+
+    eval_res_2 = pd.DataFrame()
+    for name, y_pred_score in y_pred_score_list:
+        eval_res_dict = dict(
+            [
+                (score_name, score_func(y_test, y_pred_score[:, 1], **score_para))
+                for score_name, (score_func, score_para) in scorers_prob.items()
+            ]
+        )
+        eval_res_i = pd.DataFrame(eval_res_dict, index=[name])
+
+        eval_res_2 = pd.concat(objs=[eval_res_2, eval_res_i], axis=0)
+
+    eval_res = pd.concat(objs=[eval_res_2, eval_res_1], axis=1)
+
+    if out_dir is not None:
+        if not os.path.exists(out_dir):
+            os.makedirs(out_dir)
+
+        eval_res.to_csv(f"{out_dir}/eval_res.csv")
+
+    return eval_res
+
+
+def cal_sample_auc(df: DataFrame, score_col: str) -> float64:
+    cell_prob = df[score_col].sort_values()
+    # rank the cell probability ascendingly and normalize
+    cell_rank = cell_prob.rank(method="first") / cell_prob.rank(method="first").max()
+    sample_auc = auc(cell_rank, cell_prob)
+    return sample_auc
+
+
+def auc_pvalue(row: Series) -> float:
+    if row.name[1] == 1:
+        p_value = np.mean(row < 0.5)
+    elif row.name[1] == 0:
+        p_value = np.mean(row > 0.5)
+
+    if p_value == 0:
+        p_value = 1 / row.size
+    return p_value
+
+
+def pt_pred(adata_test_final: AnnData, cell_pred_col: str="median_pred_score", num_bootstrap: Optional[int]=None) -> AnnData:
+    sample_auc = adata_test_final.obs.groupby("patient_id").apply(
+        lambda df: cal_sample_auc(df, cell_pred_col)
+    )
+    adata_test_final.obs[cell_pred_col + "_sample_auc"] = (
+        adata_test_final.obs["patient_id"].map(sample_auc).astype(float)
+    )
+    adata_test_final.obs[cell_pred_col + "_sample_pred"] = (
+        adata_test_final.obs[cell_pred_col + "_sample_auc"] >= 0.5
+    ).astype(int)
+
+    if num_bootstrap is not None:
+        auc_df = pd.DataFrame()
+        for i in range(num_bootstrap):
+            df = adata_test_final.obs.groupby("patient_id").sample(
+                frac=1, replace=True, random_state=i
+            )
+            auc = (
+                df.groupby(["patient_id", cell_pred_col + "_sample_pred"])
+                .apply(lambda df: cal_sample_auc(df, cell_pred_col))
+                .to_frame(name=i)
+            )
+            auc_df = pd.concat([auc_df, auc], axis=1)
+
+        auc_df[cell_pred_col + "_sample_auc_pvalue"] = auc_df.apply(
+            lambda row: auc_pvalue(row), axis=1
+        )
+        # store auc from each bootstrap iteration in adata.uns
+        adata_test_final.uns[cell_pred_col + "_auc_df"] = auc_df
+        # store auc_pvalue for each sample in adata.obs
+        auc_df = auc_df.droplevel(cell_pred_col + "_sample_pred")
+        adata_test_final.obs[cell_pred_col + "_sample_auc_pvalue"] = (
+            adata_test_final.obs["patient_id"].map(
+                auc_df[cell_pred_col + "_sample_auc_pvalue"]
+            )
+        )
+
+    return adata_test_final
+
+
+def pt_score(adata_test_final: AnnData, cell_pred_col: str="median_pred_score") -> AnnData:
+    ## Calculate precision, recall, f1score and accuracy at patient level
+    from sklearn.metrics import precision_recall_fscore_support
+
+    pred_col = cell_pred_col
+    res_prefix = cell_pred_col
+
+    pt_pred_res = (
+        adata_test_final.obs[["label", "patient_id", f"{res_prefix}_sample_pred"]]
+        .drop_duplicates()
+        .set_index("patient_id")
+    )
+
+    # precision, recall, f1score
+    pt_score_res = precision_recall_fscore_support(
+        pt_pred_res["label"],
+        pt_pred_res[f"{res_prefix}_sample_pred"],
+        average="weighted",
+    )
+    # accuracy
+    pt_acc_res = accuracy_score(
+        pt_pred_res["label"], pt_pred_res[f"{res_prefix}_sample_pred"]
+    )
+    # specificity
+    pt_spec_res = recall_score(
+        pt_pred_res["label"], pt_pred_res[f"{res_prefix}_sample_pred"], pos_label=0
+    )
+
+    pt_score_res = pd.DataFrame(list(pt_score_res) + [pt_acc_res] + [pt_spec_res])
+    pt_score_res = pt_score_res.iloc[[0, 1, 2, 4, 5], :]
+    pt_score_res.index = [
+        "precision",
+        "sensitivity",
+        "f1score",
+        "accuracy",
+        "specificity",
+    ]
+    pt_score_res.columns = [res_prefix]
+
+    if "sample_score" not in adata_test_final.uns:
+        adata_test_final.uns["sample_score"] = pt_score_res
+    else:
+        adata_test_final.uns["sample_score"] = adata_test_final.uns[
+            "sample_score"
+        ].merge(pt_score_res, left_index=True, right_index=True, suffixes=("_x", ""))
+
+        adata_test_final.uns["sample_score"].drop(
+            adata_test_final.uns["sample_score"].filter(regex="_x$").columns,
+            axis=1,
+            inplace=True,
+        )
+
+    return adata_test_final
+
+
+from math import pi
+
+# Plot functions
+import matplotlib.pyplot as plt
+import seaborn as sns
+from matplotlib import rcParams
+
+
+def _panel_grid(hspace: float, wspace: float, ncols: int, num_panels: int) -> Tuple[Figure, GridSpec]:
+    from matplotlib import gridspec
+
+    n_panels_x = min(ncols, num_panels)
+    n_panels_y = np.ceil(num_panels / n_panels_x).astype(int)
+    # each panel will have the size of rcParams['figure.figsize']
+    fig = plt.figure(
+        figsize=(
+            n_panels_x * rcParams["figure.figsize"][0] * (1 + wspace),
+            n_panels_y * rcParams["figure.figsize"][1],
+        ),
+    )
+    left = 0.2 / n_panels_x
+    bottom = 0.13 / n_panels_y
+    gs = gridspec.GridSpec(
+        nrows=n_panels_y,
+        ncols=n_panels_x,
+        left=left,
+        right=1 - (n_panels_x - 1) * left - 0.01 / n_panels_x,
+        bottom=bottom,
+        top=1 - (n_panels_y - 1) * bottom - 0.1 / n_panels_y,
+        hspace=hspace,
+        wspace=wspace,
+    )
+    return fig, gs
+
+
+def plot_roc_curve(
+    adata_test_final: AnnData,
+    sample_id: Series,
+    cell_pred_col: str,
+    ncols: int=4,
+    hspace: float=0.25,
+    wspace: None=None,
+    ax: None=None,
+    scatter_kws: Optional[Dict[str, int]]=None,
+    legend_kws: Optional[Dict[str, Dict[str, int]]]=None,
+) -> List[Axes]:
+    """
+    Parameters
+    ----------
+    - adata_test_final: AnnData,
+    - sample_id: str | Sequence,
+    - cell_pred_col: str = 'median_pred_score',
+    - ncols: int = 4,
+    - hspace: float =0.25,
+    - wspace: float | None = None,
+    - ax: Axes | None = None,
+    - scatter_kws: dict | None = None, Arguments to pass to matplotlib.pyplot.scatter()
+
+    Returns
+    -------
+    Axes
+
+    Examples
+    --------
+    plot_roc_curve(adata_test_final,
+               sample_id = ['C3','C6','H1'],
+               cell_pred_col = 'median_pred_score',
+               scatter_kws={'s':10})
+
+    """
+
+    # turn sample_id into a python list
+    ## if sample_id is string or None, wrap it with []
+    ## if sample_id is already sequential, turn it into a list
+    sample_id = (
+        [sample_id]
+        if isinstance(sample_id, str) or sample_id is None
+        else list(sample_id)
+    )
+
+    ##########
+    # Layout #
+    ##########
+    if scatter_kws is None:
+        scatter_kws = {}
+
+    if legend_kws is None:
+        legend_kws = {}
+
+    if wspace is None:
+        #  try to set a wspace that is not too large or too small given the
+        #  current figure size
+        wspace = 0.75 / rcParams["figure.figsize"][0] + 0.02
+
+    # if plotting multiple panels for elements in sample_id
+    if len(sample_id) > 1:
+        if ax is not None:
+            raise ValueError(
+                "Cannot specify `ax` when plotting multiple panels "
+                "(each for a given value of 'color')."
+            )
+        fig, grid = _panel_grid(hspace, wspace, ncols, len(sample_id))
+    else:
+        grid = None
+        if ax is None:
+            fig = plt.figure()
+            ax = fig.add_subplot(111)
+
+    ############
+    # Plotting #
+    ############
+    axs = []
+    for count, _sample_id in enumerate(sample_id):
+        if grid:
+            ax = plt.subplot(grid[count])
+            axs.append(ax)
+
+        # prediction probability of class 1 for sample_id
+        cell_prob = adata_test_final.obs.loc[
+            adata_test_final.obs["patient_id"] == sample_id[count]
+        ][cell_pred_col]
+        cell_prob = cell_prob.sort_values(ascending=True)
+        # rank of cell_prob and normalize
+        cell_rank = (
+            cell_prob.rank(method="first") / cell_prob.rank(method="first").max()
+        )
+        # auc
+        sample_auc = adata_test_final.obs.loc[
+            adata_test_final.obs["patient_id"] == sample_id[count]
+        ][cell_pred_col + "_sample_auc"].unique()[0]
+        # auc-pvalue
+        sample_auc_pvalue = adata_test_final.obs.loc[
+            adata_test_final.obs["patient_id"] == sample_id[count]
+        ][cell_pred_col + "_sample_auc_pvalue"].unique()[0]
+
+        ax.scatter(x=cell_rank, y=cell_prob, c=".3", **scatter_kws)
+        ax.plot(
+            cell_rank,
+            cell_prob,
+            label=f"AUC = {sample_auc:.3f} \np-value = {sample_auc_pvalue:.1e}",
+            zorder=0,
+        )
+        ax.plot(
+            [0, 1], [0, 1], linestyle="--", color=".5", zorder=0, label="Random guess"
+        )
+        # ax.text(x = 0.99, y = 0.01, s = f'AUC: {sample_auc:.3f}',
+        #         horizontalalignment='right',
+        #         verticalalignment='bottom')
+        ax.spines[["right", "top"]].set_visible(False)
+        ax.set_xlabel("Rank")
+        ax.set_ylabel("Prediction Probability (Cell)")
+        ax.set_title(f"{_sample_id}")
+        ax.set_aspect("equal")
+        if not bool(legend_kws):
+            ax.legend(prop=dict(size=8 * rcParams["figure.figsize"][0] / ncols))
+        else:
+            ax.legend(**legend_kws)
+
+    axs = axs if grid else ax
+
+    return axs
+
+
+def convert_pvalue_to_asterisks(pvalue: float) -> str:
+    if pvalue <= 0.0001:
+        return "****"
+    elif pvalue <= 0.001:
+        return "***"
+    elif pvalue <= 0.01:
+        return "**"
+    elif pvalue <= 0.05:
+        return "*"
+    return "ns"
+
+
+# plot cell level probabilities for each patient
+def plot_violin(
+    adata: AnnData,
+    cell_pred_col: str="median_pred_score",
+    dot_size: int=2,
+    ax: Optional[Axes]=None,
+    palette: Optional[Dict[str, str]]=None,
+    xticklabels_color: bool=False,
+    text_kws: Dict[Any, Any]={},
+) -> Axes:
+    """
+    Violin Plots for cell-level prediction probabilities in each sample.
+
+    Parameters:
+    - adata: AnnData Object
+
+    - cell_pred_col: string, name of the column with cell-level prediction probabilities
+    in adata.obs (default: 'median_pred_score')
+
+    - pt_stat: string, a test for the null hypothesis that the distribution of probabilities
+    in this sample is different from the population (default: 'perm')
+        Options:
+        - 'perm': permutation test
+        - 't-test': one-sample t-test
+
+    - fig_size: tuple, size of figure (default: (10, 3))
+    - dot_size: float, Radius of the markers in stripplot.
+
+    Returns:
+        ax
+
+    """
+
+    # A. organize input data for plotting--------------
+    res_prefix = cell_pred_col
+    ## cell-level data
+    pred_score_df = adata.obs[
+        [
+            cell_pred_col,
+            "y",
+            "label",
+            "patient_id",
+            f"{res_prefix}_sample_auc",
+            f"{res_prefix}_sample_auc_pvalue",
+        ]
+    ].copy()
+
+    ## sample-level data
+    sample_pData = pred_score_df.groupby(
+        [
+            "y",
+            "label",
+            "patient_id",
+            f"{res_prefix}_sample_auc",
+            f"{res_prefix}_sample_auc_pvalue",
+        ],
+        observed=True,
+        as_index=False,
+    )[cell_pred_col].max()
+    sample_pData.rename(columns={cell_pred_col: "y_pos"}, inplace=True)
+    sample_pData = sample_pData.sort_values(by=f"{res_prefix}_sample_auc").reset_index(
+        drop=True
+    )
+
+    sample_order = sample_pData.patient_id.tolist()
+
+    sample_threshold_index = (
+        sample_pData[f"{res_prefix}_sample_auc"]
+        .where(sample_pData[f"{res_prefix}_sample_auc"] >= 0.5)
+        .first_valid_index()
+    )
+    if sample_threshold_index is None:
+        if (sample_pData[f"{res_prefix}_sample_auc"] >= 0.5).all():
+            sample_threshold = -0.5
+        else:
+            sample_threshold = len(sample_pData[f"{res_prefix}_sample_auc"]) - 0.5
+    else:
+        sample_threshold = sample_threshold_index - 0.5
+
+    # B. plot--------------------------------------------
+    if ax is None:
+        ax = plt.gca()
+
+    # Hide the right and top spines
+    ax.spines[["right", "top"]].set_visible(False)
+
+    # Violin plot
+    sns.violinplot(
+        y=cell_pred_col,
+        x="patient_id",
+        data=pred_score_df,
+        order=sample_order,
+        color="0.8",
+        scale="width",
+        fontsize=15,
+        ax=ax,
+        cut=0,
+    )
+
+    # Strip plot
+    sns.stripplot(
+        y=cell_pred_col,
+        x="patient_id",
+        hue="y",
+        data=pred_score_df,
+        order=sample_order,
+        dodge=False,
+        jitter=True,
+        size=dot_size,
+        ax=ax,
+        palette=palette,
+    )
+
+    ax.axhline(y=0.5, color="0.8", linestyle="--")
+    ax.axvline(x=sample_threshold, color="0.8", linestyle="--")
+
+    # Add statistical signifiance (asterisks (*)) on top of each violin
+    ## get position x
+    yposlist = (sample_pData["y_pos"] + 0.03).tolist()
+    ## get position y
+    xposlist = range(len(yposlist))
+    ## get text
+    pvalue_list = sample_pData[f"{res_prefix}_sample_auc_pvalue"].tolist()
+    asterisks_list = [convert_pvalue_to_asterisks(pvalue) for pvalue in pvalue_list]
+    perm_stat_list = [
+        "%.3f" % perm_stat
+        for perm_stat in sample_pData[f"{res_prefix}_sample_auc"].tolist()
+    ]
+    text_list = [
+        perm_stat + "\n" + asterisk
+        for perm_stat, asterisk in zip(perm_stat_list, asterisks_list)
+    ]
+
+    for k in range(len(asterisks_list)):
+        ax.text(x=xposlist[k], y=yposlist[k], s=text_list[k], ha="center", **text_kws)
+
+    ax.set_title(cell_pred_col, pad=30)
+    ax.set_xlabel(None)
+    ax.set_ylabel("Prediction Probablity (Cell)", fontsize=13)
+    ax.plot()
+
+    ax.set_xticks(ax.get_xticks(), ax.get_xticklabels(), rotation=45, ha="right")
+    if xticklabels_color:
+        for xtick in ax.get_xticklabels():
+            x_label = xtick.get_text()
+            x_label_cate = sample_pData["y"][
+                sample_pData["patient_id"] == x_label
+            ].values[0]
+            xtick.set_color(palette[x_label_cate])
+
+    ax.legend(loc="upper left", title="Patient Label", bbox_to_anchor=(1.04, 1))
+
+    return ax
+
+
+### Plot patient level prediction scores
+def make_single_spider(adata_test_final: AnnData, metric_idx: int, color: str, nrow: int, ncol: int) -> None:
+    # number of variable
+    categories = adata_test_final.uns["sample_score"].index.tolist()
+    N = len(adata_test_final.uns["sample_score"].index)
+
+    # We are going to plot the first line of the data frame.
+    # But we need to repeat the first value to close the circular graph:
+    values = (
+        adata_test_final.uns["sample_score"]
+        .iloc[:, metric_idx]
+        .values.flatten()
+        .tolist()
+    )
+    values += values[:1]
+
+    # What will be the angle of each axis in the plot? (we divide the plot / number of variable)
+    angles = [n / float(N) * 2 * pi for n in range(N)]
+    angles += angles[:1]
+
+    # Initialise the spider plot
+    ax = plt.subplot(nrow, ncol, metric_idx + 1, polar=True)
+
+    # If you want the first axis to be on top:
+    ax.set_theta_offset(pi / 2)
+    ax.set_theta_direction(-1)
+
+    # Draw one axe per variable + add labels labels yet
+    plt.xticks(angles[:-1], categories, color="grey", size=15)
+
+    for label, i in zip(ax.get_xticklabels(), range(0, len(angles))):
+        if i < len(angles) / 2:
+            angle_text = angles[i] * (-180 / pi) + 90
+            label.set_horizontalalignment("left")
+
+        else:
+            angle_text = angles[i] * (-180 / pi) - 90
+            label.set_horizontalalignment("right")
+        label.set_rotation(angle_text)
+
+    # Draw ylabels
+    ax.set_rlabel_position(0)
+    plt.yticks([0.1, 0.3, 0.6], ["0.1", "0.3", "0.6"], color="grey", size=8)
+    plt.ylim(0, 1.05)
+
+    # Plot data
+    ax.plot(angles, values, color=color, linewidth=2, linestyle="solid")
+    ax.fill(angles, values, color=color, alpha=0.4)
+    ax.grid(color="white")
+    for ti, di in zip(angles, values):
+        ax.text(
+            ti, di - 0.02, "{0:.2f}".format(di), color="black", ha="center", va="center"
+        )
+
+    # Add a title
+    t = adata_test_final.uns["sample_score"].columns[metric_idx]
+    t = t.replace("_pred_score", "")
+    plt.title(t, color="black", y=1.2, size=22)