Diff of /tests/fig4_test.py [000000] .. [5d6472]

Switch to unified view

a b/tests/fig4_test.py
1
import os
2
import sys
3
4
current_path = os.path.dirname(__file__)
5
src_path = os.path.join(current_path, "..")
6
sys.path.append(src_path)
7
8
import pytest
9
import numpy as np
10
import multivelo as mv
11
import scanpy as sc
12
import scvelo as scv
13
import sys
14
15
sys.path.append("/..")
16
17
scv.settings.verbosity = 3
18
scv.settings.presenter_view = True
19
scv.set_figure_params('scvelo')
20
np.set_printoptions(suppress=True)
21
22
23
rna_path = "test_files/adata_postpro.h5ad"
24
atac_path = "test_files/adata_atac_postpro.h5ad"
25
26
27
@pytest.fixture(scope="session")
28
def result_data_4():
29
30
    # read in the original AnnData objects
31
    adata_rna = sc.read(rna_path)
32
    adata_atac = sc.read(atac_path)
33
34
    # subset genes to run faster
35
    gene_list = ["Shh", "Heg1", "Cux1", "Lef1"]
36
37
    # run our first function to test (recover_dynamics_chrom)
38
    adata_result = mv.recover_dynamics_chrom(adata_rna,
39
                                             adata_atac,
40
                                             gene_list=gene_list,
41
                                             max_iter=5,
42
                                             init_mode="invert",
43
                                             parallel=True,
44
                                             n_jobs=15,
45
                                             save_plot=False,
46
                                             rna_only=False,
47
                                             fit=True,
48
                                             n_anchors=500,
49
                                             extra_color_key='celltype')
50
51
    return adata_result
52
53
54
# the next three tests check to see if recover_dynamics_chrom calculated
55
# the correct parameters for each of our four genes
56
def test_alpha(result_data_4):
57
    alpha = result_data_4.var["fit_alpha"]
58
59
    assert alpha[0] == pytest.approx(0.45878197934025416)
60
    assert alpha[1] == pytest.approx(0.08032904996744818)
61
    assert alpha[2] == pytest.approx(1.5346878202804608)
62
    assert alpha[3] == pytest.approx(0.9652887906148591)
63
64
65
def test_beta(result_data_4):
66
    beta = result_data_4.var["fit_beta"]
67
68
    assert beta[0] == pytest.approx(0.28770367567423)
69
    assert beta[1] == pytest.approx(0.14497469719573167)
70
    assert beta[2] == pytest.approx(0.564865749852349)
71
    assert beta[3] == pytest.approx(0.2522643118709811)
72
73
74
def test_gamma(result_data_4):
75
    gamma = result_data_4.var["fit_gamma"]
76
77
    assert gamma[0] == pytest.approx(0.19648836445315102)
78
    assert gamma[1] == pytest.approx(0.07703610603664116)
79
    assert gamma[2] == pytest.approx(1.0079569101225154)
80
    assert gamma[3] == pytest.approx(0.7485734061079243)
81
82
83
def test_embedding_stream(result_data_4):
84
85
    mv.velocity_graph(result_data_4)
86
87
    ax = mv.velocity_embedding_stream(result_data_4, basis='umap',
88
                                      color='celltype', show=False)
89
90
    assert ax is not None
91
92
    assert ax.axis()[0] == pytest.approx(-2.0698418340618714)
93
    assert ax.axis()[1] == pytest.approx(8.961822542538197)
94
    assert ax.axis()[2] == pytest.approx(-14.418079041548095)
95
    assert ax.axis()[3] == pytest.approx(-7.789863798927619)
96
97
    assert ax.get_xlim()[0] == pytest.approx(-2.0698418340618714)
98
    assert ax.get_xlim()[1] == pytest.approx(8.961822542538197)
99
100
    assert ax.get_ylim()[0] == pytest.approx(-14.418079041548095)
101
    assert ax.get_ylim()[1] == pytest.approx(-7.789863798927619)
102
103
104
# tests the latent_time function
105
def test_latent_time(result_data_4):
106
107
    mv.velocity_graph(result_data_4)
108
    mv.latent_time(result_data_4)
109
110
    latent_time = result_data_4.obs["latent_time"]
111
112
    assert latent_time.shape[0] == 6436
113
114
115
# test the velocity_graph function
116
def test_velo_graph(result_data_4):
117
118
    mv.velocity_graph(result_data_4)
119
120
    digits = 8
121
122
    v_graph_mat = result_data_4.uns["velo_s_norm_graph"].tocoo()
123
124
    v_graph = v_graph_mat.data
125
    v_graph = v_graph.astype(float)
126
    v_graph = v_graph.round(decimals=digits)
127
128
    v_graph_rows = v_graph_mat.row
129
    v_graph_cols = v_graph_mat.col
130
131
    assert len(v_graph) == 1883599
