a b/tests/preprocessing/test_imputation.py
1
import os
2
import warnings
3
from collections.abc import Iterable
4
from pathlib import Path
5
6
import dask.array as da
7
import numpy as np
8
import pytest
9
from anndata import AnnData
10
from scipy import sparse
11
from sklearn.exceptions import ConvergenceWarning
12
13
from ehrapy.anndata.anndata_ext import _are_ndarrays_equal, _is_val_missing, _to_dense_matrix
14
from ehrapy.preprocessing._imputation import (
15
    _warn_imputation_threshold,
16
    explicit_impute,
17
    knn_impute,
18
    mice_forest_impute,
19
    miss_forest_impute,
20
    simple_impute,
21
)
22
from tests.conftest import ARRAY_TYPES, TEST_DATA_PATH
23
24
CURRENT_DIR = Path(__file__).parent
25
_TEST_PATH = f"{TEST_DATA_PATH}/imputation"
26
27
28
def _base_check_imputation(
29
    adata_before_imputation: AnnData,
30
    adata_after_imputation: AnnData,
31
    before_imputation_layer: str | None = None,
32
    after_imputation_layer: str | None = None,
33
    imputed_var_names: Iterable[str] | None = None,
34
):
35
    """Provides a base check for all imputations:
36
37
    - Imputation doesn't leave any NaN behind
38
    - Imputation doesn't modify anything in non-imputated columns (if the imputation on a subset was requested)
39
    - Imputation doesn't modify any data that wasn't NaN
40
41
    Args:
42
        adata_before_imputation: AnnData before imputation
43
        adata_after_imputation: AnnData after imputation
44
        before_imputation_layer: Layer to consider in the original ``AnnData``, ``X`` if not specified
45
        after_imputation_layer: Layer to consider in the imputated ``AnnData``, ``X`` if not specified
46
        imputed_var_names: Names of the features that were imputated, will consider all of them if not specified
47
48
    Raises:
49
        AssertionError: If any of the checks fail.
50
    """
51
    # Convert dask arrays to numpy arrays
52
    if isinstance(adata_before_imputation.X, da.Array):
53
        adata_before_imputation.X = adata_before_imputation.X.compute()
54
    if isinstance(adata_after_imputation.X, da.Array):
55
        adata_after_imputation.X = adata_after_imputation.X.compute()
56
57
    layer_before = _to_dense_matrix(adata_before_imputation, before_imputation_layer)
58
    layer_after = _to_dense_matrix(adata_after_imputation, after_imputation_layer)
59
60
    if layer_before.shape != layer_after.shape:
61
        raise AssertionError("The shapes of the two layers do not match")
62
63
    var_indices = (
64
        np.arange(layer_before.shape[1])
65
        if imputed_var_names is None
66
        else [
67
            adata_before_imputation.var_names.get_loc(var_name)
68
            for var_name in imputed_var_names
69
            if var_name in imputed_var_names
70
        ]
71
    )
72
73
    before_nan_mask = _is_val_missing(layer_before)
74
    imputed_mask = np.zeros(layer_before.shape[1], dtype=bool)
75
    imputed_mask[var_indices] = True
76
77
    # Ensure no NaN remains in the imputed columns of layer_after
78
    if np.any(before_nan_mask[:, imputed_mask] & _is_val_missing(layer_after[:, imputed_mask])):
79
        raise AssertionError("NaN found in imputed columns of layer_after.")
80
81
    # Ensure unchanged values outside imputed columns
82
    unchanged_mask = ~imputed_mask
83
    if not _are_ndarrays_equal(layer_before[:, unchanged_mask], layer_after[:, unchanged_mask]):
84
        raise AssertionError("Values outside imputed columns were modified.")
85
86
    # Ensure imputation does not alter non-NaN values in the imputed columns
87
    imputed_non_nan_mask = (~before_nan_mask) & imputed_mask
88
    if not _are_ndarrays_equal(layer_before[imputed_non_nan_mask], layer_after[imputed_non_nan_mask]):
89
        raise AssertionError("Non-NaN values in imputed columns were modified.")
90
91
    # If reaching here: all checks passed
