Diff of /src/scpanel/train.py [000000] .. [d90ecf]

Switch to unified view

a b/src/scpanel/train.py
1
import os
2
import pickle
3
import time
4
5
from sklearn.ensemble import RandomForestClassifier
6
from sklearn.linear_model import LogisticRegression
7
8
# import sklearn.linear_model as lm
9
from sklearn.metrics import (
10
    accuracy_score,
11
    auc,
12
    average_precision_score,
13
    balanced_accuracy_score,
14
    classification_report,
15
    confusion_matrix,
16
    matthews_corrcoef,
17
    recall_score,
18
    roc_auc_score,
19
)
20
from sklearn.model_selection import GridSearchCV
21
from sklearn.neighbors import KNeighborsClassifier
22
from sklearn.svm import SVC
23
24
from .GATclassifier import GATclassifier
25
26
# import pandas as pd
27
# import numpy as np
28
from .utils_func import *
29
import sklearn.ensemble._forest
30
import sklearn.linear_model._logistic
31
import sklearn.neighbors._classification
32
import sklearn.svm._classes
33
from anndata._core.anndata import AnnData
34
from matplotlib.axes._axes import Axes
35
from matplotlib.figure import Figure
36
from matplotlib.gridspec import GridSpec
37
from numpy import float64, ndarray
38
from pandas.core.frame import DataFrame
39
from pandas.core.series import Series
40
from scpanel.GATclassifier import GATclassifier
41
from torch import Tensor
42
from typing import Any, Dict, List, Optional, Tuple, Union
43
44
45
def transform_adata(adata_train: AnnData, adata_test_dict: Dict[str, AnnData], selected_gene: Optional[List[str]]=None) -> Tuple[AnnData, AnnData]:
46
    ## Transforming train set and test set from the same dataset (batch effect free)
47
    ## subset adata_train with selected genes
48
    ## subset adata_test_dict with selected cell types and genes
49
    ## WATCH OUT: X matrix in adata_test_dict is log-normalized, need to scale further
50
    if selected_gene == None:
51
        selected_gene = adata_train.uns["svm_rfe_genes"]
52
53
    adata_train_final = adata_train[:, selected_gene]
54
55
    mean = adata_train_final.var["mean"].values
56
    std = adata_train_final.var["std"].values
57
58
    ct_selected = adata_train_final.obs.ct.unique()[0]
59
60
    # transform test data with selected gene, celltype and scaling
61
    adata_test = adata_test_dict[ct_selected].copy()
62
    adata_test_final = adata_test[:, selected_gene].copy()
63
64
    if isinstance(adata_test_final.X, np.ndarray):
65
        test_X = adata_test_final.X
66
    else:
67
        test_X = adata_test_final.X.toarray()
68
    test_X -= mean
69
    test_X /= std
70
71
    max_value = 10
72
    test_X[test_X > max_value] = max_value
73
    adata_test_final.X = test_X
74
75
    return adata_train_final, adata_test_final
76
77
78
def models_train(adata_train_final: AnnData, search_grid: bool, out_dir: Optional[str]=None, param_grid: Optional[Dict[str, Dict[str, int]]]=None) -> List[Union[Tuple[str, sklearn.linear_model._logistic.LogisticRegression], Tuple[str, sklearn.ensemble._forest.RandomForestClassifier], Tuple[str, sklearn.svm._classes.SVC], Tuple[str, sklearn.neighbors._classification.KNeighborsClassifier], Tuple[str, GATclassifier]]]:
79
80
    X_tr, y_tr, adj_tr = get_X_y_from_ann(
81
        adata_train_final, return_adj=True, n_neigh=10
82
    )
83
    sample_weight = compute_cell_weight(adata_train_final)
84
85
    # Make sure no nan in matrix
86
    X_tr = np.nan_to_num(X_tr)
87
88
    grid_search = search_grid