132
    assert v_graph[0] == pytest.approx(1.0)
133
    assert v_graph[500000] == pytest.approx(1.0)
134
    assert v_graph[1005000] == pytest.approx(0.99999994)
135
    assert v_graph[1500000] == pytest.approx(1.0)
136
137
    assert v_graph_rows[0] == 0
138
    assert v_graph_rows[500000] == 1411
139
    assert v_graph_rows[1005000] == 2834
140
    assert v_graph_rows[1500000] == 4985
141
142
    assert v_graph_cols[0] == 7
143
    assert v_graph_cols[500000] == 2406
144
    assert v_graph_cols[1005000] == 2892
145
    assert v_graph_cols[1500000] == 2480
146
147
148
@pytest.fixture(scope="session")
149
def lrt_compute():
150
151
    # read in the original AnnData objects
152
    adata_rna = sc.read(rna_path)
153
    adata_atac = sc.read(atac_path)
154
155
    # subset genes to run faster
156
    gene_list = ["Shh", "Heg1", "Cux1", "Lef1"]
157
158
    # run our first function to test (LRT_decoupling)
159
    w_de, wo_de, res = mv.LRT_decoupling(adata_rna,
160
                                         adata_atac,
161
                                         gene_list=gene_list,
162
                                         max_iter=5,
163
                                         init_mode="invert",
164
                                         parallel=True,
165
                                         n_jobs=15,
166
                                         save_plot=False,
167
                                         rna_only=False,
168
                                         fit=True,
169
                                         n_anchors=500,
170
                                         extra_color_key='celltype')
171
172
    # w_de = with decoupling
173
    # wo_de = without decoupling
174
    # res = LRT stats
175
    return (w_de, wo_de, res)
176
177
178
def decouple_test(lrt_compute):
179
180
    w_decouple = lrt_compute[0]
181
182
    alpha_c = w_decouple.var["fit_alpha_c"]
183
184
    assert alpha_c[0] == pytest.approx(0.057961)
185
    assert alpha_c[1] == pytest.approx(0.039439)
186
    assert alpha_c[2] == pytest.approx(0.076731)
187
    assert alpha_c[3] == pytest.approx(0.063575)
188
189
    beta = w_decouple.var["fit_beta"]
190
191
    assert beta[0] == pytest.approx(0.287704)
192
    assert beta[1] == pytest.approx(0.144975)
193
    assert beta[2] == pytest.approx(0.564866)
194
    assert beta[3] == pytest.approx(0.252264)
195
196
    gamma = w_decouple.var["fit_gamma"]
197
198
    assert gamma[0] == pytest.approx(0.196488)
199
    assert gamma[1] == pytest.approx(0.077036)
200
    assert gamma[2] == pytest.approx(1.007957)
201
    assert gamma[3] == pytest.approx(0.748573)
202
203
204
def no_decouple_test(lrt_compute):
205
206
    print("No decouple test")
207
208
    wo_decouple = lrt_compute[1]
209
210
    alpha_c = wo_decouple.var["fit_alpha_c"]
211
212
    assert alpha_c[0] == pytest.approx(0.093752)
213
    assert alpha_c[1] == pytest.approx(0.041792)
214
    assert alpha_c[2] == pytest.approx(0.051228)
215
    assert alpha_c[3] == pytest.approx(0.050951)
216
217
    beta = wo_decouple.var["fit_beta"]
218
219
    assert beta[0] == pytest.approx(0.840938)
220
    assert beta[1] == pytest.approx(0.182773)
221
    assert beta[2] == pytest.approx(0.326623)
222
    assert beta[3] == pytest.approx(0.232073)
223
224
    gamma = wo_decouple.var["fit_gamma"]
225
226
    assert gamma[0] == pytest.approx(0.561730)
227
    assert gamma[1] == pytest.approx(0.106799)
228
    assert gamma[2] == pytest.approx(0.783257)
229
    assert gamma[3] == pytest.approx(0.705256)
230
231
232
def lrt_res_test(lrt_compute):
233
234
    res = lrt_compute[2]
235
236
    likelihood_c_w_decoupled = res["likelihood_c_w_decoupled"]
237
238
    assert likelihood_c_w_decoupled[0] == pytest.approx(0.279303)
239
    assert likelihood_c_w_decoupled[1] == pytest.approx(0.186213)
240
    assert likelihood_c_w_decoupled[2] == pytest.approx(0.295591)
241
    assert likelihood_c_w_decoupled[3] == pytest.approx(0.144158)
242
243
    likelihood_c_wo_decoupled = res["likelihood_c_wo_decoupled"]
244
245
    assert likelihood_c_wo_decoupled[0] == pytest.approx(0.270491)
246
    assert likelihood_c_wo_decoupled[1] == pytest.approx(0.180695)
247
    assert likelihood_c_wo_decoupled[2] == pytest.approx(0.294631)