92
    return
93
94
95
def test_base_check_imputation_incompatible_shapes(impute_num_adata):
96
    adata_imputed = knn_impute(impute_num_adata, copy=True)
97
    with pytest.raises(AssertionError):
98
        _base_check_imputation(impute_num_adata, adata_imputed[1:, :])
99
    with pytest.raises(AssertionError):
100
        _base_check_imputation(impute_num_adata, adata_imputed[:, 1:])
101
102
103
def test_base_check_imputation_nan_detected_after_complete_imputation(impute_num_adata):
104
    adata_imputed = knn_impute(impute_num_adata, copy=True)
105
    adata_imputed.X[0, 2] = np.nan
106
    with pytest.raises(AssertionError):
107
        _base_check_imputation(impute_num_adata, adata_imputed)
108
109
110
def test_base_check_imputation_nan_detected_after_partial_imputation(impute_num_adata):
111
    var_names = ("col2", "col3")
112
    adata_imputed = knn_impute(impute_num_adata, var_names=var_names, copy=True)
113
    adata_imputed.X[0, 2] = np.nan
114
    with pytest.raises(AssertionError):
115
        _base_check_imputation(impute_num_adata, adata_imputed, imputed_var_names=var_names)
116
117
118
def test_base_check_imputation_nan_ignored_if_not_in_imputed_column(impute_num_adata):
119
    var_names = ("col2", "col3")
120
    adata_imputed = knn_impute(impute_num_adata, var_names=var_names, copy=True)
121
    # col1 has a NaN at row 2, should get ignored
122
    _base_check_imputation(impute_num_adata, adata_imputed, imputed_var_names=var_names)
123
124
125
def test_base_check_imputation_change_detected_in_non_imputed_column(impute_num_adata):
126
    var_names = ("col2", "col3")
127
    adata_imputed = knn_impute(impute_num_adata, var_names=var_names, copy=True)
128
    # col1 has a NaN at row 2, let's simulate it has been imputed by mistake
129
    adata_imputed.X[2, 0] = 42.0
130
    with pytest.raises(AssertionError):
131
        _base_check_imputation(impute_num_adata, adata_imputed, imputed_var_names=var_names)
132
133
134
def test_base_check_imputation_change_detected_in_imputed_column(impute_num_adata):
135
    adata_imputed = knn_impute(impute_num_adata, copy=True)
136
    # col3 didn't have a NaN at row 1, let's simulate it has been modified by mistake
137
    adata_imputed.X[1, 2] = 42.0
138
    with pytest.raises(AssertionError):
139
        _base_check_imputation(impute_num_adata, adata_imputed)
140
141
142
def test_mean_impute_no_copy(impute_num_adata):
143
    adata_not_imputed = impute_num_adata.copy()
144
    simple_impute(impute_num_adata)
145
146
    _base_check_imputation(adata_not_imputed, impute_num_adata)
147
148
149
def test_mean_impute_copy(impute_num_adata):
150
    adata_imputed = simple_impute(impute_num_adata, copy=True)
151
152
    assert id(impute_num_adata) != id(adata_imputed)
153
    _base_check_imputation(impute_num_adata, adata_imputed)
154
155
156
def test_mean_impute_throws_error_non_numerical(impute_adata):
157
    with pytest.raises(ValueError):
158
        simple_impute(impute_adata)
159
160
161
def test_mean_impute_subset(impute_adata):
162
    var_names = ("intcol", "indexcol")
163
    adata_imputed = simple_impute(impute_adata, var_names=var_names, copy=True)
164
165
    _base_check_imputation(impute_adata, adata_imputed, imputed_var_names=var_names)
166
    assert np.any([item != item for item in adata_imputed.X[::, 3:4]])
167
168
169
def test_median_impute_no_copy(impute_num_adata):
170
    adata_not_imputed = impute_num_adata.copy()
171
    simple_impute(impute_num_adata, strategy="median")
172
173
    _base_check_imputation(adata_not_imputed, impute_num_adata)
174
175
176
def test_median_impute_copy(impute_num_adata):
177
    adata_imputed = simple_impute(impute_num_adata, strategy="median", copy=True)
