Diff of /tests/test_scpanel.py [000000] .. [d90ecf]

Switch to unified view

a b/tests/test_scpanel.py
1
import sys
2
from scpanel.utils_func import *
3
from scpanel.split_patient import *
4
from scpanel.select_cell import *
5
from scpanel.select_gene import *
6
from scpanel.train import *
7
from scpanel.settings import *
8
import anndata
9
10
dataset = 'wilk2020covid'
11
out_dir = './tests/test_result/'
12
13
if not os.path.exists(out_dir):
14
    os.mkdir(out_dir)
15
16
fastmode = True
17
18
adata = anndata.read_h5ad('./tests/test_wilk2020covid_processed_rna_assay_2.h5ad')
19
adata.obs.columns
20
adata
21
22
## Check what are condition labels in the data and encode it
23
print(adata.obs[['disease_status_standard']].drop_duplicates())
24
class_map = {'healthy': 0, 'COVID-19': 1}
25
26
## standardize data
27
adata = preprocess(adata,
28
                   ct_col='cell.type.fine',
29
                   y_col='disease_status_standard',
30
                   pt_col='sample',
31
                   class_map=class_map)
32
33
## split data
34
adata_train_dict, adata_test_dict = split_train_test(adata, min_cells=20,
35
                                                     test_pt_size=0.2,
36
                                                     out_dir=out_dir,
37
                                                     random_state=3467)
38
39
# Cell type selection
40
## 1. calculate responsiveness score
41
AUC, AUC_all = cell_type_score(adata_train_dict,
42
                               out_dir=out_dir,
43
                               ncpus=16,
44
                               n_iterations=30,
45
                               sample_n_cell=20)
46
47
## 2. plot responsiveness score
48
axes = plot_cell_type_score(AUC, AUC_all)
49
axes.set_xlim([0.4, 1])
50
plt.savefig(f'{out_dir}/cell_type_score.pdf', bbox_inches="tight")
51
52
## 3. select the most responsive cell type for downstream
53
top_ct = AUC['celltype'].iloc[-1]
54
adata_train = select_celltype(adata_train_dict,
55
                              celltype_selected=top_ct,
56
                              out_dir=out_dir)
57
58
# Gene selection
59
## 1. split training data
60
train_index_list, val_index_list, sample_weight_list = split_n_folds(adata_train,
61
                                                                     nfold=5,
62
                                                                     out_dir=out_dir,
63
                                                                     random_state=2349)
64
65
## 2. score genes
66
adata_train, rfecv = gene_score(adata_train,
67
                                train_index_list,
68
                                val_index_list,
69
                                sample_weight_list=sample_weight_list,
70
                                step=0.03,
71
                                out_dir=out_dir,
72
                                ncpus=16,
73
                                verbose=False)
74
75
## 3. plot gene scores
76
plot_gene_score(adata_train, n_genes_plot=200)
77
plt.savefig(f'{out_dir}/gene_score.pdf', bbox_inches="tight")
78
79
## 4. find the optimal number of informative genes
80
k = decide_k(adata_train, n_genes_plot=100)
81
plot_gene_score(adata_train, n_genes_plot=100, k=k)
82
plt.savefig(f'{out_dir}/decide_k.pdf', bbox_inches="tight")
83
84
## 5. return the list of informative genes
85
adata_train = select_gene(adata_train,
86
                          top_n_feat=k,
87
                          step=0.03,
88
                          out_dir=out_dir)
89
90
## 6. view the list of informative genes
91
sig_svm = adata_train.uns['svm_rfe_genes']
92
print(sig_svm)
93
94
# Classification
95
## 1. Subset training and testing set with selected cell type and genes
96
adata_train_final, adata_test_final = transform_adata(adata_train,
97
                                                      adata_test_dict,
98
                                                      selected_gene=sig_svm)
99
100
## 2. models training
101
# overwrite default parameters for models
102
param_grid = {'LR': {'max_iter': 600}}
103
104
clfs = models_train(adata_train_final,
105
                    search_grid=False,
106
                    out_dir=out_dir, param_grid=param_grid)
107
108
## 3. models testing
109
### cell-level prediction
110
adata_test_final, y_pred_list, y_pred_score_list = models_predict(clfs, adata_test_final, out_dir=out_dir)
111
112
### sample-level prediction and evaluation
113
all_pred_scores = ['LR_pred_score', 'RF_pred_score', 'SVM_pred_score',
114
                   'KNN_pred_score', 'GAT_pred_score', 'median_pred_score']
115
116
for i in tqdm(range(len(all_pred_scores))):
117
    adata_test_final = pt_pred(adata_test_final,
118
                               cell_pred_col=all_pred_scores[i],
119
                               num_bootstrap=1000)
120
121
### visualize patient-level AUC scores
122
sample_id_list = adata_test_final.obs[['patient_id', 'median_pred_score_sample_auc']].drop_duplicates().sort_values(
123
    by='median_pred_score_sample_auc')['patient_id']
124
plot_roc_curve(adata_test_final,
125
               sample_id=sample_id_list,
126
               cell_pred_col='median_pred_score',
127
               hspace=0.4, ncols=4,
128
               scatter_kws={'s': 10}, legend_kws={'prop': {'size': 11}})
129
plt.savefig(f'{out_dir}/ROC_curve_sample.pdf', bbox_inches="tight")
130
131
### visualize cell-level prediction probabilities and patient-level AUC scores
132
fig, axes = plt.subplots(6, figsize=(13, 30))
133
for ax, pred in zip(axes, all_pred_scores):
134
    print('Plotting for ', pred)
135
    plot_violin(adata_test_final,
136
                cell_pred_col=pred,
137
                ax=ax,
138
                palette={'healthy': 'C0', 'COVID-19': 'C1'})
139
140
plt.subplots_adjust(left=0.1,
141
                    bottom=0.1,
142
                    right=0.9,
143
                    top=0.9,
144
                    wspace=0.4,
145
                    hspace=1)
146
147
plt.savefig(f"{out_dir}/Violin_pt_pred_prob.pdf", bbox_inches="tight")
148
plt.show()
149
150
### calculate patient-level precision, recall, f1score and accuracy and visualize
151
for i in range(len(all_pred_scores)):
152
    adata_test_final = pt_score(adata_test_final, cell_pred_col=all_pred_scores[i])
153
154
plt.figure(figsize=(30, 10))
155
# Create a color palette:
156
my_palette = plt.cm.get_cmap("Set2")
157
# Loop to plot
158
for metric_idx in range(0, len(adata_test_final.uns['sample_score'].columns)):
159
    make_single_spider(adata_test_final,
160
                       metric_idx=metric_idx,
161
                       color='grey',
162
                       nrow=1, ncol=6)
163
164
plt.subplots_adjust(left=0.1,
165
                    bottom=0.1,
166
                    right=0.9,
167
                    top=0.9,
168
                    wspace=1.5,
169
                    hspace=1)
170
171
plt.savefig(f"{out_dir}/SpiderPlot_pt_score.pdf", bbox_inches="tight")
172
plt.show()