89
    models = [
90
        ("LR", LogisticRegression(solver="saga", max_iter=500, random_state=42)),
91
        ("RF", RandomForestClassifier(random_state=42)),
92
        ("SVM", SVC(probability=True, random_state=42)),
93
        ("KNN", KNeighborsClassifier()),
94
        (
95
            "GAT",
96
            GATclassifier(
97
                nFeatures=adata_train_final.n_vars, NumParts=10, nEpochs=1000, verbose=1
98
            ),
99
        ),
100
    ]
101
102
    # Parameter tuning grids-------------------------
103
    LR_params = [{"C": [10, 1.0, 0.1, 0.01], "max_iter": [10, 50, 200, 500]}]
104
    RF_params = [
105
        {"max_depth": [2, 5, 10, 15, 20, 30, None], "n_estimators": [50, 100, 500]}
106
    ]
107
    SVM_params = [{"C": [100, 10, 1.0, 0.1, 0.001], "gamma": [1, 0.1, 0.01, 0.001]}]
108
    KNN_params = [{"n_neighbors": [3, 5, 10, 20, 50], "p": [1, 2]}]
109
110
    my_grid = {"LR": LR_params, "RF": RF_params, "SVM": SVM_params, "KNN": KNN_params}
111
112
    clfs = []
113
    names = []
114
    runtimes = []
115
    best_params = []
116
117
    for name, model in models:
118
        start_time = time.time()
119
120
        if grid_search:
121
            if name != "GAT":
122
                clf = GridSearchCV(
123
                    model, my_grid[name], cv=5, scoring="roc_auc", n_jobs=10
124
                )
125
            else:
126
                clf = model
127
        else:
128
            clf = model
129
            if param_grid is not None:
130
                if name in param_grid:
131
                    clf.set_params(**param_grid[name])
132
133
        if name == "GAT":
134
            clf.fit(X_tr, y_tr, adj_tr)
135
        elif name == "KNN":
136
            clf.fit(X_tr, y_tr)
137
        else:
138
            clf.fit(X_tr, y_tr, sample_weight=sample_weight)
139
140
        runtime = time.time() - start_time
141
142
        # save outputs
143
        clfs.append((name, clf))
144
        names.append(name)
145
        runtimes.append(runtime)
146
147
        print("---%s finished in %s seconds ---" % (name, runtime))
148
149
    # save models
150
    if out_dir is not None:
151
152
        if not os.path.exists(out_dir):
153
            os.makedirs(out_dir)
154
155
        with open(f"{out_dir}/clfs.pkl", "wb") as f:
156
            pickle.dump(clfs, f, protocol=pickle.HIGHEST_PROTOCOL)
157
            f.close()
158
159
        with open(f"{out_dir}/adata_train_final.pkl", "wb") as f:
160
            pickle.dump(adata_train_final, f, protocol=pickle.HIGHEST_PROTOCOL)
161
            f.close()
162
163
    return clfs
164
165
166
def models_predict(clfs: List[Union[Tuple[str, sklearn.linear_model._logistic.LogisticRegression], Tuple[str, sklearn.ensemble._forest.RandomForestClassifier], Tuple[str, sklearn.svm._classes.SVC], Tuple[str, sklearn.neighbors._classification.KNeighborsClassifier], Tuple[str, GATclassifier]]], adata_test_final: AnnData, out_dir: Optional[str]=None) -> Tuple[AnnData, List[Union[Tuple[str, ndarray], Tuple[str, Tensor]]], List[Tuple[str, ndarray]]]:
167
    X_test, y_test, adj_test = get_X_y_from_ann(
168
        adata_test_final, return_adj=True, n_neigh=10
169
    )
170
    X_test = np.nan_to_num(X_test)
171
172
    ## Predicting---------------
173
    y_pred_list = []
174
    y_pred_score_list = []
175
176
    for name, clf in clfs:
177
        if name == "GAT":
178
            y_pred = clf.predict(X_test, y_test, adj_test)
