Switch to unified view

a b/src/scpanel/select_gene.py
1
# import anndata
2
import itertools
3
import os
4
import pickle
5
import time
6
7
import matplotlib.pyplot as plt
8
from joblib import Parallel, delayed
9
10
# from sklearn.preprocessing import LabelEncoder
11
from sklearn import svm
12
13
# import scanpy as sc
14
# import numpy as np
15
# import pandas as pd
16
from sklearn.model_selection import StratifiedKFold
17
18
from .SVMRFECV import RFE, RFECV
19
from .utils_func import *
20
from anndata._core.anndata import AnnData
21
from matplotlib.axes._axes import Axes
22
from scpanel.SVMRFECV import RFECV
23
from typing import List, Optional, Tuple
24
25
26
def split_n_folds(adata_train: AnnData, nfold: int, out_dir: Optional[str]=None, random_state: int=2349) -> Tuple[List[List[int]], List[List[int]], List[List[float]]]:
27
    ## add: exclude patients without selected cell type
28
    n_cell_pat = adata_train.obs.groupby(["patient_id"])["ct"].count()
29
    exclude_pat = adata_train.obs["patient_id"].isin(n_cell_pat[n_cell_pat == 0].index)
30
    adata_train = adata_train[~exclude_pat]
31
32
    if sum(exclude_pat) > 0:
33
        print(
34
            n_cell_pat[n_cell_pat == 0].index.tolist(),
35
            "get excluded since no selected cell type appears",
36
        )
37
38
    ## split patients
39
    pat_meta_temp = adata_train.obs[["y", "patient_id"]].drop_duplicates().reset_index()
40
    cell_meta_temp = adata_train.obs.reset_index()
41
42
    patient_class = pat_meta_temp["y"].to_numpy()
43
    patient = pat_meta_temp["patient_id"].to_numpy()
44
45
    skf = StratifiedKFold(n_splits=nfold, shuffle=True, random_state=random_state)
46
    # sss = StratifiedShuffleSplit(n_splits=nfold, test_size = test_size)
47
48
    patient_train_id_list = []
49
    patient_val_id_list = []
50
51
    train_patient_list = []
52
    val_patient_list = []
53
54
    train_index_list = []
55
    val_index_list = []
56
57
    weight_list = []
58
59
    for train_index, val_index in skf.split(patient, patient_class):
60
61
        train_patient_list.append(train_index)
62
        val_patient_list.append(val_index)
63
64
        patient_train_id = patient[train_index]
65
        patient_val_id = patient[val_index]
66
67
        patient_train_id_list.append(patient_train_id)
68
        patient_val_id_list.append(patient_val_id)
69
70
        cell_meta_fold_train = cell_meta_temp[
71
            cell_meta_temp["patient_id"].isin(patient_train_id)
72
        ]
73
        cell_meta_fold_test = cell_meta_temp[
74
            cell_meta_temp["patient_id"].isin(patient_val_id)
75
        ]
76
77
        # compute weight for each cell in each fold's training set
78
        w_fold_train = compute_cell_weight(cell_meta_fold_train)
79
        weight_list.append(w_fold_train.tolist())
80
81
        # get positional index for train and test set in each fold
82
        cell_train_id = cell_meta_fold_train.index.tolist()
83
        cell_val_id = cell_meta_fold_test.index.tolist()
84
85
        # cell_train_id.sort()
86
        # cell_val_id.sort()
87
88
        if cell_train_id not in train_index_list:
89
            train_index_list.append(cell_train_id)
90
91
        if cell_val_id not in val_index_list:
92
            val_index_list.append(cell_val_id)
93
94
    ## check if weights (np.Series) have the same order as train_index_list
95
    # np.array_equiv([idx for fold in train_index_list for idx in fold],
96
    #                w_fold_train.index.values)
97
98
    if out_dir is not None:
99
        # Output
100
        if not os.path.exists(out_dir):
101
            os.makedirs(out_dir)
102
103
        ## Data and index
104
        X_train, y_train = get_X_y_from_ann(adata_train)
105
        with open(os.path.join(out_dir, "Data_X_y.pkl"), "wb") as f:
106
            d = {"features": X_train, "labels": y_train}
107
            pickle.dump(d, f, protocol=pickle.HIGHEST_PROTOCOL)
108
            f.close()
109
            del d
110
111
        with open(
112
            os.path.join(out_dir, "Data_" + str(nfold) + "fold_index.pkl"), "wb"
113
        ) as f:
114
            d = {
115
                "train": train_index_list,
116
                "val": val_index_list,
117
                "sample_weight": w_fold_train,
118
            }
119
            pickle.dump(d, f, protocol=pickle.HIGHEST_PROTOCOL)
120
            f.close()
121
            del d
122
123
        ## nfold splitting information
124
        # get n_cells, patient_id and class prop for each set
125
        all_list = train_index_list + val_index_list
126
        n_cells = [len(sublist) for sublist in all_list]
127
128
        patient_ids = patient_train_id_list + patient_val_id_list
129
130
        class_prop = [
131
            np.unique(y_train[index], return_counts=True)[1] for index in all_list
132
        ]
133
134
        patient_prop_train = [
135
            pat_meta_temp.loc[fold].y.value_counts().tolist()
136
            for fold in train_patient_list
137
        ]
138
        patient_prop_val = [
139
            pat_meta_temp.loc[fold].y.value_counts().tolist()
140
            for fold in val_patient_list
141
        ]
142
        patient_prop = patient_prop_train + patient_prop_val
143
144
        nfold_info = pd.DataFrame(
145
            {
146
                "n_cells": n_cells,
147
                "patient_ids": patient_ids,
148
                "class_prop": class_prop,
149
                "pt_prop": patient_prop,
150
            }
151
        )
152
153
        train_col = ["train_f" + str(i) for i in range(1, nfold + 1)]
154
        val_col = ["val_f" + str(i) for i in range(1, nfold + 1)]
155
        nfold_info.index = train_col + val_col
156
        nfold_info.to_csv(f"{out_dir}/split_nfold_info.csv")
157
158
    return train_index_list, val_index_list, weight_list
159
160
161
def gene_score(
162
    adata_train: AnnData,
163
    train_index_list: List[List[int]],
164
    val_index_list: List[List[int]],
165
    sample_weight_list: List[List[float]],
166
    out_dir: str,
167
    ncpus: int,
168
    step: float=0.03,
169
    metric: str="average_precision",
170
    verbose: bool=False,
171
) -> Tuple[AnnData, RFECV]:
172
173
    # metric: https://scikit-learn.org/stable/modules/model_evaluation.html
174
175
    X, y = get_X_y_from_ann(adata_train)
176
177
    # Fill NaN in numpy
178
    X = np.nan_to_num(X)
179
    y = np.nan_to_num(y)
180
181
    # model------------
182
    # model = svm.SVC(kernel="linear", class_weight = 'balanced', verbose=verbose, random_state = 123)
183
    model = svm.SVC(kernel="linear", verbose=verbose, random_state=123)
184
185
    rfecv = RFECV(
186
        estimator=model, step=step, scoring=metric, cv=10, n_jobs=ncpus, verbose=0
187
    )
188
    # X = StandardScaler().fit_transform(X)
189
    rfecv.fit(
190
        X, y, train_index_list, val_index_list, sample_weight_list=sample_weight_list
191
    )
192
193
    # organize dataframe for results
194
    n_gene = X.shape[1]
195
    cv_dict = rfecv.cv_results_.copy()
196
    # cv_dict.pop('mean_feature_ranking')
197
    cv_df = pd.DataFrame.from_dict(cv_dict)
198
199
    # find number of features selected in each iteration
200
    import math
201
202
    nfeat = n_gene
203
    step = step
204
    steps = [n_gene]
205
    while nfeat > 1:
206
        nstep = math.ceil(nfeat * step)
207
        nfeat = nfeat - nstep
208
        steps.append(nfeat)
209
210
    cv_df.index = steps[::-1]
211
212
    adata_train.uns["rfecv_result"] = cv_df
213
    adata_train.uns["rfecv_result_metric"] = rfecv.scoring
214
215
    if out_dir is not None:
216
        # save tmp output------------------
217
        if not os.path.exists(out_dir):
218
            os.makedirs(out_dir)
219
220
        model_file = f"{out_dir}/rfecv_ranking_by_{type(model).__name__}.sav"
221
        pickle.dump(rfecv, open(model_file, "wb"))
222
223
    return adata_train, rfecv
224
225
226
def plot_gene_score(adata_train: AnnData, n_genes_plot: int=200, width: int=5, height: int=4, k: Optional[int]=None) -> Axes:
227
228
    cv_df = adata_train.uns["rfecv_result"].filter(regex="mean|split")
229
    cv_df = cv_df.loc[:n_genes_plot,]
230
    cv_df.columns = cv_df.columns.str.rstrip("_test_score")
