--- a +++ b/tests/conftest.py @@ -0,0 +1,210 @@ +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING + +import numpy as np +import pandas as pd +import pytest +from anndata import AnnData +from matplotlib.testing.compare import compare_images + +import ehrapy as ep +from ehrapy.io import read_csv + +if TYPE_CHECKING: + import os + + from matplotlib.figure import Figure + +TEST_DATA_PATH = Path(__file__).parent / "data" + + +@pytest.fixture +def root_dir(): + return Path(__file__).resolve().parent + + +@pytest.fixture +def rng(): + return np.random.default_rng(seed=42) + + +@pytest.fixture +def obs_data(): + return { + "disease": ["cancer", "tumor"], + "country": ["Germany", "switzerland"], + "sex": ["male", "female"], + } + + +@pytest.fixture +def var_data(): + return { + "alive": ["yes", "no", "maybe"], + "hospital": ["hospital 1", "hospital 2", "hospital 1"], + "crazy": ["yes", "yes", "yes"], + } + + +@pytest.fixture +def missing_values_adata(obs_data, var_data): + return AnnData( + X=np.array([[0.21, np.nan, 41.42], [np.nan, np.nan, 7.234]], dtype=np.float32), + obs=pd.DataFrame(data=obs_data), + var=pd.DataFrame(data=var_data, index=["Acetaminophen", "hospital", "crazy"]), + ) + + +@pytest.fixture +def lab_measurements_simple_adata(obs_data, var_data): + X = np.array([[73, 0.02, 1.00], [148, 0.25, 3.55]], dtype=np.float32) + return AnnData( + X=X, + obs=pd.DataFrame(data=obs_data), + var=pd.DataFrame(data=var_data, index=["Acetaminophen", "Acetoacetic acid", "Beryllium, toxic"]), + ) + + +@pytest.fixture +def lab_measurements_layer_adata(obs_data, var_data): + X = np.array([[73, 0.02, 1.00], [148, 0.25, 3.55]], dtype=np.float32) + return AnnData( + X=X, + obs=pd.DataFrame(data=obs_data), + var=pd.DataFrame(data=var_data, index=["Acetaminophen", "Acetoacetic acid", "Beryllium, toxic"]), + layers={"layer_copy": X}, + ) + + +@pytest.fixture +def mimic_2(): + adata = ep.dt.mimic_2() + return adata + + +@pytest.fixture +def mimic_2_encoded(): + adata = ep.dt.mimic_2(encoded=True) + return adata + + +@pytest.fixture +def mimic_2_10(): + mimic_2_10 = ep.dt.mimic_2()[:10] + + return mimic_2_10 + + +@pytest.fixture +def mar_adata(rng) -> AnnData: + """Generate MAR data using dependent columns.""" + data = rng.random((100, 10)) + # Assume missingness in the last column depends on the values of the first column + missing_indicator = data[:, 0] < np.percentile(data[:, 0], 0.1 * 100) + data[missing_indicator, -1] = np.nan # Only last column has missing values dependent on the first column + + return AnnData(data) + + +@pytest.fixture +def mcar_adata(rng) -> AnnData: + """Generate MCAR data by randomly sampling.""" + data = rng.random((100, 10)) + missing_indices = rng.choice(a=[False, True], size=data.shape, p=[1 - 0.1, 0.1]) + data[missing_indices] = np.nan + + return AnnData(data) + + +@pytest.fixture +def adata_mini(): + return read_csv(f"{TEST_DATA_PATH}/dataset1.csv", columns_obs_only=["glucose", "weight", "disease", "station"]) + + +@pytest.fixture +def adata_move_obs_num() -> AnnData: + return read_csv(TEST_DATA_PATH / "io/dataset_move_obs_num.csv") + + +@pytest.fixture +def adata_move_obs_mix() -> AnnData: + return read_csv(TEST_DATA_PATH / "io/dataset_move_obs_mix.csv") + + +@pytest.fixture +def impute_num_adata() -> AnnData: + adata = read_csv(dataset_path=f"{TEST_DATA_PATH}/imputation/test_impute_num.csv") + return adata + + +@pytest.fixture +def impute_adata() -> AnnData: + adata = read_csv(dataset_path=f"{TEST_DATA_PATH}/imputation/test_impute.csv") + return adata + + +@pytest.fixture +def impute_iris_adata() -> AnnData: + adata = read_csv(dataset_path=f"{TEST_DATA_PATH}/imputation/test_impute_iris.csv") + return adata + + +@pytest.fixture +def impute_titanic_adata(): + adata = read_csv(dataset_path=f"{TEST_DATA_PATH}/imputation/test_impute_titanic.csv") + return adata + + +@pytest.fixture +def encode_ds_1_adata() -> AnnData: + adata = read_csv(dataset_path=f"{TEST_DATA_PATH}/encode/dataset1.csv") + return adata + + +@pytest.fixture +def encode_ds_2_adata() -> AnnData: + adata = read_csv(dataset_path=f"{TEST_DATA_PATH}/encode/dataset2.csv") + return adata + + +# simplified from https://github.com/scverse/scanpy/blob/main/scanpy/tests/conftest.py +@pytest.fixture +def check_same_image(tmp_path): + def check_same_image( + fig: Figure, + base_path: Path | os.PathLike, + *, + tol: float, + ) -> None: + expected = Path(base_path).parent / (Path(base_path).name + "_expected.png") + if not Path(expected).is_file(): + raise OSError(f"No expected output found at {expected}.") + actual = tmp_path / "actual.png" + + fig.savefig(actual, dpi=80) + + result = compare_images(expected, actual, tol=tol, in_decorator=True) + + if result is None: + return None + + raise AssertionError(result) + + return check_same_image + + +def asarray(a): + import numpy as np + + return np.asarray(a) + + +def as_dense_dask_array(a, chunk_size=1000): + import dask.array as da + + return da.from_array(a, chunks=chunk_size) + + +ARRAY_TYPES = (asarray, as_dense_dask_array)