179
            y_pred_score = clf.predict_proba(X_test, y_test, adj_test)
180
        else:
181
            y_pred = clf.predict(X_test)
182
            y_pred_score = clf.predict_proba(X_test)
183
184
        y_pred_list.append((name, y_pred))
185
        y_pred_score_list.append((name, y_pred_score))
186
187
    # add prediction result to adata_test_final
188
    y_pred = pd.DataFrame(dict([(name + "_pred", pred) for name, pred in y_pred_list]))
189
    y_pred_score = pd.DataFrame(
190
        dict([(name + "_pred_score", pred[:, 1]) for name, pred in y_pred_score_list])
191
    )
192
193
    y_pred_df = pd.concat([y_pred, y_pred_score], axis=1)
194
    y_pred_df.index = adata_test_final.obs.index
195
196
    if set(y_pred_df.columns).issubset(set(adata_test_final.obs.columns)):
197
        print("Prediction result already exits in test adata, overwrite it...")
198
        adata_test_final.obs.update(y_pred_df)
199
    else:
200
        adata_test_final.obs = pd.concat([adata_test_final.obs, y_pred_df], axis=1)
201
202
    # calcuate median prediction score out of 5 classifiers
203
    pred_col = [
204
        col for col in adata_test_final.obs.columns if col.endswith("_pred_score")
205
    ]
206
    adata_test_final.obs["median_pred_score"] = adata_test_final.obs[pred_col].median(
207
        axis=1
208
    )
209
210
    return adata_test_final, y_pred_list, y_pred_score_list
211
212
213
def models_score(adata_test_final, y_pred_list, y_pred_score_list, out_dir=None):
214
    X_test, y_test = get_X_y_from_ann(adata_test_final)
215
216
    ## Scoring-------------------------------------
217
    ## define scoring metrics (from sklearn)
218
    scorers = {
219
        "accuracy": (accuracy_score, {}),
220
        "balanced_accuracy": (balanced_accuracy_score, {}),
221
        "MCC": (matthews_corrcoef, {}),
222
    }  # Passing Dictionary as Arguments to Function
223
224
    scorers_prob = {
225
        "AUROC": (roc_auc_score, {}),
226
        "AUPRC": (average_precision_score, {}),
227
    }
228
229
    ## calculate
230
    eval_res_1 = pd.DataFrame()
231
    for name, y_pred in y_pred_list:
232
        eval_res_dict = dict(
233
            [
234
                (score_name, score_func(y_test, y_pred, **score_para))
235
                for score_name, (score_func, score_para) in scorers.items()
236
            ]
237
        )
238
        eval_res_i = pd.DataFrame(eval_res_dict, index=[name])
239
240
        eval_res_1 = pd.concat(objs=[eval_res_1, eval_res_i], axis=0)
241
242
    eval_res_2 = pd.DataFrame()
243
    for name, y_pred_score in y_pred_score_list:
244
        eval_res_dict = dict(
245
            [
246
                (score_name, score_func(y_test, y_pred_score[:, 1], **score_para))
247
                for score_name, (score_func, score_para) in scorers_prob.items()
248
            ]
249
        )
250
        eval_res_i = pd.DataFrame(eval_res_dict, index=[name])
251
252
        eval_res_2 = pd.concat(objs=[eval_res_2, eval_res_i], axis=0)
253
254
    eval_res = pd.concat(objs=[eval_res_2, eval_res_1], axis=1)
255
256
    if out_dir is not None:
257
        if not os.path.exists(out_dir):
258
            os.makedirs(out_dir)
259
260
        eval_res.to_csv(f"{out_dir}/eval_res.csv")
261
262
    return eval_res
263
264
265
def cal_sample_auc(df: DataFrame, score_col: str) -> float64:
266
    cell_prob = df[score_col].sort_values()
267
    # rank the cell probability ascendingly and normalize
