a b/tests/test_main.py
1
from context import models, pl, tl, score
2
import mudata as md
3
import anndata as ad
4
import torch
5
import numpy as np
6
7
# Define some gene names (useful for enrichment analysis).
8
gene_names = [
9
    "ENSG00000125877",
10
    "ENSG00000184840",
11
    "ENSG00000164440",
12
    "ENSG00000177144",
13
    "ENSG00000186815",
14
    "ENSG00000079974",
15
    "ENSG00000136159",
16
    "ENSG00000177243",
17
    "ENSG00000163932",
18
    "ENSG00000112799",
19
    "ENSG00000075618",
20
    "ENSG00000092531",
21
    "ENSG00000171408",
22
    "ENSG00000150527",
23
    "ENSG00000202429",
24
    "ENSG00000140807",
25
    "ENSG00000154589",
26
    "ENSG00000166263",
27
    "ENSG00000205268",
28
    "ENSG00000115008",
29
]
30
31
n_cells, n_genes, n_peaks = 20, len(gene_names), 5
32
latent_dim = 5
33
34
# Create a random anndata object for RNA.
35
rna = ad.AnnData(np.random.rand(n_cells, n_genes))
36
rna.var["highly_variable"] = True
37
38
# Create a random anndata object for ATAC.
39
atac = ad.AnnData(np.random.rand(n_cells, n_peaks))
40
atac.var["highly_variable"] = True
41
42
# Create a MuData object combining RNA and ATAC.
43
mdata = md.MuData({"rna": rna, "atac": atac})
44
45
mdata.obs["rna:mod_weight"] = 0.5
46
mdata.obs["atac:mod_weight"] = 0.5
47
mdata.obs["label"] = np.random.choice(["A", "B", "C"], size=n_cells)
48
49
50
def test_default_params():
51
52
    # Initialize the Mowgli model.
53
    model = models.MowgliModel(
54
        latent_dim=latent_dim,
55
        cost_path={
56
            "rna": "cost_rna.npy",
57
            "atac": "cost_atac.npy",
58
        },
59
    )
60
61
    # Train the model.
62
    model.train(mdata)
63
64
    # Check the size of the embedding.
65
    assert mdata.obsm["W_OT"].shape == (n_cells, latent_dim)
66
67
    # Check the size of the dictionaries.
68
    assert mdata["rna"].uns["H_OT"].shape == (n_genes, latent_dim)
69
    assert mdata["atac"].uns["H_OT"].shape == (n_peaks, latent_dim)
70
71
72
def test_custom_params():
73
74
    # Initialize the Mowgli model.
75
    model = models.MowgliModel(
76
        latent_dim=latent_dim,
77
        h_regularization={"rna": 0.1, "atac": 0.1},
78
        use_mod_weight=True,
79
        pca_cost=True,
80
        cost_path={
81
            "rna": "cost_rna.npy",
82
            "atac": "cost_atac.npy",
83
        },
84
    )
85
    model.init_parameters(
86
        mdata,
87
        force_recompute=True,
88
        normalize_rows=True,
89
        dtype=torch.float,
90
        device="cpu",
91
    )
92
93
    # Train the model.
94
    model.train(mdata, optim_name="adam")
95
96
    # Check the size of the embedding.
97
    assert mdata.obsm["W_OT"].shape == (n_cells, latent_dim)
98
99
    # Check the size of the dictionaries.
100
    assert mdata["rna"].uns["H_OT"].shape == (n_genes, latent_dim)
101
    assert mdata["atac"].uns["H_OT"].shape == (n_peaks, latent_dim)
102
103
104
def test_plotting():
105
106
    # Make a clustermap.
107
    pl.clustermap(mdata, show=False)
108
109
    # Make a violin plot.
110
    pl.factor_violin(mdata, groupby="label", dim=0, show=False)
111
112
    # Make a heatmap.
113
    pl.heatmap(mdata, groupby="label", show=False)
114
115
116
def test_tools():
117
118
    # Compute top genes.
119
    tl.top_features(mdata, mod="rna", dim=0, threshold=0.2)
120
121
    # Compute top peaks.
122
    tl.top_features(mdata, mod="atac", dim=0, threshold=0.2)
123
124
    # Compute enrichment.
125
    tl.enrich(mdata, n_genes=10, ordered=False)
126
127
128
def test_score():
129
130
    # Compute a silhouette score.
131
    score.embedding_silhouette_score(
132
        embedding=mdata.obsm["W_OT"],
133
        labels=mdata.obs["label"],
134
        metric="euclidean",
135
    )
136
137
    # Compute leiden clustering across resolutions.
138
    score.embedding_leiden_across_resolutions(
139
        embedding=mdata.obsm["W_OT"],
140
        labels=mdata.obs["label"],
141
        n_neighbors=10,
142
        resolutions=[0.1, 0.5, 1.0],
143
    )
144
145
    # Compute a knn from the embedding.
146
    knn = score.embedding_to_knn(embedding=mdata.obsm["W_OT"], k=15, metric="euclidean")
147
148
    # Compute the knn purity score.
149
    score.knn_purity_score(knn=knn, labels=mdata.obs["label"])