231
232
    scoring_metrics = adata_train.uns["rfecv_result_metric"]
233
    if scoring_metrics == "average_precision":
234
        ylabel = "AUPRC"
235
    elif scoring_metrics == "roc_auc":
236
        ylabel = "AUROC"
237
    else:
238
        ylabel = scoring_metrics
239
240
    fig, axes = plt.subplots(figsize=(width, height))
241
    for columnName, columnData in cv_df.items():
242
        if "mean" in columnName:
243
            axes.plot(columnData, label=columnName)
244
        else:
245
            axes.plot(columnData, label=columnName, linestyle="dashed", alpha=0.6)
246
247
    axes.spines[["right", "top"]].set_visible(False)
248
249
    plt.xlabel("Number of Genes")
250
    plt.ylabel(ylabel)
251
    plt.legend()
252
253
    if k is not None:
254
        k_score = cv_df.loc[k, "mean"]
255
        y_label_adjust = (cv_df["mean"].max() - cv_df["mean"].min()) / 2
256
257
        plt.axvline(x=k, color="r", linestyle=":")
258
        plt.text(
259
            x=k + 4, y=k_score - y_label_adjust, s=f"n={k}\n{ylabel}={k_score:.3f}"
260
        )
261
262
    return axes
263
264
265
def decide_k(adata_train: AnnData, n_genes_plot: int=100) -> int:
266
    cv_df = adata_train.uns["rfecv_result"]
267
    cv_df = cv_df.loc[:n_genes_plot, :]
268
269
    data = cv_df.reset_index()[["index", "mean_test_score"]].to_numpy()
270
    A = data[0]
271
    B = data[-1]
272
    # 利用ABC三点坐标计算三角形面积,利用AB边长倒推三角形的高
273
    Dist = dict()
274
    for i in range(1, len(data)):
275
        C = data[i]
276
        ngene = C[0]
277
        D = np.append(np.vstack((A, B, C)), [[1], [1], [1]], axis=1)
278
        S = 1 / 2 * np.linalg.det(D)
279
        Dist[ngene] = 2 * S / np.linalg.norm(A - B)
280
281
    top_n_feat = int(max(Dist, key=Dist.get))
282
    top_n_feat_auc = cv_df.loc[max(Dist, key=Dist.get), "mean_test_score"]
283
284
    # print(f'Number of genes to select = {top_n_feat}')
285
286
    return top_n_feat
287
288
289
def select_gene(
290
    adata_train: AnnData, out_dir: Optional[str]=None, step: float=0.03, top_n_feat: int=5, n_genes_plot: int=100, verbose: int=0
291
) -> AnnData:
292
293
    # retrieve top_n_feat from one SVM-RFE run
294
    X, y = get_X_y_from_ann(adata_train)
295
296
    # Fill NaN in numpy
297
    X = np.nan_to_num(X)
298
    y = np.nan_to_num(y)
299
300
    # model------------
301
    model = svm.SVC(kernel="linear", random_state=123)
302
303
    ## get ranking of all selected features
304
    selector = RFE(model, n_features_to_select=1, step=step, verbose=verbose)
305
306
    sample_weight = compute_cell_weight(adata_train)
307
    selector.fit(X, y, sample_weight=sample_weight)
308
309
    feature_ranking = pd.DataFrame(
310
        {"ranking": selector.ranking_}, index=adata_train.var_names
311
    ).sort_values(by="ranking")
312
    sig_list_ranked = feature_ranking.index[:top_n_feat].tolist()
313
    # print(sig_list_ranked)
314
315
    adata_train.uns["svm_rfe_genes"] = sig_list_ranked
316
    adata_train.var["ranking"] = selector.ranking_
317
318
    if out_dir is not None:
319
        # output gene list
320
        if not os.path.exists(out_dir):
321
            os.makedirs(out_dir)
322
323
        with open(f"{out_dir}/sig_svm.txt", "w") as f:
324
            for item in sig_list_ranked:
325
                f.write("%s\n" % item)
326
327
        # output adata_train_s with gene scores
328
        adata_train.write_h5ad(f"{out_dir}/adata_train_s.h5ad")
329
330
    return adata_train
331
332
333
def select_gene_stable(
334
    adata_train,
335
    n_iter=20,
336
    nfold=2,
337
    downsample_prop_list=[0.6, 0.8],
338
    num_cores=1,
339
    out_dir=None,
340
):
341
342
    def _single_fit(downsample_prop, i, adata_train, nfold, out_dir):
