Diff of /tests/conftest.py [000000] .. [e5f1db]

Switch to unified view

a b/tests/conftest.py
1
from __future__ import annotations
2
3
from pathlib import Path
4
from typing import TYPE_CHECKING
5
6
import numpy as np
7
import pandas as pd
8
import pytest
9
from anndata import AnnData
10
from matplotlib.testing.compare import compare_images
11
12
import ehrapy as ep
13
from ehrapy.io import read_csv
14
15
if TYPE_CHECKING:
16
    import os
17
18
    from matplotlib.figure import Figure
19
20
TEST_DATA_PATH = Path(__file__).parent / "data"
21
22
23
@pytest.fixture
24
def root_dir():
25
    return Path(__file__).resolve().parent
26
27
28
@pytest.fixture
29
def rng():
30
    return np.random.default_rng(seed=42)
31
32
33
@pytest.fixture
34
def obs_data():
35
    return {
36
        "disease": ["cancer", "tumor"],
37
        "country": ["Germany", "switzerland"],
38
        "sex": ["male", "female"],
39
    }
40
41
42
@pytest.fixture
43
def var_data():
44
    return {
45
        "alive": ["yes", "no", "maybe"],
46
        "hospital": ["hospital 1", "hospital 2", "hospital 1"],
47
        "crazy": ["yes", "yes", "yes"],
48
    }
49
50
51
@pytest.fixture
52
def missing_values_adata(obs_data, var_data):
53
    return AnnData(
54
        X=np.array([[0.21, np.nan, 41.42], [np.nan, np.nan, 7.234]], dtype=np.float32),
55
        obs=pd.DataFrame(data=obs_data),
56
        var=pd.DataFrame(data=var_data, index=["Acetaminophen", "hospital", "crazy"]),
57
    )
58
59
60
@pytest.fixture
61
def lab_measurements_simple_adata(obs_data, var_data):
62
    X = np.array([[73, 0.02, 1.00], [148, 0.25, 3.55]], dtype=np.float32)
63
    return AnnData(
64
        X=X,
65
        obs=pd.DataFrame(data=obs_data),
66
        var=pd.DataFrame(data=var_data, index=["Acetaminophen", "Acetoacetic acid", "Beryllium, toxic"]),
67
    )
68
69
70
@pytest.fixture
71
def lab_measurements_layer_adata(obs_data, var_data):
72
    X = np.array([[73, 0.02, 1.00], [148, 0.25, 3.55]], dtype=np.float32)
73
    return AnnData(
74
        X=X,
75
        obs=pd.DataFrame(data=obs_data),
76
        var=pd.DataFrame(data=var_data, index=["Acetaminophen", "Acetoacetic acid", "Beryllium, toxic"]),
77
        layers={"layer_copy": X},
78
    )
79
80
81
@pytest.fixture
82
def mimic_2():
83
    adata = ep.dt.mimic_2()
84
    return adata
85
86
87
@pytest.fixture
88
def mimic_2_encoded():
89
    adata = ep.dt.mimic_2(encoded=True)
90
    return adata
91
92
93
@pytest.fixture
94
def mimic_2_10():
95
    mimic_2_10 = ep.dt.mimic_2()[:10]
96
97
    return mimic_2_10
98
99
100
@pytest.fixture
101
def mar_adata(rng) -> AnnData:
102
    """Generate MAR data using dependent columns."""
103
    data = rng.random((100, 10))
104
    # Assume missingness in the last column depends on the values of the first column
105
    missing_indicator = data[:, 0] < np.percentile(data[:, 0], 0.1 * 100)
106
    data[missing_indicator, -1] = np.nan  # Only last column has missing values dependent on the first column
107
108
    return AnnData(data)
109
110
111
@pytest.fixture
112
def mcar_adata(rng) -> AnnData:
113
    """Generate MCAR data by randomly sampling."""
114
    data = rng.random((100, 10))
115
    missing_indices = rng.choice(a=[False, True], size=data.shape, p=[1 - 0.1, 0.1])
116
    data[missing_indices] = np.nan
117
118
    return AnnData(data)
119
120
121
@pytest.fixture
122
def adata_mini():
123
    return read_csv(f"{TEST_DATA_PATH}/dataset1.csv", columns_obs_only=["glucose", "weight", "disease", "station"])
124
125
126
@pytest.fixture
127
def adata_move_obs_num() -> AnnData:
128
    return read_csv(TEST_DATA_PATH / "io/dataset_move_obs_num.csv")
129
130
131
@pytest.fixture
132
def adata_move_obs_mix() -> AnnData:
133
    return read_csv(TEST_DATA_PATH / "io/dataset_move_obs_mix.csv")
134
135
136
@pytest.fixture
137
def impute_num_adata() -> AnnData:
138
    adata = read_csv(dataset_path=f"{TEST_DATA_PATH}/imputation/test_impute_num.csv")
139
    return adata
140
141
142
@pytest.fixture
143
def impute_adata() -> AnnData:
144
    adata = read_csv(dataset_path=f"{TEST_DATA_PATH}/imputation/test_impute.csv")
145
    return adata
146
147
148
@pytest.fixture
149
def impute_iris_adata() -> AnnData:
150
    adata = read_csv(dataset_path=f"{TEST_DATA_PATH}/imputation/test_impute_iris.csv")
151
    return adata
152
153
154
@pytest.fixture
155
def impute_titanic_adata():
156
    adata = read_csv(dataset_path=f"{TEST_DATA_PATH}/imputation/test_impute_titanic.csv")
157
    return adata
158
159
160
@pytest.fixture
161
def encode_ds_1_adata() -> AnnData:
162
    adata = read_csv(dataset_path=f"{TEST_DATA_PATH}/encode/dataset1.csv")
163
    return adata
164
165
166
@pytest.fixture
167
def encode_ds_2_adata() -> AnnData:
168
    adata = read_csv(dataset_path=f"{TEST_DATA_PATH}/encode/dataset2.csv")
169
    return adata
170
171
172
# simplified from https://github.com/scverse/scanpy/blob/main/scanpy/tests/conftest.py
173
@pytest.fixture
174
def check_same_image(tmp_path):
175
    def check_same_image(
176
        fig: Figure,
177
        base_path: Path | os.PathLike,
178
        *,
179
        tol: float,
180
    ) -> None:
181
        expected = Path(base_path).parent / (Path(base_path).name + "_expected.png")
182
        if not Path(expected).is_file():
183
            raise OSError(f"No expected output found at {expected}.")
184
        actual = tmp_path / "actual.png"
185
186
        fig.savefig(actual, dpi=80)
187
188
        result = compare_images(expected, actual, tol=tol, in_decorator=True)
189
190
        if result is None:
191
            return None
192
193
        raise AssertionError(result)
194
195
    return check_same_image
196
197
198
def asarray(a):
199
    import numpy as np
200
201
    return np.asarray(a)
202
203
204
def as_dense_dask_array(a, chunk_size=1000):
205
    import dask.array as da
206
207
    return da.from_array(a, chunks=chunk_size)
208
209
210
ARRAY_TYPES = (asarray, as_dense_dask_array)