268
    cell_rank = cell_prob.rank(method="first") / cell_prob.rank(method="first").max()
269
    sample_auc = auc(cell_rank, cell_prob)
270
    return sample_auc
271
272
273
def auc_pvalue(row: Series) -> float:
274
    if row.name[1] == 1:
275
        p_value = np.mean(row < 0.5)
276
    elif row.name[1] == 0:
277
        p_value = np.mean(row > 0.5)
278
279
    if p_value == 0:
280
        p_value = 1 / row.size
281
    return p_value
282
283
284
def pt_pred(adata_test_final: AnnData, cell_pred_col: str="median_pred_score", num_bootstrap: Optional[int]=None) -> AnnData:
285
    sample_auc = adata_test_final.obs.groupby("patient_id").apply(
286
        lambda df: cal_sample_auc(df, cell_pred_col)
287
    )
288
    adata_test_final.obs[cell_pred_col + "_sample_auc"] = (
289
        adata_test_final.obs["patient_id"].map(sample_auc).astype(float)
290
    )
291
    adata_test_final.obs[cell_pred_col + "_sample_pred"] = (
292
        adata_test_final.obs[cell_pred_col + "_sample_auc"] >= 0.5
293
    ).astype(int)
294
295
    if num_bootstrap is not None:
296
        auc_df = pd.DataFrame()
297
        for i in range(num_bootstrap):
298
            df = adata_test_final.obs.groupby("patient_id").sample(
299
                frac=1, replace=True, random_state=i
300
            )
301
            auc = (
302
                df.groupby(["patient_id", cell_pred_col + "_sample_pred"])
303
                .apply(lambda df: cal_sample_auc(df, cell_pred_col))
304
                .to_frame(name=i)
305
            )
306
            auc_df = pd.concat([auc_df, auc], axis=1)
307
308
        auc_df[cell_pred_col + "_sample_auc_pvalue"] = auc_df.apply(
309
            lambda row: auc_pvalue(row), axis=1
310
        )
311
        # store auc from each bootstrap iteration in adata.uns
312
        adata_test_final.uns[cell_pred_col + "_auc_df"] = auc_df
313
        # store auc_pvalue for each sample in adata.obs
314
        auc_df = auc_df.droplevel(cell_pred_col + "_sample_pred")
315
        adata_test_final.obs[cell_pred_col + "_sample_auc_pvalue"] = (
316
            adata_test_final.obs["patient_id"].map(
317
                auc_df[cell_pred_col + "_sample_auc_pvalue"]
318
            )
319
        )
320
321
    return adata_test_final
322
323
324
def pt_score(adata_test_final: AnnData, cell_pred_col: str="median_pred_score") -> AnnData:
325
    ## Calculate precision, recall, f1score and accuracy at patient level
326
    from sklearn.metrics import precision_recall_fscore_support
327
328
    pred_col = cell_pred_col
329
    res_prefix = cell_pred_col
330
331
    pt_pred_res = (
332
        adata_test_final.obs[["label", "patient_id", f"{res_prefix}_sample_pred"]]
333
        .drop_duplicates()
334
        .set_index("patient_id")
335
    )
336
337
    # precision, recall, f1score
338
    pt_score_res = precision_recall_fscore_support(
339
        pt_pred_res["label"],
340
        pt_pred_res[f"{res_prefix}_sample_pred"],
341
        average="weighted",
342
    )
343
    # accuracy
344
    pt_acc_res = accuracy_score(
345
        pt_pred_res["label"], pt_pred_res[f"{res_prefix}_sample_pred"]
346
    )
347
    # specificity
348
    pt_spec_res = recall_score(
349
        pt_pred_res["label"], pt_pred_res[f"{res_prefix}_sample_pred"], pos_label=0
350
    )
351
352
    pt_score_res = pd.DataFrame(list(pt_score_res) + [pt_acc_res] + [pt_spec_res])
353
    pt_score_res = pt_score_res.iloc[[0, 1, 2, 4, 5], :]
