Switch to side-by-side view

--- a
+++ b/tests/preprocessing/test_imputation.py
@@ -0,0 +1,358 @@
+import os
+import warnings
+from collections.abc import Iterable
+from pathlib import Path
+
+import dask.array as da
+import numpy as np
+import pytest
+from anndata import AnnData
+from scipy import sparse
+from sklearn.exceptions import ConvergenceWarning
+
+from ehrapy.anndata.anndata_ext import _are_ndarrays_equal, _is_val_missing, _to_dense_matrix
+from ehrapy.preprocessing._imputation import (
+    _warn_imputation_threshold,
+    explicit_impute,
+    knn_impute,
+    mice_forest_impute,
+    miss_forest_impute,
+    simple_impute,
+)
+from tests.conftest import ARRAY_TYPES, TEST_DATA_PATH
+
+CURRENT_DIR = Path(__file__).parent
+_TEST_PATH = f"{TEST_DATA_PATH}/imputation"
+
+
+def _base_check_imputation(
+    adata_before_imputation: AnnData,
+    adata_after_imputation: AnnData,
+    before_imputation_layer: str | None = None,
+    after_imputation_layer: str | None = None,
+    imputed_var_names: Iterable[str] | None = None,
+):
+    """Provides a base check for all imputations:
+
+    - Imputation doesn't leave any NaN behind
+    - Imputation doesn't modify anything in non-imputated columns (if the imputation on a subset was requested)
+    - Imputation doesn't modify any data that wasn't NaN
+
+    Args:
+        adata_before_imputation: AnnData before imputation
+        adata_after_imputation: AnnData after imputation
+        before_imputation_layer: Layer to consider in the original ``AnnData``, ``X`` if not specified
+        after_imputation_layer: Layer to consider in the imputated ``AnnData``, ``X`` if not specified
+        imputed_var_names: Names of the features that were imputated, will consider all of them if not specified
+
+    Raises:
+        AssertionError: If any of the checks fail.
+    """
+    # Convert dask arrays to numpy arrays
+    if isinstance(adata_before_imputation.X, da.Array):
+        adata_before_imputation.X = adata_before_imputation.X.compute()
+    if isinstance(adata_after_imputation.X, da.Array):
+        adata_after_imputation.X = adata_after_imputation.X.compute()
+
+    layer_before = _to_dense_matrix(adata_before_imputation, before_imputation_layer)
+    layer_after = _to_dense_matrix(adata_after_imputation, after_imputation_layer)
+
+    if layer_before.shape != layer_after.shape:
+        raise AssertionError("The shapes of the two layers do not match")
+
+    var_indices = (
+        np.arange(layer_before.shape[1])
+        if imputed_var_names is None
+        else [
+            adata_before_imputation.var_names.get_loc(var_name)
+            for var_name in imputed_var_names
+            if var_name in imputed_var_names
+        ]
+    )
+
+    before_nan_mask = _is_val_missing(layer_before)
+    imputed_mask = np.zeros(layer_before.shape[1], dtype=bool)
+    imputed_mask[var_indices] = True
+
+    # Ensure no NaN remains in the imputed columns of layer_after
+    if np.any(before_nan_mask[:, imputed_mask] & _is_val_missing(layer_after[:, imputed_mask])):
+        raise AssertionError("NaN found in imputed columns of layer_after.")
+
+    # Ensure unchanged values outside imputed columns
+    unchanged_mask = ~imputed_mask
+    if not _are_ndarrays_equal(layer_before[:, unchanged_mask], layer_after[:, unchanged_mask]):
+        raise AssertionError("Values outside imputed columns were modified.")
+
+    # Ensure imputation does not alter non-NaN values in the imputed columns
+    imputed_non_nan_mask = (~before_nan_mask) & imputed_mask
+    if not _are_ndarrays_equal(layer_before[imputed_non_nan_mask], layer_after[imputed_non_nan_mask]):
+        raise AssertionError("Non-NaN values in imputed columns were modified.")
+
+    # If reaching here: all checks passed
+    return
+
+
+def test_base_check_imputation_incompatible_shapes(impute_num_adata):
+    adata_imputed = knn_impute(impute_num_adata, copy=True)
+    with pytest.raises(AssertionError):
+        _base_check_imputation(impute_num_adata, adata_imputed[1:, :])
+    with pytest.raises(AssertionError):
+        _base_check_imputation(impute_num_adata, adata_imputed[:, 1:])
+
+
+def test_base_check_imputation_nan_detected_after_complete_imputation(impute_num_adata):
+    adata_imputed = knn_impute(impute_num_adata, copy=True)
+    adata_imputed.X[0, 2] = np.nan
+    with pytest.raises(AssertionError):
+        _base_check_imputation(impute_num_adata, adata_imputed)
+
+
+def test_base_check_imputation_nan_detected_after_partial_imputation(impute_num_adata):
+    var_names = ("col2", "col3")
+    adata_imputed = knn_impute(impute_num_adata, var_names=var_names, copy=True)
+    adata_imputed.X[0, 2] = np.nan
+    with pytest.raises(AssertionError):
+        _base_check_imputation(impute_num_adata, adata_imputed, imputed_var_names=var_names)
+
+
+def test_base_check_imputation_nan_ignored_if_not_in_imputed_column(impute_num_adata):
+    var_names = ("col2", "col3")
+    adata_imputed = knn_impute(impute_num_adata, var_names=var_names, copy=True)
+    # col1 has a NaN at row 2, should get ignored
+    _base_check_imputation(impute_num_adata, adata_imputed, imputed_var_names=var_names)
+
+
+def test_base_check_imputation_change_detected_in_non_imputed_column(impute_num_adata):
+    var_names = ("col2", "col3")
+    adata_imputed = knn_impute(impute_num_adata, var_names=var_names, copy=True)
+    # col1 has a NaN at row 2, let's simulate it has been imputed by mistake
+    adata_imputed.X[2, 0] = 42.0
+    with pytest.raises(AssertionError):
+        _base_check_imputation(impute_num_adata, adata_imputed, imputed_var_names=var_names)
+
+
+def test_base_check_imputation_change_detected_in_imputed_column(impute_num_adata):
+    adata_imputed = knn_impute(impute_num_adata, copy=True)
+    # col3 didn't have a NaN at row 1, let's simulate it has been modified by mistake
+    adata_imputed.X[1, 2] = 42.0
+    with pytest.raises(AssertionError):
+        _base_check_imputation(impute_num_adata, adata_imputed)
+
+
+def test_mean_impute_no_copy(impute_num_adata):
+    adata_not_imputed = impute_num_adata.copy()
+    simple_impute(impute_num_adata)
+
+    _base_check_imputation(adata_not_imputed, impute_num_adata)
+
+
+def test_mean_impute_copy(impute_num_adata):
+    adata_imputed = simple_impute(impute_num_adata, copy=True)
+
+    assert id(impute_num_adata) != id(adata_imputed)
+    _base_check_imputation(impute_num_adata, adata_imputed)
+
+
+def test_mean_impute_throws_error_non_numerical(impute_adata):
+    with pytest.raises(ValueError):
+        simple_impute(impute_adata)
+
+
+def test_mean_impute_subset(impute_adata):
+    var_names = ("intcol", "indexcol")
+    adata_imputed = simple_impute(impute_adata, var_names=var_names, copy=True)
+
+    _base_check_imputation(impute_adata, adata_imputed, imputed_var_names=var_names)
+    assert np.any([item != item for item in adata_imputed.X[::, 3:4]])
+
+
+def test_median_impute_no_copy(impute_num_adata):
+    adata_not_imputed = impute_num_adata.copy()
+    simple_impute(impute_num_adata, strategy="median")
+
+    _base_check_imputation(adata_not_imputed, impute_num_adata)
+
+
+def test_median_impute_copy(impute_num_adata):
+    adata_imputed = simple_impute(impute_num_adata, strategy="median", copy=True)
+
+    _base_check_imputation(impute_num_adata, adata_imputed)
+    assert id(impute_num_adata) != id(adata_imputed)
+
+
+def test_median_impute_throws_error_non_numerical(impute_adata):
+    with pytest.raises(ValueError):
+        simple_impute(impute_adata, strategy="median")
+
+
+def test_median_impute_subset(impute_adata):
+    var_names = ("intcol", "indexcol")
+    adata_imputed = simple_impute(impute_adata, var_names=var_names, strategy="median", copy=True)
+
+    _base_check_imputation(impute_adata, adata_imputed, imputed_var_names=var_names)
+
+
+def test_most_frequent_impute_no_copy(impute_adata):
+    adata_not_imputed = impute_adata.copy()
+    simple_impute(impute_adata, strategy="most_frequent")
+
+    _base_check_imputation(adata_not_imputed, impute_adata)
+
+
+def test_most_frequent_impute_copy(impute_adata):
+    adata_imputed = simple_impute(impute_adata, strategy="most_frequent", copy=True)
+
+    _base_check_imputation(impute_adata, adata_imputed)
+    assert id(impute_adata) != id(adata_imputed)
+
+
+def test_unknown_simple_imputation_strategy(impute_adata):
+    with pytest.raises(ValueError):
+        simple_impute(impute_adata, strategy="invalid_strategy", copy=True)  # type: ignore
+
+
+def test_most_frequent_impute_subset(impute_adata):
+    var_names = ("intcol", "strcol")
+    adata_imputed = simple_impute(impute_adata, var_names=var_names, strategy="most_frequent", copy=True)
+
+    _base_check_imputation(impute_adata, adata_imputed, imputed_var_names=var_names)
+
+
+def test_knn_impute_check_backend(impute_num_adata):
+    knn_impute(impute_num_adata, backend="faiss", copy=True)
+    knn_impute(impute_num_adata, backend="scikit-learn", copy=True)
+    with pytest.raises(
+        ValueError,
+        match="Unknown backend 'invalid_backend' for KNN imputation. Choose between 'scikit-learn' and 'faiss'.",
+    ):
+        knn_impute(impute_num_adata, backend="invalid_backend")  # type: ignore
+
+
+def test_knn_impute_no_copy(impute_num_adata):
+    adata_not_imputed = impute_num_adata.copy()
+    knn_impute(impute_num_adata)
+
+    _base_check_imputation(adata_not_imputed, impute_num_adata)
+
+
+def test_knn_impute_copy(impute_num_adata):
+    adata_imputed = knn_impute(impute_num_adata, n_neighbors=3, copy=True)
+
+    _base_check_imputation(impute_num_adata, adata_imputed)
+    assert id(impute_num_adata) != id(adata_imputed)
+
+
+def test_knn_impute_non_numerical_data(impute_adata):
+    with pytest.raises(ValueError):
+        knn_impute(impute_adata, n_neighbors=3, copy=True)
+
+
+def test_knn_impute_numerical_data(impute_num_adata):
+    adata_imputed = knn_impute(impute_num_adata, copy=True)
+
+    _base_check_imputation(impute_num_adata, adata_imputed)
+
+
+def test_missforest_impute_non_numerical_data(impute_adata):
+    with pytest.raises(ValueError):
+        miss_forest_impute(impute_adata, copy=True)
+
+
+def test_missforest_impute_numerical_data(impute_num_adata):
+    warnings.filterwarnings("ignore", category=ConvergenceWarning)
+    adata_imputed = miss_forest_impute(impute_num_adata, copy=True)
+
+    _base_check_imputation(impute_num_adata, adata_imputed)
+
+
+def test_missforest_impute_subset(impute_num_adata):
+    warnings.filterwarnings("ignore", category=ConvergenceWarning)
+    var_names = ("col2", "col3")
+    adata_imputed = miss_forest_impute(impute_num_adata, var_names=var_names, copy=True)
+
+    _base_check_imputation(impute_num_adata, adata_imputed, imputed_var_names=var_names)
+
+
+@pytest.mark.parametrize(
+    "array_type,expected_error",
+    [
+        (np.array, None),
+        (da.from_array, NotImplementedError),
+        (sparse.csr_matrix, NotImplementedError),
+    ],
+)
+def test_miceforest_array_types(impute_num_adata, array_type, expected_error):
+    impute_num_adata.X = array_type(impute_num_adata.X)
+    if expected_error:
+        with pytest.raises(expected_error):
+            mice_forest_impute(impute_num_adata, copy=True)
+
+
+@pytest.mark.skipif(os.name == "Darwin", reason="miceforest Imputation not supported by MacOS.")
+def test_miceforest_impute_no_copy(impute_iris_adata):
+    adata_not_imputed = impute_iris_adata.copy()
+    mice_forest_impute(impute_iris_adata)
+
+    _base_check_imputation(adata_not_imputed, impute_iris_adata)
+
+
+@pytest.mark.skipif(os.name == "Darwin", reason="miceforest Imputation not supported by MacOS.")
+def test_miceforest_impute_copy(impute_iris_adata):
+    adata_imputed = mice_forest_impute(impute_iris_adata, copy=True)
+
+    _base_check_imputation(impute_iris_adata, adata_imputed)
+    assert id(impute_iris_adata) != id(adata_imputed)
+
+
+@pytest.mark.skipif(os.name == "Darwin", reason="miceforest Imputation not supported by MacOS.")
+def test_miceforest_impute_non_numerical_data(impute_titanic_adata):
+    with pytest.raises(ValueError):
+        mice_forest_impute(impute_titanic_adata)
+
+
+@pytest.mark.skipif(os.name == "Darwin", reason="miceforest Imputation not supported by MacOS.")
+def test_miceforest_impute_numerical_data(impute_iris_adata):
+    adata_not_imputed = impute_iris_adata.copy()
+    mice_forest_impute(impute_iris_adata)
+
+    _base_check_imputation(adata_not_imputed, impute_iris_adata)
+
+
+@pytest.mark.parametrize(
+    "array_type,expected_error",
+    [
+        (np.array, None),
+        (da.from_array, None),
+        (sparse.csr_matrix, NotImplementedError),
+    ],
+)
+def test_explicit_impute_array_types(impute_num_adata, array_type, expected_error):
+    impute_num_adata.X = array_type(impute_num_adata.X)
+    if expected_error:
+        with pytest.raises(expected_error):
+            explicit_impute(impute_num_adata, replacement=1011, copy=True)
+
+
+@pytest.mark.parametrize("array_type", ARRAY_TYPES)
+def test_explicit_impute_all(array_type, impute_num_adata):
+    impute_num_adata.X = array_type(impute_num_adata.X)
+    warnings.filterwarnings("ignore", category=FutureWarning)
+    adata_imputed = explicit_impute(impute_num_adata, replacement=1011, copy=True)
+
+    _base_check_imputation(impute_num_adata, adata_imputed)
+    assert np.sum([adata_imputed.X == 1011]) == 3
+
+
+@pytest.mark.parametrize("array_type", ARRAY_TYPES)
+def test_explicit_impute_subset(impute_adata, array_type):
+    impute_adata.X = array_type(impute_adata.X)
+    adata_imputed = explicit_impute(impute_adata, replacement={"strcol": "REPLACED", "intcol": 1011}, copy=True)
+
+    _base_check_imputation(impute_adata, adata_imputed, imputed_var_names=("strcol", "intcol"))
+    assert np.sum([adata_imputed.X == 1011]) == 1
+    assert np.sum([adata_imputed.X == "REPLACED"]) == 1
+
+
+def test_warning(impute_num_adata):
+    warning_results = _warn_imputation_threshold(impute_num_adata, threshold=20, var_names=None)
+    assert warning_results == {"col1": 25, "col3": 50}