178
179
    _base_check_imputation(impute_num_adata, adata_imputed)
180
    assert id(impute_num_adata) != id(adata_imputed)
181
182
183
def test_median_impute_throws_error_non_numerical(impute_adata):
184
    with pytest.raises(ValueError):
185
        simple_impute(impute_adata, strategy="median")
186
187
188
def test_median_impute_subset(impute_adata):
189
    var_names = ("intcol", "indexcol")
190
    adata_imputed = simple_impute(impute_adata, var_names=var_names, strategy="median", copy=True)
191
192
    _base_check_imputation(impute_adata, adata_imputed, imputed_var_names=var_names)
193
194
195
def test_most_frequent_impute_no_copy(impute_adata):
196
    adata_not_imputed = impute_adata.copy()
197
    simple_impute(impute_adata, strategy="most_frequent")
198
199
    _base_check_imputation(adata_not_imputed, impute_adata)
200
201
202
def test_most_frequent_impute_copy(impute_adata):
203
    adata_imputed = simple_impute(impute_adata, strategy="most_frequent", copy=True)
204
205
    _base_check_imputation(impute_adata, adata_imputed)
206
    assert id(impute_adata) != id(adata_imputed)
207
208
209
def test_unknown_simple_imputation_strategy(impute_adata):
210
    with pytest.raises(ValueError):
211
        simple_impute(impute_adata, strategy="invalid_strategy", copy=True)  # type: ignore
212
213
214
def test_most_frequent_impute_subset(impute_adata):
215
    var_names = ("intcol", "strcol")
216
    adata_imputed = simple_impute(impute_adata, var_names=var_names, strategy="most_frequent", copy=True)
217
218
    _base_check_imputation(impute_adata, adata_imputed, imputed_var_names=var_names)
219
220
221
def test_knn_impute_check_backend(impute_num_adata):
222
    knn_impute(impute_num_adata, backend="faiss", copy=True)
223
    knn_impute(impute_num_adata, backend="scikit-learn", copy=True)
224
    with pytest.raises(
225
        ValueError,
226
        match="Unknown backend 'invalid_backend' for KNN imputation. Choose between 'scikit-learn' and 'faiss'.",
227
    ):
228
        knn_impute(impute_num_adata, backend="invalid_backend")  # type: ignore
229
230
231
def test_knn_impute_no_copy(impute_num_adata):
232
    adata_not_imputed = impute_num_adata.copy()
233
    knn_impute(impute_num_adata)
234
235
    _base_check_imputation(adata_not_imputed, impute_num_adata)
236
237
238
def test_knn_impute_copy(impute_num_adata):
239
    adata_imputed = knn_impute(impute_num_adata, n_neighbors=3, copy=True)
240
241
    _base_check_imputation(impute_num_adata, adata_imputed)
242
    assert id(impute_num_adata) != id(adata_imputed)
243
244
245
def test_knn_impute_non_numerical_data(impute_adata):
246
    with pytest.raises(ValueError):
247
        knn_impute(impute_adata, n_neighbors=3, copy=True)
248
249
250
def test_knn_impute_numerical_data(impute_num_adata):
251
    adata_imputed = knn_impute(impute_num_adata, copy=True)
252
253
    _base_check_imputation(impute_num_adata, adata_imputed)
254
255
256
def test_missforest_impute_non_numerical_data(impute_adata):
257
    with pytest.raises(ValueError):
258
        miss_forest_impute(impute_adata, copy=True)
259
260
261
def test_missforest_impute_numerical_data(impute_num_adata):
262
    warnings.filterwarnings("ignore", category=ConvergenceWarning)
263
    adata_imputed = miss_forest_impute(impute_num_adata, copy=True)
264
265
    _base_check_imputation(impute_num_adata, adata_imputed)
266
267
268
def test_missforest_impute_subset(impute_num_adata):
269
    warnings.filterwarnings("ignore", category=ConvergenceWarning)
270
    var_names = ("col2", "col3")
271
    adata_imputed = miss_forest_impute(impute_num_adata, var_names=var_names, copy=True)
