--- a
+++ b/tests/test_scpanel.py
@@ -0,0 +1,172 @@
+import sys
+from scpanel.utils_func import *
+from scpanel.split_patient import *
+from scpanel.select_cell import *
+from scpanel.select_gene import *
+from scpanel.train import *
+from scpanel.settings import *
+import anndata
+
+dataset = 'wilk2020covid'
+out_dir = './tests/test_result/'
+
+if not os.path.exists(out_dir):
+    os.mkdir(out_dir)
+
+fastmode = True
+
+adata = anndata.read_h5ad('./tests/test_wilk2020covid_processed_rna_assay_2.h5ad')
+adata.obs.columns
+adata
+
+## Check what are condition labels in the data and encode it
+print(adata.obs[['disease_status_standard']].drop_duplicates())
+class_map = {'healthy': 0, 'COVID-19': 1}
+
+## standardize data
+adata = preprocess(adata,
+                   ct_col='cell.type.fine',
+                   y_col='disease_status_standard',
+                   pt_col='sample',
+                   class_map=class_map)
+
+## split data
+adata_train_dict, adata_test_dict = split_train_test(adata, min_cells=20,
+                                                     test_pt_size=0.2,
+                                                     out_dir=out_dir,
+                                                     random_state=3467)
+
+# Cell type selection
+## 1. calculate responsiveness score
+AUC, AUC_all = cell_type_score(adata_train_dict,
+                               out_dir=out_dir,
+                               ncpus=16,
+                               n_iterations=30,
+                               sample_n_cell=20)
+
+## 2. plot responsiveness score
+axes = plot_cell_type_score(AUC, AUC_all)
+axes.set_xlim([0.4, 1])
+plt.savefig(f'{out_dir}/cell_type_score.pdf', bbox_inches="tight")
+
+## 3. select the most responsive cell type for downstream
+top_ct = AUC['celltype'].iloc[-1]
+adata_train = select_celltype(adata_train_dict,
+                              celltype_selected=top_ct,
+                              out_dir=out_dir)
+
+# Gene selection
+## 1. split training data
+train_index_list, val_index_list, sample_weight_list = split_n_folds(adata_train,
+                                                                     nfold=5,
+                                                                     out_dir=out_dir,
+                                                                     random_state=2349)
+
+## 2. score genes
+adata_train, rfecv = gene_score(adata_train,
+                                train_index_list,
+                                val_index_list,
+                                sample_weight_list=sample_weight_list,
+                                step=0.03,
+                                out_dir=out_dir,
+                                ncpus=16,
+                                verbose=False)
+
+## 3. plot gene scores
+plot_gene_score(adata_train, n_genes_plot=200)
+plt.savefig(f'{out_dir}/gene_score.pdf', bbox_inches="tight")
+
+## 4. find the optimal number of informative genes
+k = decide_k(adata_train, n_genes_plot=100)
+plot_gene_score(adata_train, n_genes_plot=100, k=k)
+plt.savefig(f'{out_dir}/decide_k.pdf', bbox_inches="tight")
+
+## 5. return the list of informative genes
+adata_train = select_gene(adata_train,
+                          top_n_feat=k,
+                          step=0.03,
+                          out_dir=out_dir)
+
+## 6. view the list of informative genes
+sig_svm = adata_train.uns['svm_rfe_genes']
+print(sig_svm)
+
+# Classification
+## 1. Subset training and testing set with selected cell type and genes
+adata_train_final, adata_test_final = transform_adata(adata_train,
+                                                      adata_test_dict,
+                                                      selected_gene=sig_svm)
+
+## 2. models training
+# overwrite default parameters for models
+param_grid = {'LR': {'max_iter': 600}}
+
+clfs = models_train(adata_train_final,
+                    search_grid=False,
+                    out_dir=out_dir, param_grid=param_grid)
+
+## 3. models testing
+### cell-level prediction
+adata_test_final, y_pred_list, y_pred_score_list = models_predict(clfs, adata_test_final, out_dir=out_dir)
+
+### sample-level prediction and evaluation
+all_pred_scores = ['LR_pred_score', 'RF_pred_score', 'SVM_pred_score',
+                   'KNN_pred_score', 'GAT_pred_score', 'median_pred_score']
+
+for i in tqdm(range(len(all_pred_scores))):
+    adata_test_final = pt_pred(adata_test_final,
+                               cell_pred_col=all_pred_scores[i],
+                               num_bootstrap=1000)
+
+### visualize patient-level AUC scores
+sample_id_list = adata_test_final.obs[['patient_id', 'median_pred_score_sample_auc']].drop_duplicates().sort_values(
+    by='median_pred_score_sample_auc')['patient_id']
+plot_roc_curve(adata_test_final,
+               sample_id=sample_id_list,
+               cell_pred_col='median_pred_score',
+               hspace=0.4, ncols=4,
+               scatter_kws={'s': 10}, legend_kws={'prop': {'size': 11}})
+plt.savefig(f'{out_dir}/ROC_curve_sample.pdf', bbox_inches="tight")
+
+### visualize cell-level prediction probabilities and patient-level AUC scores
+fig, axes = plt.subplots(6, figsize=(13, 30))
+for ax, pred in zip(axes, all_pred_scores):
+    print('Plotting for ', pred)
+    plot_violin(adata_test_final,
+                cell_pred_col=pred,
+                ax=ax,
+                palette={'healthy': 'C0', 'COVID-19': 'C1'})
+
+plt.subplots_adjust(left=0.1,
+                    bottom=0.1,
+                    right=0.9,
+                    top=0.9,
+                    wspace=0.4,
+                    hspace=1)
+
+plt.savefig(f"{out_dir}/Violin_pt_pred_prob.pdf", bbox_inches="tight")
+plt.show()
+
+### calculate patient-level precision, recall, f1score and accuracy and visualize
+for i in range(len(all_pred_scores)):
+    adata_test_final = pt_score(adata_test_final, cell_pred_col=all_pred_scores[i])
+
+plt.figure(figsize=(30, 10))
+# Create a color palette:
+my_palette = plt.cm.get_cmap("Set2")
+# Loop to plot
+for metric_idx in range(0, len(adata_test_final.uns['sample_score'].columns)):
+    make_single_spider(adata_test_final,
+                       metric_idx=metric_idx,
+                       color='grey',
+                       nrow=1, ncol=6)
+
+plt.subplots_adjust(left=0.1,
+                    bottom=0.1,
+                    right=0.9,
+                    top=0.9,
+                    wspace=1.5,
+                    hspace=1)
+
+plt.savefig(f"{out_dir}/SpiderPlot_pt_score.pdf", bbox_inches="tight")
+plt.show()