354
    pt_score_res.index = [
355
        "precision",
356
        "sensitivity",
357
        "f1score",
358
        "accuracy",
359
        "specificity",
360
    ]
361
    pt_score_res.columns = [res_prefix]
362
363
    if "sample_score" not in adata_test_final.uns:
364
        adata_test_final.uns["sample_score"] = pt_score_res
365
    else:
366
        adata_test_final.uns["sample_score"] = adata_test_final.uns[
367
            "sample_score"
368
        ].merge(pt_score_res, left_index=True, right_index=True, suffixes=("_x", ""))
369
370
        adata_test_final.uns["sample_score"].drop(
371
            adata_test_final.uns["sample_score"].filter(regex="_x$").columns,
372
            axis=1,
373
            inplace=True,
374
        )
375
376
    return adata_test_final
377
378
379
from math import pi
380
381
# Plot functions
382
import matplotlib.pyplot as plt
383
import seaborn as sns
384
from matplotlib import rcParams
385
386
387
def _panel_grid(hspace: float, wspace: float, ncols: int, num_panels: int) -> Tuple[Figure, GridSpec]:
388
    from matplotlib import gridspec
389
390
    n_panels_x = min(ncols, num_panels)
391
    n_panels_y = np.ceil(num_panels / n_panels_x).astype(int)
392
    # each panel will have the size of rcParams['figure.figsize']
393
    fig = plt.figure(
394
        figsize=(
395
            n_panels_x * rcParams["figure.figsize"][0] * (1 + wspace),
396
            n_panels_y * rcParams["figure.figsize"][1],
397
        ),
398
    )
399
    left = 0.2 / n_panels_x
400
    bottom = 0.13 / n_panels_y
401
    gs = gridspec.GridSpec(
402
        nrows=n_panels_y,
403
        ncols=n_panels_x,
404
        left=left,
405
        right=1 - (n_panels_x - 1) * left - 0.01 / n_panels_x,
406
        bottom=bottom,
407
        top=1 - (n_panels_y - 1) * bottom - 0.1 / n_panels_y,
408
        hspace=hspace,
409
        wspace=wspace,
410
    )
411
    return fig, gs
412
413
414
def plot_roc_curve(
415
    adata_test_final: AnnData,
416
    sample_id: Series,
417
    cell_pred_col: str,
418
    ncols: int=4,
419
    hspace: float=0.25,
420
    wspace: None=None,
421
    ax: None=None,
422
    scatter_kws: Optional[Dict[str, int]]=None,
423
    legend_kws: Optional[Dict[str, Dict[str, int]]]=None,
424
) -> List[Axes]:
425
    """
426
    Parameters
427
    ----------
428
    - adata_test_final: AnnData,
429
    - sample_id: str | Sequence,
430
    - cell_pred_col: str = 'median_pred_score',
431
    - ncols: int = 4,
432
    - hspace: float =0.25,
433
    - wspace: float | None = None,
434
    - ax: Axes | None = None,
435
    - scatter_kws: dict | None = None, Arguments to pass to matplotlib.pyplot.scatter()
436
437
    Returns
438
    -------
439
    Axes
440
441
    Examples
442
    --------
443
    plot_roc_curve(adata_test_final,
444
               sample_id = ['C3','C6','H1'],
445
               cell_pred_col = 'median_pred_score',
446
               scatter_kws={'s':10})
447
448
    """
449
450
    # turn sample_id into a python list
451
    ## if sample_id is string or None, wrap it with []
452
    ## if sample_id is already sequential, turn it into a list
453
    sample_id = (
454
        [sample_id]
455
        if isinstance(sample_id, str) or sample_id is None
456
        else list(sample_id)
457
    )
458
459
    ##########
460
    # Layout #
461
    ##########
462
    if scatter_kws is None:
463
        scatter_kws = {}
464
465
    if legend_kws is None:
466
        legend_kws = {}
467
468
    if wspace is None:
469
        #  try to set a wspace that is not too large or too small given the
470
        #  current figure size
471
        wspace = 0.75 / rcParams["figure.figsize"][0] + 0.02
472
473
    # if plotting multiple panels for elements in sample_id
474
    if len(sample_id) > 1:
475
        if ax is not None:
476
            raise ValueError(
477
                "Cannot specify `ax` when plotting multiple panels "
478
                "(each for a given value of 'color')."
479
            )
480
        fig, grid = _panel_grid(hspace, wspace, ncols, len(sample_id))
481
    else:
482
        grid = None
483
        if ax is None:
484
            fig = plt.figure()
485
            ax = fig.add_subplot(111)
486
487
    ############
488
    # Plotting #
489
    ############
490
    axs = []
491
    for count, _sample_id in enumerate(sample_id):
492
        if grid:
493
            ax = plt.subplot(grid[count])
494
            axs.append(ax)
495
496
        # prediction probability of class 1 for sample_id
497
        cell_prob = adata_test_final.obs.loc[
498
            adata_test_final.obs["patient_id"] == sample_id[count]
499
        ][cell_pred_col]
500
        cell_prob = cell_prob.sort_values(ascending=True)
501
        # rank of cell_prob and normalize
502
        cell_rank = (
503
            cell_prob.rank(method="first") / cell_prob.rank(method="first").max()
504
        )
505
        # auc
506
        sample_auc = adata_test_final.obs.loc[
507
            adata_test_final.obs["patient_id"] == sample_id[count]
508
        ][cell_pred_col + "_sample_auc"].unique()[0]
509
        # auc-pvalue
510
        sample_auc_pvalue = adata_test_final.obs.loc[
511
            adata_test_final.obs["patient_id"] == sample_id[count]
512
        ][cell_pred_col + "_sample_auc_pvalue"].unique()[0]
513
514
        ax.scatter(x=cell_rank, y=cell_prob, c=".3", **scatter_kws)
515
        ax.plot(
516
            cell_rank,
517
            cell_prob,
518
            label=f"AUC = {sample_auc:.3f} \np-value = {sample_auc_pvalue:.1e}",
519
            zorder=0,
520
        )
521
        ax.plot(
522
            [0, 1], [0, 1], linestyle="--", color=".5", zorder=0, label="Random guess"
523
        )
524
        # ax.text(x = 0.99, y = 0.01, s = f'AUC: {sample_auc:.3f}',
525
        #         horizontalalignment='right',
526
        #         verticalalignment='bottom')
527
        ax.spines[["right", "top"]].set_visible(False)
528
        ax.set_xlabel("Rank")
529
        ax.set_ylabel("Prediction Probability (Cell)")
530
        ax.set_title(f"{_sample_id}")
531
        ax.set_aspect("equal")
532
        if not bool(legend_kws):
533
            ax.legend(prop=dict(size=8 * rcParams["figure.figsize"][0] / ncols))
534
        else:
535
            ax.legend(**legend_kws)
536
537
    axs = axs if grid else ax
538
539
    return axs
540
541
542
def convert_pvalue_to_asterisks(pvalue: float) -> str:
543
    if pvalue <= 0.0001:
544
        return "****"
545
    elif pvalue <= 0.001:
546
        return "***"
547
    elif pvalue <= 0.01:
548
        return "**"
549
    elif pvalue <= 0.05:
550
        return "*"
551
    return "ns"