343
344
        downsample_size = round(adata_train.n_obs * downsample_prop)
345
        i = i + 1
346
347
        # create folder to output results for each iteration
348
        out_dir = f"{out_dir}/{downsample_size}/{i}"
349
        if not os.path.exists(out_dir):
350
            os.makedirs(out_dir)
351
352
        # metadata for stratified downsampling
353
        adata_train.obs["downsample_stratify"] = adata_train.obs[["patient_id"]].astype(
354
            "category"
355
        )
356
357
        down_index_i = resample(
358
            adata_train.obs_names,
359
            replace=False,
360
            n_samples=downsample_size,
361
            stratify=adata_train.obs["downsample_stratify"],
362
            random_state=i,
363
        )
364
        # downsampling
365
        adata_train_i = adata_train[adata_train.obs_names.isin(down_index_i),].copy()
366
367
        # QC for downsampled traninig data
368
        # 1. for each cell type, remove samples with <20 cells
369
        # 2. remove cell types with < 2 samples
370
        # 3. Remove 0-expressed genes
371
        # 4. Update training data
372
373
        min_cells = 20
374
        ## Number of cells in each patient
375
        n_cell_pt = adata_train_i.obs.groupby(
376
            ["patient_id"], observed=True, as_index=False
377
        ).size()
378
        # Remove paients with cells less than min_cells
379
        pt_keep = n_cell_pt.patient_id[n_cell_pt["size"] >= min_cells].tolist()
380
381
        ## Cell types with 0 patient has cells >= min_cells
382
        if len(pt_keep) > 0:
383
            adata_train_i = adata_train_i[
384
                adata_train_i.obs["patient_id"].isin(pt_keep),
385
            ]
386
            n_cell_pt = adata_train_i.obs.groupby(
387
                ["y", "patient_id"], observed=True, as_index=False
388
            ).size()
389
            ## Skip cell types with less than 2 patients in at least one condition
390
            if (n_cell_pt.y.nunique() >= 2) & ((n_cell_pt.y.value_counts() >= 2).all()):
391
                print("we have >= 2 samples in each condition...")
392
393
                ## Remove 0-expressed genes
394
                sc.pp.filter_genes(adata_train_i, min_cells=1)
395
396
                # Split downsampled train data into folds
397
                train_index_list, val_index_list, sample_weight_list = split_n_folds(
398
                    adata_train_i, nfold=nfold, out_dir=out_dir, random_state=2349
399
                )
400
401
                adata_train_i, rfecv_i = gene_score(
402
                    adata_train_i,
403
                    train_index_list,
404
                    val_index_list,
405
                    sample_weight_list=sample_weight_list,
406
                    step=0.03,
407
                    out_dir=out_dir,
408
                    ncpus=None,
409
                    verbose=False,
410
                )
411
412
                k = decide_k(adata_train_i, n_genes_plot=100)
413
                adata_train_i = select_gene(
414
                    adata_train_i, top_n_feat=k, step=0.03, out_dir=out_dir
415
                )
416
                sig_svm_i = adata_train_i.uns["svm_rfe_genes"]
417
418
        res_i = pd.DataFrame(sig_svm_i, columns=["gene"])
419
        res_i["downsample_prop"] = downsample_prop
420
        res_i["downsample_size"] = downsample_size
421
        res_i["n_iter"] = i
422
423
        return res_i
424
425
    start = time.time()
426
427
    paramlist = itertools.product(downsample_prop_list, range(n_iter))  # 2 nested loops
428
    res = Parallel(n_jobs=num_cores)(
429
        delayed(_single_fit)(
430
            downsample_prop, i, adata_train=adata_train, nfold=nfold, out_dir=out_dir
431
        )
432
        for downsample_prop, i in paramlist
433
    )
434
    end = time.time()
435
436
    res_df = pd.concat(res)
437
    gene_freq_df = res_df.groupby(
438
        ["downsample_prop", "downsample_size", "gene"], as_index=False
439
    ).size()
440
    adata_train.uns["scPanel_stable_rfecv_result"] = gene_freq_df
441
442
    gene_mean = gene_freq_df.groupby("gene")["size"].mean()
443
    gene_mean = gene_mean[gene_mean > round(n_iter * 0.5)].sort_values(ascending=False)
444
    adata_train.uns["svm_rfe_genes_stable"] = gene_mean.index.tolist()
445
    adata_train.uns["svm_rfe_genes_stable_time"] = end - start
446
447
    return adata_train