Switch to side-by-side view

--- a
+++ b/tests/preprocessing/test_quality_control.py
@@ -0,0 +1,244 @@
+from pathlib import Path
+
+import numpy as np
+import pandas as pd
+import pytest
+from anndata import AnnData
+
+import ehrapy as ep
+from ehrapy.io._read import read_csv
+from ehrapy.preprocessing._encoding import encode
+from ehrapy.preprocessing._quality_control import _compute_obs_metrics, _compute_var_metrics, mcar_test
+from tests.conftest import ARRAY_TYPES, TEST_DATA_PATH, as_dense_dask_array
+
+CURRENT_DIR = Path(__file__).parent
+_TEST_PATH_ENCODE = f"{TEST_DATA_PATH}/encode"
+
+
+@pytest.mark.parametrize("array_type", ARRAY_TYPES)
+def test_qc_metrics_vanilla(array_type, missing_values_adata):
+    adata = missing_values_adata
+    adata.X = array_type(adata.X)
+    modification_copy = adata.copy()
+    obs_metrics, var_metrics = ep.pp.qc_metrics(adata)
+
+    assert np.array_equal(obs_metrics["missing_values_abs"].values, np.array([1, 2]))
+    assert np.allclose(obs_metrics["missing_values_pct"].values, np.array([33.3333, 66.6667]))
+
+    assert np.array_equal(var_metrics["missing_values_abs"].values, np.array([1, 2, 0]))
+    assert np.allclose(var_metrics["missing_values_pct"].values, np.array([50.0, 100.0, 0.0]))
+    assert np.allclose(var_metrics["mean"].values, np.array([0.21, np.nan, 24.327]), equal_nan=True)
+    assert np.allclose(var_metrics["median"].values, np.array([0.21, np.nan, 24.327]), equal_nan=True)
+    assert np.allclose(var_metrics["min"].values, np.array([0.21, np.nan, 7.234]), equal_nan=True)
+    assert np.allclose(var_metrics["max"].values, np.array([0.21, np.nan, 41.419998]), equal_nan=True)
+    assert (~var_metrics["iqr_outliers"]).all()
+
+    # check that none of the columns were modified
+    for key in modification_copy.obs.keys():
+        assert np.array_equal(modification_copy.obs[key], adata.obs[key])
+    for key in modification_copy.var.keys():
+        assert np.array_equal(modification_copy.var[key], adata.var[key])
+
+
+@pytest.mark.parametrize("array_type", ARRAY_TYPES)
+def test_obs_qc_metrics(array_type, missing_values_adata):
+    missing_values_adata.X = array_type(missing_values_adata.X)
+    mtx = missing_values_adata.X
+    obs_metrics = _compute_obs_metrics(mtx, missing_values_adata)
+
+    assert np.array_equal(obs_metrics["missing_values_abs"].values, np.array([1, 2]))
+    assert np.allclose(obs_metrics["missing_values_pct"].values, np.array([33.3333, 66.6667]))
+
+
+@pytest.mark.parametrize("array_type", ARRAY_TYPES)
+def test_var_qc_metrics(array_type, missing_values_adata):
+    missing_values_adata.X = array_type(missing_values_adata.X)
+    mtx = missing_values_adata.X
+    var_metrics = _compute_var_metrics(mtx, missing_values_adata)
+
+    assert np.array_equal(var_metrics["missing_values_abs"].values, np.array([1, 2, 0]))
+    assert np.allclose(var_metrics["missing_values_pct"].values, np.array([50.0, 100.0, 0.0]))
+    assert np.allclose(var_metrics["mean"].values, np.array([0.21, np.nan, 24.327]), equal_nan=True)
+    assert np.allclose(var_metrics["median"].values, np.array([0.21, np.nan, 24.327]), equal_nan=True)
+    assert np.allclose(var_metrics["min"].values, np.array([0.21, np.nan, 7.234]), equal_nan=True)
+    assert np.allclose(var_metrics["max"].values, np.array([0.21, np.nan, 41.419998]), equal_nan=True)
+    assert (~var_metrics["iqr_outliers"]).all()
+
+
+@pytest.mark.parametrize(
+    "array_type, expected_error",
+    [
+        (np.array, None),
+        (as_dense_dask_array, None),
+        # can't test sparse matrices because they don't support string values
+    ],
+)
+def test_obs_qc_metrics_array_types(array_type, expected_error):
+    adata = read_csv(dataset_path=f"{_TEST_PATH_ENCODE}/dataset1.csv")
+    adata.X = array_type(adata.X)
+    mtx = adata.X
+    if expected_error:
+        with pytest.raises(expected_error):
+            _compute_obs_metrics(mtx, adata)
+
+
+def test_obs_nan_qc_metrics():
+    adata = read_csv(dataset_path=f"{_TEST_PATH_ENCODE}/dataset1.csv")
+    adata.X[0][4] = np.nan
+    adata2 = encode(adata, encodings={"one-hot": ["clinic_day"]})
+    mtx = adata2.X
+    obs_metrics = _compute_obs_metrics(mtx, adata2)
+    assert obs_metrics.iloc[0].iloc[0] == 1
+
+
+@pytest.mark.parametrize(
+    "array_type, expected_error",
+    [
+        (np.array, None),
+        (as_dense_dask_array, None),
+        # can't test sparse matrices because they don't support string values
+    ],
+)
+def test_var_qc_metrics_array_types(array_type, expected_error):
+    adata = read_csv(dataset_path=f"{_TEST_PATH_ENCODE}/dataset1.csv")
+    adata.X = array_type(adata.X)
+    mtx = adata.X
+    if expected_error:
+        with pytest.raises(expected_error):
+            _compute_var_metrics(mtx, adata)
+
+
+def test_var_encoding_mode_does_not_modify_original_matrix():
+    adata = read_csv(dataset_path=f"{_TEST_PATH_ENCODE}/dataset1.csv")
+    adata2 = encode(adata, encodings={"one-hot": ["clinic_day"]})
+    mtx_copy = adata2.X.copy()
+    _compute_var_metrics(adata2.X, adata2)
+    assert np.array_equal(mtx_copy, adata2.X)
+
+
+def test_var_nan_qc_metrics():
+    adata = read_csv(dataset_path=f"{_TEST_PATH_ENCODE}/dataset1.csv")
+    adata.X[0][4] = np.nan
+    adata2 = encode(adata, encodings={"one-hot": ["clinic_day"]})
+    mtx = adata2.X
+    var_metrics = _compute_var_metrics(mtx, adata2)
+    assert var_metrics.iloc[0].iloc[0] == 1
+    assert var_metrics.iloc[1].iloc[0] == 1
+    assert var_metrics.iloc[2].iloc[0] == 1
+    assert var_metrics.iloc[3].iloc[0] == 1
+    assert var_metrics.iloc[4].iloc[0] == 1
+
+
+def test_calculate_qc_metrics(missing_values_adata):
+    obs_metrics, var_metrics = ep.pp.qc_metrics(missing_values_adata)
+
+    assert obs_metrics is not None
+    assert var_metrics is not None
+
+    assert missing_values_adata.obs.missing_values_abs is not None
+    assert missing_values_adata.var.missing_values_abs is not None
+
+
+def test_qc_lab_measurements_simple(lab_measurements_simple_adata):
+    expected_obs_data = pd.Series(
+        data={
+            "Acetaminophen normal": [True, True],
+            "Acetoacetic acid normal": [True, False],
+            "Beryllium, toxic normal": [False, True],
+        }
+    )
+
+    ep.pp.qc_lab_measurements(
+        lab_measurements_simple_adata,
+        measurements=list(lab_measurements_simple_adata.var_names),
+        unit="SI",
+    )
+
+    assert (
+        list(lab_measurements_simple_adata.obs["Acetaminophen normal"]) == (expected_obs_data["Acetaminophen normal"])
+    )
+    assert (
+        list(lab_measurements_simple_adata.obs["Acetoacetic acid normal"])
+        == (expected_obs_data["Acetoacetic acid normal"])
+    )
+    assert (
+        list(lab_measurements_simple_adata.obs["Beryllium, toxic normal"])
+        == (expected_obs_data["Beryllium, toxic normal"])
+    )
+
+
+def test_qc_lab_measurements_simple_layer(lab_measurements_layer_adata):
+    expected_obs_data = pd.Series(
+        data={
+            "Acetaminophen normal": [True, True],
+            "Acetoacetic acid normal": [True, False],
+            "Beryllium, toxic normal": [False, True],
+        }
+    )
+
+    ep.pp.qc_lab_measurements(
+        lab_measurements_layer_adata,
+        measurements=list(lab_measurements_layer_adata.var_names),
+        unit="SI",
+        layer="layer_copy",
+    )
+
+    assert list(lab_measurements_layer_adata.obs["Acetaminophen normal"]) == (expected_obs_data["Acetaminophen normal"])
+    assert (
+        list(lab_measurements_layer_adata.obs["Acetoacetic acid normal"])
+        == (expected_obs_data["Acetoacetic acid normal"])
+    )
+    assert (
+        list(lab_measurements_layer_adata.obs["Beryllium, toxic normal"])
+        == (expected_obs_data["Beryllium, toxic normal"])
+    )
+
+
+def test_qc_lab_measurements_age():
+    # TODO
+    pass
+
+
+def test_qc_lab_measurements_sex():
+    # TODO
+    pass
+
+
+def test_qc_lab_measurements_ethnicity():
+    # TODO
+    pass
+
+
+def test_qc_lab_measurements_multiple_measurements():
+    data = pd.DataFrame(np.array([[100, 98], [162, 107]]), columns=["oxygen saturation", "glucose"], index=[0, 1])
+
+    with pytest.raises(ValueError):
+        adata = ep.ad.df_to_anndata(data)
+        ep.pp.qc_lab_measurements(adata, measurements=["oxygen saturation", "glucose"], unit="SI")
+
+
+@pytest.mark.parametrize(
+    "method,expected_output_type",
+    [
+        ("little", float),
+        ("ttest", pd.DataFrame),
+    ],
+)
+def test_mcar_test_method_output_types(mar_adata, method, expected_output_type):
+    """Tests if mcar_test returns the correct output type for different methods."""
+    output = mcar_test(mar_adata, method=method)
+    assert isinstance(output, expected_output_type), (
+        f"Output type for method '{method}' should be {expected_output_type}, got {type(output)} instead."
+    )
+
+
+def test_mar_data_identification(mar_adata):
+    """Test that mcar_test correctly identifies data as not MCAR (i.e., MAR or NMAR)."""
+    p_value = mcar_test(mar_adata, method="little")
+    assert p_value <= 0.05, "The test should significantly reject the MCAR hypothesis for MAR data."
+
+
+def test_mcar_identification(mcar_adata):
+    """Test that mcar_test correctly identifies data as MCAR."""
+    p_value = mcar_test(mcar_adata, method="little")
+    assert p_value > 0.05, "The test should not significantly accept the MCAR hypothesis for MCAR data."