552
553
554
# plot cell level probabilities for each patient
555
def plot_violin(
556
    adata: AnnData,
557
    cell_pred_col: str="median_pred_score",
558
    dot_size: int=2,
559
    ax: Optional[Axes]=None,
560
    palette: Optional[Dict[str, str]]=None,
561
    xticklabels_color: bool=False,
562
    text_kws: Dict[Any, Any]={},
563
) -> Axes:
564
    """
565
    Violin Plots for cell-level prediction probabilities in each sample.
566
567
    Parameters:
568
    - adata: AnnData Object
569
570
    - cell_pred_col: string, name of the column with cell-level prediction probabilities
571
    in adata.obs (default: 'median_pred_score')
572
573
    - pt_stat: string, a test for the null hypothesis that the distribution of probabilities
574
    in this sample is different from the population (default: 'perm')
575
        Options:
576
        - 'perm': permutation test
577
        - 't-test': one-sample t-test
578
579
    - fig_size: tuple, size of figure (default: (10, 3))
580
    - dot_size: float, Radius of the markers in stripplot.
581
582
    Returns:
583
        ax
584
585
    """
586
587
    # A. organize input data for plotting--------------
588
    res_prefix = cell_pred_col
589
    ## cell-level data
590
    pred_score_df = adata.obs[
591
        [
592
            cell_pred_col,
593
            "y",
594
            "label",
595
            "patient_id",
596
            f"{res_prefix}_sample_auc",
597
            f"{res_prefix}_sample_auc_pvalue",
598
        ]
599
    ].copy()
600
601
    ## sample-level data
602
    sample_pData = pred_score_df.groupby(
603
        [
604
            "y",
605
            "label",
606
            "patient_id",
607
            f"{res_prefix}_sample_auc",
608
            f"{res_prefix}_sample_auc_pvalue",
609
        ],
610
        observed=True,
611
        as_index=False,
612
    )[cell_pred_col].max()
613
    sample_pData.rename(columns={cell_pred_col: "y_pos"}, inplace=True)
614
    sample_pData = sample_pData.sort_values(by=f"{res_prefix}_sample_auc").reset_index(
615
        drop=True
616
    )
617
618
    sample_order = sample_pData.patient_id.tolist()
619
620
    sample_threshold_index = (
621
        sample_pData[f"{res_prefix}_sample_auc"]
622
        .where(sample_pData[f"{res_prefix}_sample_auc"] >= 0.5)
623
        .first_valid_index()
624
    )
625
    if sample_threshold_index is None:
626
        if (sample_pData[f"{res_prefix}_sample_auc"] >= 0.5).all():
627
            sample_threshold = -0.5
628
        else:
629
            sample_threshold = len(sample_pData[f"{res_prefix}_sample_auc"]) - 0.5
630
    else:
631
        sample_threshold = sample_threshold_index - 0.5
632
633
    # B. plot--------------------------------------------
634
    if ax is None:
635
        ax = plt.gca()
636
637
    # Hide the right and top spines
638
    ax.spines[["right", "top"]].set_visible(False)
639
640
    # Violin plot
641
    sns.violinplot(
642
        y=cell_pred_col,
643
        x="patient_id",
644
        data=pred_score_df,
645
        order=sample_order,
646
        color="0.8",
647
        scale="width",
648
        fontsize=15,
649
        ax=ax,
650
        cut=0,
651
    )
652
653
    # Strip plot
654
    sns.stripplot(
655
        y=cell_pred_col,
656
        x="patient_id",
657
        hue="y",
658
        data=pred_score_df,
659
        order=sample_order,
660
        dodge=False,
661
        jitter=True,
662
        size=dot_size,
663
        ax=ax,
664
        palette=palette,
665
    )
666
667
    ax.axhline(y=0.5, color="0.8", linestyle="--")
668
    ax.axvline(x=sample_threshold, color="0.8", linestyle="--")
669
670
    # Add statistical signifiance (asterisks (*)) on top of each violin
671
    ## get position x
672
    yposlist = (sample_pData["y_pos"] + 0.03).tolist()
673
    ## get position y
674
    xposlist = range(len(yposlist))
675
    ## get text
676
    pvalue_list = sample_pData[f"{res_prefix}_sample_auc_pvalue"].tolist()
677
    asterisks_list = [convert_pvalue_to_asterisks(pvalue) for pvalue in pvalue_list]
