|
a |
|
b/tests/test_scpanel.py |
|
|
1 |
import sys |
|
|
2 |
from scpanel.utils_func import * |
|
|
3 |
from scpanel.split_patient import * |
|
|
4 |
from scpanel.select_cell import * |
|
|
5 |
from scpanel.select_gene import * |
|
|
6 |
from scpanel.train import * |
|
|
7 |
from scpanel.settings import * |
|
|
8 |
import anndata |
|
|
9 |
|
|
|
10 |
dataset = 'wilk2020covid' |
|
|
11 |
out_dir = './tests/test_result/' |
|
|
12 |
|
|
|
13 |
if not os.path.exists(out_dir): |
|
|
14 |
os.mkdir(out_dir) |
|
|
15 |
|
|
|
16 |
fastmode = True |
|
|
17 |
|
|
|
18 |
adata = anndata.read_h5ad('./tests/test_wilk2020covid_processed_rna_assay_2.h5ad') |
|
|
19 |
adata.obs.columns |
|
|
20 |
adata |
|
|
21 |
|
|
|
22 |
## Check what are condition labels in the data and encode it |
|
|
23 |
print(adata.obs[['disease_status_standard']].drop_duplicates()) |
|
|
24 |
class_map = {'healthy': 0, 'COVID-19': 1} |
|
|
25 |
|
|
|
26 |
## standardize data |
|
|
27 |
adata = preprocess(adata, |
|
|
28 |
ct_col='cell.type.fine', |
|
|
29 |
y_col='disease_status_standard', |
|
|
30 |
pt_col='sample', |
|
|
31 |
class_map=class_map) |
|
|
32 |
|
|
|
33 |
## split data |
|
|
34 |
adata_train_dict, adata_test_dict = split_train_test(adata, min_cells=20, |
|
|
35 |
test_pt_size=0.2, |
|
|
36 |
out_dir=out_dir, |
|
|
37 |
random_state=3467) |
|
|
38 |
|
|
|
39 |
# Cell type selection |
|
|
40 |
## 1. calculate responsiveness score |
|
|
41 |
AUC, AUC_all = cell_type_score(adata_train_dict, |
|
|
42 |
out_dir=out_dir, |
|
|
43 |
ncpus=16, |
|
|
44 |
n_iterations=30, |
|
|
45 |
sample_n_cell=20) |
|
|
46 |
|
|
|
47 |
## 2. plot responsiveness score |
|
|
48 |
axes = plot_cell_type_score(AUC, AUC_all) |
|
|
49 |
axes.set_xlim([0.4, 1]) |
|
|
50 |
plt.savefig(f'{out_dir}/cell_type_score.pdf', bbox_inches="tight") |
|
|
51 |
|
|
|
52 |
## 3. select the most responsive cell type for downstream |
|
|
53 |
top_ct = AUC['celltype'].iloc[-1] |
|
|
54 |
adata_train = select_celltype(adata_train_dict, |
|
|
55 |
celltype_selected=top_ct, |
|
|
56 |
out_dir=out_dir) |
|
|
57 |
|
|
|
58 |
# Gene selection |
|
|
59 |
## 1. split training data |
|
|
60 |
train_index_list, val_index_list, sample_weight_list = split_n_folds(adata_train, |
|
|
61 |
nfold=5, |
|
|
62 |
out_dir=out_dir, |
|
|
63 |
random_state=2349) |
|
|
64 |
|
|
|
65 |
## 2. score genes |
|
|
66 |
adata_train, rfecv = gene_score(adata_train, |
|
|
67 |
train_index_list, |
|
|
68 |
val_index_list, |
|
|
69 |
sample_weight_list=sample_weight_list, |
|
|
70 |
step=0.03, |
|
|
71 |
out_dir=out_dir, |
|
|
72 |
ncpus=16, |
|
|
73 |
verbose=False) |
|
|
74 |
|
|
|
75 |
## 3. plot gene scores |
|
|
76 |
plot_gene_score(adata_train, n_genes_plot=200) |
|
|
77 |
plt.savefig(f'{out_dir}/gene_score.pdf', bbox_inches="tight") |
|
|
78 |
|
|
|
79 |
## 4. find the optimal number of informative genes |
|
|
80 |
k = decide_k(adata_train, n_genes_plot=100) |
|
|
81 |
plot_gene_score(adata_train, n_genes_plot=100, k=k) |
|
|
82 |
plt.savefig(f'{out_dir}/decide_k.pdf', bbox_inches="tight") |
|
|
83 |
|
|
|
84 |
## 5. return the list of informative genes |
|
|
85 |
adata_train = select_gene(adata_train, |
|
|
86 |
top_n_feat=k, |
|
|
87 |
step=0.03, |
|
|
88 |
out_dir=out_dir) |
|
|
89 |
|
|
|
90 |
## 6. view the list of informative genes |
|
|
91 |
sig_svm = adata_train.uns['svm_rfe_genes'] |
|
|
92 |
print(sig_svm) |
|
|
93 |
|
|
|
94 |
# Classification |
|
|
95 |
## 1. Subset training and testing set with selected cell type and genes |
|
|
96 |
adata_train_final, adata_test_final = transform_adata(adata_train, |
|
|
97 |
adata_test_dict, |
|
|
98 |
selected_gene=sig_svm) |
|
|
99 |
|
|
|
100 |
## 2. models training |
|
|
101 |
# overwrite default parameters for models |
|
|
102 |
param_grid = {'LR': {'max_iter': 600}} |
|
|
103 |
|
|
|
104 |
clfs = models_train(adata_train_final, |
|
|
105 |
search_grid=False, |
|
|
106 |
out_dir=out_dir, param_grid=param_grid) |
|
|
107 |
|
|
|
108 |
## 3. models testing |
|
|
109 |
### cell-level prediction |
|
|
110 |
adata_test_final, y_pred_list, y_pred_score_list = models_predict(clfs, adata_test_final, out_dir=out_dir) |
|
|
111 |
|
|
|
112 |
### sample-level prediction and evaluation |
|
|
113 |
all_pred_scores = ['LR_pred_score', 'RF_pred_score', 'SVM_pred_score', |
|
|
114 |
'KNN_pred_score', 'GAT_pred_score', 'median_pred_score'] |
|
|
115 |
|
|
|
116 |
for i in tqdm(range(len(all_pred_scores))): |
|
|
117 |
adata_test_final = pt_pred(adata_test_final, |
|
|
118 |
cell_pred_col=all_pred_scores[i], |
|
|
119 |
num_bootstrap=1000) |
|
|
120 |
|
|
|
121 |
### visualize patient-level AUC scores |
|
|
122 |
sample_id_list = adata_test_final.obs[['patient_id', 'median_pred_score_sample_auc']].drop_duplicates().sort_values( |
|
|
123 |
by='median_pred_score_sample_auc')['patient_id'] |
|
|
124 |
plot_roc_curve(adata_test_final, |
|
|
125 |
sample_id=sample_id_list, |
|
|
126 |
cell_pred_col='median_pred_score', |
|
|
127 |
hspace=0.4, ncols=4, |
|
|
128 |
scatter_kws={'s': 10}, legend_kws={'prop': {'size': 11}}) |
|
|
129 |
plt.savefig(f'{out_dir}/ROC_curve_sample.pdf', bbox_inches="tight") |
|
|
130 |
|
|
|
131 |
### visualize cell-level prediction probabilities and patient-level AUC scores |
|
|
132 |
fig, axes = plt.subplots(6, figsize=(13, 30)) |
|
|
133 |
for ax, pred in zip(axes, all_pred_scores): |
|
|
134 |
print('Plotting for ', pred) |
|
|
135 |
plot_violin(adata_test_final, |
|
|
136 |
cell_pred_col=pred, |
|
|
137 |
ax=ax, |
|
|
138 |
palette={'healthy': 'C0', 'COVID-19': 'C1'}) |
|
|
139 |
|
|
|
140 |
plt.subplots_adjust(left=0.1, |
|
|
141 |
bottom=0.1, |
|
|
142 |
right=0.9, |
|
|
143 |
top=0.9, |
|
|
144 |
wspace=0.4, |
|
|
145 |
hspace=1) |
|
|
146 |
|
|
|
147 |
plt.savefig(f"{out_dir}/Violin_pt_pred_prob.pdf", bbox_inches="tight") |
|
|
148 |
plt.show() |
|
|
149 |
|
|
|
150 |
### calculate patient-level precision, recall, f1score and accuracy and visualize |
|
|
151 |
for i in range(len(all_pred_scores)): |
|
|
152 |
adata_test_final = pt_score(adata_test_final, cell_pred_col=all_pred_scores[i]) |
|
|
153 |
|
|
|
154 |
plt.figure(figsize=(30, 10)) |
|
|
155 |
# Create a color palette: |
|
|
156 |
my_palette = plt.cm.get_cmap("Set2") |
|
|
157 |
# Loop to plot |
|
|
158 |
for metric_idx in range(0, len(adata_test_final.uns['sample_score'].columns)): |
|
|
159 |
make_single_spider(adata_test_final, |
|
|
160 |
metric_idx=metric_idx, |
|
|
161 |
color='grey', |
|
|
162 |
nrow=1, ncol=6) |
|
|
163 |
|
|
|
164 |
plt.subplots_adjust(left=0.1, |
|
|
165 |
bottom=0.1, |
|
|
166 |
right=0.9, |
|
|
167 |
top=0.9, |
|
|
168 |
wspace=1.5, |
|
|
169 |
hspace=1) |
|
|
170 |
|
|
|
171 |
plt.savefig(f"{out_dir}/SpiderPlot_pt_score.pdf", bbox_inches="tight") |
|
|
172 |
plt.show() |