248
    assert likelihood_c_wo_decoupled[3] == pytest.approx(0.175622)
249
250
    LRT_c = res["LRT_c"]
251
252
    assert LRT_c[0] == pytest.approx(412.637730)
253
    assert LRT_c[1] == pytest.approx(387.177688)
254
    assert LRT_c[2] == pytest.approx(41.850304)
255
    assert LRT_c[3] == pytest.approx(-2541.289231)
256
257
    pval_c = res["pval_c"]
258
259
    assert pval_c[0] == pytest.approx(9.771580e-92)
260
    assert pval_c[1] == pytest.approx(3.406463e-86)
261
    assert pval_c[2] == pytest.approx(9.853544e-11)
262
    assert pval_c[3] == pytest.approx(1.000000e+00)
263
264
    likelihood_w_decoupled = res["likelihood_w_decoupled"]
265
266
    assert likelihood_w_decoupled[0] == pytest.approx(0.177979)
267
    assert likelihood_w_decoupled[1] == pytest.approx(0.008453)
268
    assert likelihood_w_decoupled[2] == pytest.approx(0.140156)
269
    assert likelihood_w_decoupled[3] == pytest.approx(0.005029)
270
271
    likelihood_wo_decoupled = res["likelihood_wo_decoupled"]
272
273
    assert likelihood_wo_decoupled[0] == pytest.approx(0.181317)
274
    assert likelihood_wo_decoupled[1] == pytest.approx(0.009486)
275
    assert likelihood_wo_decoupled[2] == pytest.approx(0.141367)
276
    assert likelihood_wo_decoupled[3] == pytest.approx(0.008299)
277
278
    LRT = res["LRT"]
279
280
    assert LRT[0] == pytest.approx(-239.217562)
281
    assert LRT[1] == pytest.approx(-1485.199859)
282
    assert LRT[2] == pytest.approx(-110.788912)
283
    assert LRT[3] == pytest.approx(-6447.599212)
284
285
    pval = res["pval"]
286
287
    assert pval[0] == pytest.approx(1.0)
288
    assert pval[1] == pytest.approx(1.0)
289
    assert pval[2] == pytest.approx(1.0)
290
    assert pval[3] == pytest.approx(1.0)
291
292
    c_likelihood = res["likelihood_c_w_decoupled"]
293
294
    assert c_likelihood[0] == pytest.approx(0.279303)
295
    assert c_likelihood[1] == pytest.approx(0.186213)
296
    assert c_likelihood[2] == pytest.approx(0.295591)
297
    assert c_likelihood[3] == pytest.approx(0.144158)
298
299
    def test_qc_metrics():
300
        adata_rna = sc.read(rna_path)
301
302
        mv.calculate_qc_metrics(adata_rna)
303
304
        total_unspliced = adata_rna.obs["total_unspliced"]
305
306
        assert total_unspliced.shape == (6436,)
307
        assert total_unspliced[0] == pytest.approx(91.709404)
308
        assert total_unspliced[1500] == pytest.approx(115.21283)
309
        assert total_unspliced[3000] == pytest.approx(61.402004)
310
        assert total_unspliced[4500] == pytest.approx(84.03409)
311
        assert total_unspliced[6000] == pytest.approx(61.26761)
312
313
        total_spliced = adata_rna.obs["total_spliced"]
314
315
        assert total_spliced.shape == (6436,)
316
        assert total_spliced[0] == pytest.approx(91.514175)
317
        assert total_spliced[1500] == pytest.approx(66.045616)
318
        assert total_spliced[3000] == pytest.approx(87.05275)
319
        assert total_spliced[4500] == pytest.approx(83.82857)
320
        assert total_spliced[6000] == pytest.approx(62.019516)
321
322
        unspliced_ratio = adata_rna.obs["unspliced_ratio"]
323
324
        assert unspliced_ratio.shape == (6436,)
325
        assert unspliced_ratio[0] == pytest.approx(0.5005328)
326
        assert unspliced_ratio[1500] == pytest.approx(0.6356273)
327
        assert unspliced_ratio[3000] == pytest.approx(0.4136075)
328
        assert unspliced_ratio[4500] == pytest.approx(0.50061214)
329
        assert unspliced_ratio[6000] == pytest.approx(0.4969506)
330
331
        cell_cycle_score = adata_rna.obs["cell_cycle_score"]
332
333
        assert cell_cycle_score.shape == (6436,)
334
        assert cell_cycle_score[0] == pytest.approx(-0.24967776384597046)
335
        assert cell_cycle_score[1500] == pytest.approx(0.5859756395543293)
336
        assert cell_cycle_score[3000] == pytest.approx(0.06501555292615813)
337
        assert cell_cycle_score[4500] == pytest.approx(0.1406775909466575)
338
        assert cell_cycle_score[6000] == pytest.approx(-0.33825528386759895)