678
    perm_stat_list = [
679
        "%.3f" % perm_stat
680
        for perm_stat in sample_pData[f"{res_prefix}_sample_auc"].tolist()
681
    ]
682
    text_list = [
683
        perm_stat + "\n" + asterisk
684
        for perm_stat, asterisk in zip(perm_stat_list, asterisks_list)
685
    ]
686
687
    for k in range(len(asterisks_list)):
688
        ax.text(x=xposlist[k], y=yposlist[k], s=text_list[k], ha="center", **text_kws)
689
690
    ax.set_title(cell_pred_col, pad=30)
691
    ax.set_xlabel(None)
692
    ax.set_ylabel("Prediction Probablity (Cell)", fontsize=13)
693
    ax.plot()
694
695
    ax.set_xticks(ax.get_xticks(), ax.get_xticklabels(), rotation=45, ha="right")
696
    if xticklabels_color:
697
        for xtick in ax.get_xticklabels():
698
            x_label = xtick.get_text()
699
            x_label_cate = sample_pData["y"][
700
                sample_pData["patient_id"] == x_label
701
            ].values[0]
702
            xtick.set_color(palette[x_label_cate])
703
704
    ax.legend(loc="upper left", title="Patient Label", bbox_to_anchor=(1.04, 1))
705
706
    return ax
707
708
709
### Plot patient level prediction scores
710
def make_single_spider(adata_test_final: AnnData, metric_idx: int, color: str, nrow: int, ncol: int) -> None:
711
    # number of variable
712
    categories = adata_test_final.uns["sample_score"].index.tolist()
713
    N = len(adata_test_final.uns["sample_score"].index)
714
715
    # We are going to plot the first line of the data frame.
716
    # But we need to repeat the first value to close the circular graph:
717
    values = (
718
        adata_test_final.uns["sample_score"]
719
        .iloc[:, metric_idx]
720
        .values.flatten()
721
        .tolist()
722
    )
723
    values += values[:1]
724
725
    # What will be the angle of each axis in the plot? (we divide the plot / number of variable)
726
    angles = [n / float(N) * 2 * pi for n in range(N)]
727
    angles += angles[:1]
728
729
    # Initialise the spider plot
730
    ax = plt.subplot(nrow, ncol, metric_idx + 1, polar=True)
731
732
    # If you want the first axis to be on top:
733
    ax.set_theta_offset(pi / 2)
734
    ax.set_theta_direction(-1)
735
736
    # Draw one axe per variable + add labels labels yet
737
    plt.xticks(angles[:-1], categories, color="grey", size=15)
738
739
    for label, i in zip(ax.get_xticklabels(), range(0, len(angles))):
740
        if i < len(angles) / 2:
741
            angle_text = angles[i] * (-180 / pi) + 90
742
            label.set_horizontalalignment("left")
743
744
        else:
745
            angle_text = angles[i] * (-180 / pi) - 90
746
            label.set_horizontalalignment("right")
747
        label.set_rotation(angle_text)
748
749
    # Draw ylabels
750
    ax.set_rlabel_position(0)
751
    plt.yticks([0.1, 0.3, 0.6], ["0.1", "0.3", "0.6"], color="grey", size=8)
752
    plt.ylim(0, 1.05)
753
754
    # Plot data
755
    ax.plot(angles, values, color=color, linewidth=2, linestyle="solid")
756
    ax.fill(angles, values, color=color, alpha=0.4)
757
    ax.grid(color="white")
758
    for ti, di in zip(angles, values):
759
        ax.text(
760
            ti, di - 0.02, "{0:.2f}".format(di), color="black", ha="center", va="center"
761
        )
762
763
    # Add a title
764
    t = adata_test_final.uns["sample_score"].columns[metric_idx]
765
    t = t.replace("_pred_score", "")
766
    plt.title(t, color="black", y=1.2, size=22)