[f3902c]: / tests / test_scpanel.py

Download this file

173 lines (142 with data), 6.4 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import sys
from scpanel.utils_func import *
from scpanel.split_patient import *
from scpanel.select_cell import *
from scpanel.select_gene import *
from scpanel.train import *
from scpanel.settings import *
import anndata
dataset = 'wilk2020covid'
out_dir = './tests/test_result/'
if not os.path.exists(out_dir):
os.mkdir(out_dir)
fastmode = True
adata = anndata.read_h5ad('./tests/test_wilk2020covid_processed_rna_assay_2.h5ad')
adata.obs.columns
adata
## Check what are condition labels in the data and encode it
print(adata.obs[['disease_status_standard']].drop_duplicates())
class_map = {'healthy': 0, 'COVID-19': 1}
## standardize data
adata = preprocess(adata,
ct_col='cell.type.fine',
y_col='disease_status_standard',
pt_col='sample',
class_map=class_map)
## split data
adata_train_dict, adata_test_dict = split_train_test(adata, min_cells=20,
test_pt_size=0.2,
out_dir=out_dir,
random_state=3467)
# Cell type selection
## 1. calculate responsiveness score
AUC, AUC_all = cell_type_score(adata_train_dict,
out_dir=out_dir,
ncpus=16,
n_iterations=30,
sample_n_cell=20)
## 2. plot responsiveness score
axes = plot_cell_type_score(AUC, AUC_all)
axes.set_xlim([0.4, 1])
plt.savefig(f'{out_dir}/cell_type_score.pdf', bbox_inches="tight")
## 3. select the most responsive cell type for downstream
top_ct = AUC['celltype'].iloc[-1]
adata_train = select_celltype(adata_train_dict,
celltype_selected=top_ct,
out_dir=out_dir)
# Gene selection
## 1. split training data
train_index_list, val_index_list, sample_weight_list = split_n_folds(adata_train,
nfold=5,
out_dir=out_dir,
random_state=2349)
## 2. score genes
adata_train, rfecv = gene_score(adata_train,
train_index_list,
val_index_list,
sample_weight_list=sample_weight_list,
step=0.03,
out_dir=out_dir,
ncpus=16,
verbose=False)
## 3. plot gene scores
plot_gene_score(adata_train, n_genes_plot=200)
plt.savefig(f'{out_dir}/gene_score.pdf', bbox_inches="tight")
## 4. find the optimal number of informative genes
k = decide_k(adata_train, n_genes_plot=100)
plot_gene_score(adata_train, n_genes_plot=100, k=k)
plt.savefig(f'{out_dir}/decide_k.pdf', bbox_inches="tight")
## 5. return the list of informative genes
adata_train = select_gene(adata_train,
top_n_feat=k,
step=0.03,
out_dir=out_dir)
## 6. view the list of informative genes
sig_svm = adata_train.uns['svm_rfe_genes']
print(sig_svm)
# Classification
## 1. Subset training and testing set with selected cell type and genes
adata_train_final, adata_test_final = transform_adata(adata_train,
adata_test_dict,
selected_gene=sig_svm)
## 2. models training
# overwrite default parameters for models
param_grid = {'LR': {'max_iter': 600}}
clfs = models_train(adata_train_final,
search_grid=False,
out_dir=out_dir, param_grid=param_grid)
## 3. models testing
### cell-level prediction
adata_test_final, y_pred_list, y_pred_score_list = models_predict(clfs, adata_test_final, out_dir=out_dir)
### sample-level prediction and evaluation
all_pred_scores = ['LR_pred_score', 'RF_pred_score', 'SVM_pred_score',
'KNN_pred_score', 'GAT_pred_score', 'median_pred_score']
for i in tqdm(range(len(all_pred_scores))):
adata_test_final = pt_pred(adata_test_final,
cell_pred_col=all_pred_scores[i],
num_bootstrap=1000)
### visualize patient-level AUC scores
sample_id_list = adata_test_final.obs[['patient_id', 'median_pred_score_sample_auc']].drop_duplicates().sort_values(
by='median_pred_score_sample_auc')['patient_id']
plot_roc_curve(adata_test_final,
sample_id=sample_id_list,
cell_pred_col='median_pred_score',
hspace=0.4, ncols=4,
scatter_kws={'s': 10}, legend_kws={'prop': {'size': 11}})
plt.savefig(f'{out_dir}/ROC_curve_sample.pdf', bbox_inches="tight")
### visualize cell-level prediction probabilities and patient-level AUC scores
fig, axes = plt.subplots(6, figsize=(13, 30))
for ax, pred in zip(axes, all_pred_scores):
print('Plotting for ', pred)
plot_violin(adata_test_final,
cell_pred_col=pred,
ax=ax,
palette={'healthy': 'C0', 'COVID-19': 'C1'})
plt.subplots_adjust(left=0.1,
bottom=0.1,
right=0.9,
top=0.9,
wspace=0.4,
hspace=1)
plt.savefig(f"{out_dir}/Violin_pt_pred_prob.pdf", bbox_inches="tight")
plt.show()
### calculate patient-level precision, recall, f1score and accuracy and visualize
for i in range(len(all_pred_scores)):
adata_test_final = pt_score(adata_test_final, cell_pred_col=all_pred_scores[i])
plt.figure(figsize=(30, 10))
# Create a color palette:
my_palette = plt.cm.get_cmap("Set2")
# Loop to plot
for metric_idx in range(0, len(adata_test_final.uns['sample_score'].columns)):
make_single_spider(adata_test_final,
metric_idx=metric_idx,
color='grey',
nrow=1, ncol=6)
plt.subplots_adjust(left=0.1,
bottom=0.1,
right=0.9,
top=0.9,
wspace=1.5,
hspace=1)
plt.savefig(f"{out_dir}/SpiderPlot_pt_score.pdf", bbox_inches="tight")
plt.show()