272
273
    _base_check_imputation(impute_num_adata, adata_imputed, imputed_var_names=var_names)
274
275
276
@pytest.mark.parametrize(
277
    "array_type,expected_error",
278
    [
279
        (np.array, None),
280
        (da.from_array, NotImplementedError),
281
        (sparse.csr_matrix, NotImplementedError),
282
    ],
283
)
284
def test_miceforest_array_types(impute_num_adata, array_type, expected_error):
285
    impute_num_adata.X = array_type(impute_num_adata.X)
286
    if expected_error:
287
        with pytest.raises(expected_error):
288
            mice_forest_impute(impute_num_adata, copy=True)
289
290
291
@pytest.mark.skipif(os.name == "Darwin", reason="miceforest Imputation not supported by MacOS.")
292
def test_miceforest_impute_no_copy(impute_iris_adata):
293
    adata_not_imputed = impute_iris_adata.copy()
294
    mice_forest_impute(impute_iris_adata)
295
296
    _base_check_imputation(adata_not_imputed, impute_iris_adata)
297
298
299
@pytest.mark.skipif(os.name == "Darwin", reason="miceforest Imputation not supported by MacOS.")
300
def test_miceforest_impute_copy(impute_iris_adata):
301
    adata_imputed = mice_forest_impute(impute_iris_adata, copy=True)
302
303
    _base_check_imputation(impute_iris_adata, adata_imputed)
304
    assert id(impute_iris_adata) != id(adata_imputed)
305
306
307
@pytest.mark.skipif(os.name == "Darwin", reason="miceforest Imputation not supported by MacOS.")
308
def test_miceforest_impute_non_numerical_data(impute_titanic_adata):
309
    with pytest.raises(ValueError):
310
        mice_forest_impute(impute_titanic_adata)
311
312
313
@pytest.mark.skipif(os.name == "Darwin", reason="miceforest Imputation not supported by MacOS.")
314
def test_miceforest_impute_numerical_data(impute_iris_adata):
315
    adata_not_imputed = impute_iris_adata.copy()
316
    mice_forest_impute(impute_iris_adata)
317
318
    _base_check_imputation(adata_not_imputed, impute_iris_adata)
319
320
321
@pytest.mark.parametrize(
322
    "array_type,expected_error",
323
    [
324
        (np.array, None),
325
        (da.from_array, None),
326
        (sparse.csr_matrix, NotImplementedError),
327
    ],
328
)
329
def test_explicit_impute_array_types(impute_num_adata, array_type, expected_error):
330
    impute_num_adata.X = array_type(impute_num_adata.X)
331
    if expected_error:
332
        with pytest.raises(expected_error):
333
            explicit_impute(impute_num_adata, replacement=1011, copy=True)
334
335
336
@pytest.mark.parametrize("array_type", ARRAY_TYPES)
337
def test_explicit_impute_all(array_type, impute_num_adata):
338
    impute_num_adata.X = array_type(impute_num_adata.X)
339
    warnings.filterwarnings("ignore", category=FutureWarning)
340
    adata_imputed = explicit_impute(impute_num_adata, replacement=1011, copy=True)
341
342
    _base_check_imputation(impute_num_adata, adata_imputed)
343
    assert np.sum([adata_imputed.X == 1011]) == 3
344
345
346
@pytest.mark.parametrize("array_type", ARRAY_TYPES)
347
def test_explicit_impute_subset(impute_adata, array_type):
348
    impute_adata.X = array_type(impute_adata.X)
349
    adata_imputed = explicit_impute(impute_adata, replacement={"strcol": "REPLACED", "intcol": 1011}, copy=True)
350
351
    _base_check_imputation(impute_adata, adata_imputed, imputed_var_names=("strcol", "intcol"))
352
    assert np.sum([adata_imputed.X == 1011]) == 1
353
    assert np.sum([adata_imputed.X == "REPLACED"]) == 1
354
355
356
def test_warning(impute_num_adata):
357
    warning_results = _warn_imputation_threshold(impute_num_adata, threshold=20, var_names=None)
358
    assert warning_results == {"col1": 25, "col3": 50}