|
a |
|
b/src/scpanel/select_gene.py |
|
|
1 |
# import anndata |
|
|
2 |
import itertools |
|
|
3 |
import os |
|
|
4 |
import pickle |
|
|
5 |
import time |
|
|
6 |
|
|
|
7 |
import matplotlib.pyplot as plt |
|
|
8 |
from joblib import Parallel, delayed |
|
|
9 |
|
|
|
10 |
# from sklearn.preprocessing import LabelEncoder |
|
|
11 |
from sklearn import svm |
|
|
12 |
|
|
|
13 |
# import scanpy as sc |
|
|
14 |
# import numpy as np |
|
|
15 |
# import pandas as pd |
|
|
16 |
from sklearn.model_selection import StratifiedKFold |
|
|
17 |
|
|
|
18 |
from .SVMRFECV import RFE, RFECV |
|
|
19 |
from .utils_func import * |
|
|
20 |
from anndata._core.anndata import AnnData |
|
|
21 |
from matplotlib.axes._axes import Axes |
|
|
22 |
from scpanel.SVMRFECV import RFECV |
|
|
23 |
from typing import List, Optional, Tuple |
|
|
24 |
|
|
|
25 |
|
|
|
26 |
def split_n_folds(adata_train: AnnData, nfold: int, out_dir: Optional[str]=None, random_state: int=2349) -> Tuple[List[List[int]], List[List[int]], List[List[float]]]: |
|
|
27 |
## add: exclude patients without selected cell type |
|
|
28 |
n_cell_pat = adata_train.obs.groupby(["patient_id"])["ct"].count() |
|
|
29 |
exclude_pat = adata_train.obs["patient_id"].isin(n_cell_pat[n_cell_pat == 0].index) |
|
|
30 |
adata_train = adata_train[~exclude_pat] |
|
|
31 |
|
|
|
32 |
if sum(exclude_pat) > 0: |
|
|
33 |
print( |
|
|
34 |
n_cell_pat[n_cell_pat == 0].index.tolist(), |
|
|
35 |
"get excluded since no selected cell type appears", |
|
|
36 |
) |
|
|
37 |
|
|
|
38 |
## split patients |
|
|
39 |
pat_meta_temp = adata_train.obs[["y", "patient_id"]].drop_duplicates().reset_index() |
|
|
40 |
cell_meta_temp = adata_train.obs.reset_index() |
|
|
41 |
|
|
|
42 |
patient_class = pat_meta_temp["y"].to_numpy() |
|
|
43 |
patient = pat_meta_temp["patient_id"].to_numpy() |
|
|
44 |
|
|
|
45 |
skf = StratifiedKFold(n_splits=nfold, shuffle=True, random_state=random_state) |
|
|
46 |
# sss = StratifiedShuffleSplit(n_splits=nfold, test_size = test_size) |
|
|
47 |
|
|
|
48 |
patient_train_id_list = [] |
|
|
49 |
patient_val_id_list = [] |
|
|
50 |
|
|
|
51 |
train_patient_list = [] |
|
|
52 |
val_patient_list = [] |
|
|
53 |
|
|
|
54 |
train_index_list = [] |
|
|
55 |
val_index_list = [] |
|
|
56 |
|
|
|
57 |
weight_list = [] |
|
|
58 |
|
|
|
59 |
for train_index, val_index in skf.split(patient, patient_class): |
|
|
60 |
|
|
|
61 |
train_patient_list.append(train_index) |
|
|
62 |
val_patient_list.append(val_index) |
|
|
63 |
|
|
|
64 |
patient_train_id = patient[train_index] |
|
|
65 |
patient_val_id = patient[val_index] |
|
|
66 |
|
|
|
67 |
patient_train_id_list.append(patient_train_id) |
|
|
68 |
patient_val_id_list.append(patient_val_id) |
|
|
69 |
|
|
|
70 |
cell_meta_fold_train = cell_meta_temp[ |
|
|
71 |
cell_meta_temp["patient_id"].isin(patient_train_id) |
|
|
72 |
] |
|
|
73 |
cell_meta_fold_test = cell_meta_temp[ |
|
|
74 |
cell_meta_temp["patient_id"].isin(patient_val_id) |
|
|
75 |
] |
|
|
76 |
|
|
|
77 |
# compute weight for each cell in each fold's training set |
|
|
78 |
w_fold_train = compute_cell_weight(cell_meta_fold_train) |
|
|
79 |
weight_list.append(w_fold_train.tolist()) |
|
|
80 |
|
|
|
81 |
# get positional index for train and test set in each fold |
|
|
82 |
cell_train_id = cell_meta_fold_train.index.tolist() |
|
|
83 |
cell_val_id = cell_meta_fold_test.index.tolist() |
|
|
84 |
|
|
|
85 |
# cell_train_id.sort() |
|
|
86 |
# cell_val_id.sort() |
|
|
87 |
|
|
|
88 |
if cell_train_id not in train_index_list: |
|
|
89 |
train_index_list.append(cell_train_id) |
|
|
90 |
|
|
|
91 |
if cell_val_id not in val_index_list: |
|
|
92 |
val_index_list.append(cell_val_id) |
|
|
93 |
|
|
|
94 |
## check if weights (np.Series) have the same order as train_index_list |
|
|
95 |
# np.array_equiv([idx for fold in train_index_list for idx in fold], |
|
|
96 |
# w_fold_train.index.values) |
|
|
97 |
|
|
|
98 |
if out_dir is not None: |
|
|
99 |
# Output |
|
|
100 |
if not os.path.exists(out_dir): |
|
|
101 |
os.makedirs(out_dir) |
|
|
102 |
|
|
|
103 |
## Data and index |
|
|
104 |
X_train, y_train = get_X_y_from_ann(adata_train) |
|
|
105 |
with open(os.path.join(out_dir, "Data_X_y.pkl"), "wb") as f: |
|
|
106 |
d = {"features": X_train, "labels": y_train} |
|
|
107 |
pickle.dump(d, f, protocol=pickle.HIGHEST_PROTOCOL) |
|
|
108 |
f.close() |
|
|
109 |
del d |
|
|
110 |
|
|
|
111 |
with open( |
|
|
112 |
os.path.join(out_dir, "Data_" + str(nfold) + "fold_index.pkl"), "wb" |
|
|
113 |
) as f: |
|
|
114 |
d = { |
|
|
115 |
"train": train_index_list, |
|
|
116 |
"val": val_index_list, |
|
|
117 |
"sample_weight": w_fold_train, |
|
|
118 |
} |
|
|
119 |
pickle.dump(d, f, protocol=pickle.HIGHEST_PROTOCOL) |
|
|
120 |
f.close() |
|
|
121 |
del d |
|
|
122 |
|
|
|
123 |
## nfold splitting information |
|
|
124 |
# get n_cells, patient_id and class prop for each set |
|
|
125 |
all_list = train_index_list + val_index_list |
|
|
126 |
n_cells = [len(sublist) for sublist in all_list] |
|
|
127 |
|
|
|
128 |
patient_ids = patient_train_id_list + patient_val_id_list |
|
|
129 |
|
|
|
130 |
class_prop = [ |
|
|
131 |
np.unique(y_train[index], return_counts=True)[1] for index in all_list |
|
|
132 |
] |
|
|
133 |
|
|
|
134 |
patient_prop_train = [ |
|
|
135 |
pat_meta_temp.loc[fold].y.value_counts().tolist() |
|
|
136 |
for fold in train_patient_list |
|
|
137 |
] |
|
|
138 |
patient_prop_val = [ |
|
|
139 |
pat_meta_temp.loc[fold].y.value_counts().tolist() |
|
|
140 |
for fold in val_patient_list |
|
|
141 |
] |
|
|
142 |
patient_prop = patient_prop_train + patient_prop_val |
|
|
143 |
|
|
|
144 |
nfold_info = pd.DataFrame( |
|
|
145 |
{ |
|
|
146 |
"n_cells": n_cells, |
|
|
147 |
"patient_ids": patient_ids, |
|
|
148 |
"class_prop": class_prop, |
|
|
149 |
"pt_prop": patient_prop, |
|
|
150 |
} |
|
|
151 |
) |
|
|
152 |
|
|
|
153 |
train_col = ["train_f" + str(i) for i in range(1, nfold + 1)] |
|
|
154 |
val_col = ["val_f" + str(i) for i in range(1, nfold + 1)] |
|
|
155 |
nfold_info.index = train_col + val_col |
|
|
156 |
nfold_info.to_csv(f"{out_dir}/split_nfold_info.csv") |
|
|
157 |
|
|
|
158 |
return train_index_list, val_index_list, weight_list |
|
|
159 |
|
|
|
160 |
|
|
|
161 |
def gene_score( |
|
|
162 |
adata_train: AnnData, |
|
|
163 |
train_index_list: List[List[int]], |
|
|
164 |
val_index_list: List[List[int]], |
|
|
165 |
sample_weight_list: List[List[float]], |
|
|
166 |
out_dir: str, |
|
|
167 |
ncpus: int, |
|
|
168 |
step: float=0.03, |
|
|
169 |
metric: str="average_precision", |
|
|
170 |
verbose: bool=False, |
|
|
171 |
) -> Tuple[AnnData, RFECV]: |
|
|
172 |
|
|
|
173 |
# metric: https://scikit-learn.org/stable/modules/model_evaluation.html |
|
|
174 |
|
|
|
175 |
X, y = get_X_y_from_ann(adata_train) |
|
|
176 |
|
|
|
177 |
# Fill NaN in numpy |
|
|
178 |
X = np.nan_to_num(X) |
|
|
179 |
y = np.nan_to_num(y) |
|
|
180 |
|
|
|
181 |
# model------------ |
|
|
182 |
# model = svm.SVC(kernel="linear", class_weight = 'balanced', verbose=verbose, random_state = 123) |
|
|
183 |
model = svm.SVC(kernel="linear", verbose=verbose, random_state=123) |
|
|
184 |
|
|
|
185 |
rfecv = RFECV( |
|
|
186 |
estimator=model, step=step, scoring=metric, cv=10, n_jobs=ncpus, verbose=0 |
|
|
187 |
) |
|
|
188 |
# X = StandardScaler().fit_transform(X) |
|
|
189 |
rfecv.fit( |
|
|
190 |
X, y, train_index_list, val_index_list, sample_weight_list=sample_weight_list |
|
|
191 |
) |
|
|
192 |
|
|
|
193 |
# organize dataframe for results |
|
|
194 |
n_gene = X.shape[1] |
|
|
195 |
cv_dict = rfecv.cv_results_.copy() |
|
|
196 |
# cv_dict.pop('mean_feature_ranking') |
|
|
197 |
cv_df = pd.DataFrame.from_dict(cv_dict) |
|
|
198 |
|
|
|
199 |
# find number of features selected in each iteration |
|
|
200 |
import math |
|
|
201 |
|
|
|
202 |
nfeat = n_gene |
|
|
203 |
step = step |
|
|
204 |
steps = [n_gene] |
|
|
205 |
while nfeat > 1: |
|
|
206 |
nstep = math.ceil(nfeat * step) |
|
|
207 |
nfeat = nfeat - nstep |
|
|
208 |
steps.append(nfeat) |
|
|
209 |
|
|
|
210 |
cv_df.index = steps[::-1] |
|
|
211 |
|
|
|
212 |
adata_train.uns["rfecv_result"] = cv_df |
|
|
213 |
adata_train.uns["rfecv_result_metric"] = rfecv.scoring |
|
|
214 |
|
|
|
215 |
if out_dir is not None: |
|
|
216 |
# save tmp output------------------ |
|
|
217 |
if not os.path.exists(out_dir): |
|
|
218 |
os.makedirs(out_dir) |
|
|
219 |
|
|
|
220 |
model_file = f"{out_dir}/rfecv_ranking_by_{type(model).__name__}.sav" |
|
|
221 |
pickle.dump(rfecv, open(model_file, "wb")) |
|
|
222 |
|
|
|
223 |
return adata_train, rfecv |
|
|
224 |
|
|
|
225 |
|
|
|
226 |
def plot_gene_score(adata_train: AnnData, n_genes_plot: int=200, width: int=5, height: int=4, k: Optional[int]=None) -> Axes: |
|
|
227 |
|
|
|
228 |
cv_df = adata_train.uns["rfecv_result"].filter(regex="mean|split") |
|
|
229 |
cv_df = cv_df.loc[:n_genes_plot,] |
|
|
230 |
cv_df.columns = cv_df.columns.str.rstrip("_test_score") |
|
|
231 |
|
|
|
232 |
scoring_metrics = adata_train.uns["rfecv_result_metric"] |
|
|
233 |
if scoring_metrics == "average_precision": |
|
|
234 |
ylabel = "AUPRC" |
|
|
235 |
elif scoring_metrics == "roc_auc": |
|
|
236 |
ylabel = "AUROC" |
|
|
237 |
else: |
|
|
238 |
ylabel = scoring_metrics |
|
|
239 |
|
|
|
240 |
fig, axes = plt.subplots(figsize=(width, height)) |
|
|
241 |
for columnName, columnData in cv_df.items(): |
|
|
242 |
if "mean" in columnName: |
|
|
243 |
axes.plot(columnData, label=columnName) |
|
|
244 |
else: |
|
|
245 |
axes.plot(columnData, label=columnName, linestyle="dashed", alpha=0.6) |
|
|
246 |
|
|
|
247 |
axes.spines[["right", "top"]].set_visible(False) |
|
|
248 |
|
|
|
249 |
plt.xlabel("Number of Genes") |
|
|
250 |
plt.ylabel(ylabel) |
|
|
251 |
plt.legend() |
|
|
252 |
|
|
|
253 |
if k is not None: |
|
|
254 |
k_score = cv_df.loc[k, "mean"] |
|
|
255 |
y_label_adjust = (cv_df["mean"].max() - cv_df["mean"].min()) / 2 |
|
|
256 |
|
|
|
257 |
plt.axvline(x=k, color="r", linestyle=":") |
|
|
258 |
plt.text( |
|
|
259 |
x=k + 4, y=k_score - y_label_adjust, s=f"n={k}\n{ylabel}={k_score:.3f}" |
|
|
260 |
) |
|
|
261 |
|
|
|
262 |
return axes |
|
|
263 |
|
|
|
264 |
|
|
|
265 |
def decide_k(adata_train: AnnData, n_genes_plot: int=100) -> int: |
|
|
266 |
cv_df = adata_train.uns["rfecv_result"] |
|
|
267 |
cv_df = cv_df.loc[:n_genes_plot, :] |
|
|
268 |
|
|
|
269 |
data = cv_df.reset_index()[["index", "mean_test_score"]].to_numpy() |
|
|
270 |
A = data[0] |
|
|
271 |
B = data[-1] |
|
|
272 |
# 利用ABC三点坐标计算三角形面积,利用AB边长倒推三角形的高 |
|
|
273 |
Dist = dict() |
|
|
274 |
for i in range(1, len(data)): |
|
|
275 |
C = data[i] |
|
|
276 |
ngene = C[0] |
|
|
277 |
D = np.append(np.vstack((A, B, C)), [[1], [1], [1]], axis=1) |
|
|
278 |
S = 1 / 2 * np.linalg.det(D) |
|
|
279 |
Dist[ngene] = 2 * S / np.linalg.norm(A - B) |
|
|
280 |
|
|
|
281 |
top_n_feat = int(max(Dist, key=Dist.get)) |
|
|
282 |
top_n_feat_auc = cv_df.loc[max(Dist, key=Dist.get), "mean_test_score"] |
|
|
283 |
|
|
|
284 |
# print(f'Number of genes to select = {top_n_feat}') |
|
|
285 |
|
|
|
286 |
return top_n_feat |
|
|
287 |
|
|
|
288 |
|
|
|
289 |
def select_gene( |
|
|
290 |
adata_train: AnnData, out_dir: Optional[str]=None, step: float=0.03, top_n_feat: int=5, n_genes_plot: int=100, verbose: int=0 |
|
|
291 |
) -> AnnData: |
|
|
292 |
|
|
|
293 |
# retrieve top_n_feat from one SVM-RFE run |
|
|
294 |
X, y = get_X_y_from_ann(adata_train) |
|
|
295 |
|
|
|
296 |
# Fill NaN in numpy |
|
|
297 |
X = np.nan_to_num(X) |
|
|
298 |
y = np.nan_to_num(y) |
|
|
299 |
|
|
|
300 |
# model------------ |
|
|
301 |
model = svm.SVC(kernel="linear", random_state=123) |
|
|
302 |
|
|
|
303 |
## get ranking of all selected features |
|
|
304 |
selector = RFE(model, n_features_to_select=1, step=step, verbose=verbose) |
|
|
305 |
|
|
|
306 |
sample_weight = compute_cell_weight(adata_train) |
|
|
307 |
selector.fit(X, y, sample_weight=sample_weight) |
|
|
308 |
|
|
|
309 |
feature_ranking = pd.DataFrame( |
|
|
310 |
{"ranking": selector.ranking_}, index=adata_train.var_names |
|
|
311 |
).sort_values(by="ranking") |
|
|
312 |
sig_list_ranked = feature_ranking.index[:top_n_feat].tolist() |
|
|
313 |
# print(sig_list_ranked) |
|
|
314 |
|
|
|
315 |
adata_train.uns["svm_rfe_genes"] = sig_list_ranked |
|
|
316 |
adata_train.var["ranking"] = selector.ranking_ |
|
|
317 |
|
|
|
318 |
if out_dir is not None: |
|
|
319 |
# output gene list |
|
|
320 |
if not os.path.exists(out_dir): |
|
|
321 |
os.makedirs(out_dir) |
|
|
322 |
|
|
|
323 |
with open(f"{out_dir}/sig_svm.txt", "w") as f: |
|
|
324 |
for item in sig_list_ranked: |
|
|
325 |
f.write("%s\n" % item) |
|
|
326 |
|
|
|
327 |
# output adata_train_s with gene scores |
|
|
328 |
adata_train.write_h5ad(f"{out_dir}/adata_train_s.h5ad") |
|
|
329 |
|
|
|
330 |
return adata_train |
|
|
331 |
|
|
|
332 |
|
|
|
333 |
def select_gene_stable( |
|
|
334 |
adata_train, |
|
|
335 |
n_iter=20, |
|
|
336 |
nfold=2, |
|
|
337 |
downsample_prop_list=[0.6, 0.8], |
|
|
338 |
num_cores=1, |
|
|
339 |
out_dir=None, |
|
|
340 |
): |
|
|
341 |
|
|
|
342 |
def _single_fit(downsample_prop, i, adata_train, nfold, out_dir): |
|
|
343 |
|
|
|
344 |
downsample_size = round(adata_train.n_obs * downsample_prop) |
|
|
345 |
i = i + 1 |
|
|
346 |
|
|
|
347 |
# create folder to output results for each iteration |
|
|
348 |
out_dir = f"{out_dir}/{downsample_size}/{i}" |
|
|
349 |
if not os.path.exists(out_dir): |
|
|
350 |
os.makedirs(out_dir) |
|
|
351 |
|
|
|
352 |
# metadata for stratified downsampling |
|
|
353 |
adata_train.obs["downsample_stratify"] = adata_train.obs[["patient_id"]].astype( |
|
|
354 |
"category" |
|
|
355 |
) |
|
|
356 |
|
|
|
357 |
down_index_i = resample( |
|
|
358 |
adata_train.obs_names, |
|
|
359 |
replace=False, |
|
|
360 |
n_samples=downsample_size, |
|
|
361 |
stratify=adata_train.obs["downsample_stratify"], |
|
|
362 |
random_state=i, |
|
|
363 |
) |
|
|
364 |
# downsampling |
|
|
365 |
adata_train_i = adata_train[adata_train.obs_names.isin(down_index_i),].copy() |
|
|
366 |
|
|
|
367 |
# QC for downsampled traninig data |
|
|
368 |
# 1. for each cell type, remove samples with <20 cells |
|
|
369 |
# 2. remove cell types with < 2 samples |
|
|
370 |
# 3. Remove 0-expressed genes |
|
|
371 |
# 4. Update training data |
|
|
372 |
|
|
|
373 |
min_cells = 20 |
|
|
374 |
## Number of cells in each patient |
|
|
375 |
n_cell_pt = adata_train_i.obs.groupby( |
|
|
376 |
["patient_id"], observed=True, as_index=False |
|
|
377 |
).size() |
|
|
378 |
# Remove paients with cells less than min_cells |
|
|
379 |
pt_keep = n_cell_pt.patient_id[n_cell_pt["size"] >= min_cells].tolist() |
|
|
380 |
|
|
|
381 |
## Cell types with 0 patient has cells >= min_cells |
|
|
382 |
if len(pt_keep) > 0: |
|
|
383 |
adata_train_i = adata_train_i[ |
|
|
384 |
adata_train_i.obs["patient_id"].isin(pt_keep), |
|
|
385 |
] |
|
|
386 |
n_cell_pt = adata_train_i.obs.groupby( |
|
|
387 |
["y", "patient_id"], observed=True, as_index=False |
|
|
388 |
).size() |
|
|
389 |
## Skip cell types with less than 2 patients in at least one condition |
|
|
390 |
if (n_cell_pt.y.nunique() >= 2) & ((n_cell_pt.y.value_counts() >= 2).all()): |
|
|
391 |
print("we have >= 2 samples in each condition...") |
|
|
392 |
|
|
|
393 |
## Remove 0-expressed genes |
|
|
394 |
sc.pp.filter_genes(adata_train_i, min_cells=1) |
|
|
395 |
|
|
|
396 |
# Split downsampled train data into folds |
|
|
397 |
train_index_list, val_index_list, sample_weight_list = split_n_folds( |
|
|
398 |
adata_train_i, nfold=nfold, out_dir=out_dir, random_state=2349 |
|
|
399 |
) |
|
|
400 |
|
|
|
401 |
adata_train_i, rfecv_i = gene_score( |
|
|
402 |
adata_train_i, |
|
|
403 |
train_index_list, |
|
|
404 |
val_index_list, |
|
|
405 |
sample_weight_list=sample_weight_list, |
|
|
406 |
step=0.03, |
|
|
407 |
out_dir=out_dir, |
|
|
408 |
ncpus=None, |
|
|
409 |
verbose=False, |
|
|
410 |
) |
|
|
411 |
|
|
|
412 |
k = decide_k(adata_train_i, n_genes_plot=100) |
|
|
413 |
adata_train_i = select_gene( |
|
|
414 |
adata_train_i, top_n_feat=k, step=0.03, out_dir=out_dir |
|
|
415 |
) |
|
|
416 |
sig_svm_i = adata_train_i.uns["svm_rfe_genes"] |
|
|
417 |
|
|
|
418 |
res_i = pd.DataFrame(sig_svm_i, columns=["gene"]) |
|
|
419 |
res_i["downsample_prop"] = downsample_prop |
|
|
420 |
res_i["downsample_size"] = downsample_size |
|
|
421 |
res_i["n_iter"] = i |
|
|
422 |
|
|
|
423 |
return res_i |
|
|
424 |
|
|
|
425 |
start = time.time() |
|
|
426 |
|
|
|
427 |
paramlist = itertools.product(downsample_prop_list, range(n_iter)) # 2 nested loops |
|
|
428 |
res = Parallel(n_jobs=num_cores)( |
|
|
429 |
delayed(_single_fit)( |
|
|
430 |
downsample_prop, i, adata_train=adata_train, nfold=nfold, out_dir=out_dir |
|
|
431 |
) |
|
|
432 |
for downsample_prop, i in paramlist |
|
|
433 |
) |
|
|
434 |
end = time.time() |
|
|
435 |
|
|
|
436 |
res_df = pd.concat(res) |
|
|
437 |
gene_freq_df = res_df.groupby( |
|
|
438 |
["downsample_prop", "downsample_size", "gene"], as_index=False |
|
|
439 |
).size() |
|
|
440 |
adata_train.uns["scPanel_stable_rfecv_result"] = gene_freq_df |
|
|
441 |
|
|
|
442 |
gene_mean = gene_freq_df.groupby("gene")["size"].mean() |
|
|
443 |
gene_mean = gene_mean[gene_mean > round(n_iter * 0.5)].sort_values(ascending=False) |
|
|
444 |
adata_train.uns["svm_rfe_genes_stable"] = gene_mean.index.tolist() |
|
|
445 |
adata_train.uns["svm_rfe_genes_stable_time"] = end - start |
|
|
446 |
|
|
|
447 |
return adata_train |