--- a +++ b/tests/anndata/test_anndata_ext.py @@ -0,0 +1,525 @@ +from __future__ import annotations + +import numpy as np +import pandas as pd +import pytest +from anndata import AnnData +from pandas import DataFrame +from pandas.testing import assert_frame_equal + +import ehrapy as ep +from ehrapy.anndata._constants import CATEGORICAL_TAG, FEATURE_TYPE_KEY, NUMERIC_TAG +from ehrapy.anndata.anndata_ext import ( + NotEncodedError, + _are_ndarrays_equal, + _assert_encoded, + _is_val_missing, + anndata_to_df, + assert_numeric_vars, + delete_from_obs, + df_to_anndata, + get_numeric_vars, + move_to_obs, + move_to_x, + set_numeric_vars, +) +from tests.conftest import TEST_DATA_PATH + + +@pytest.fixture +def setup_df_to_anndata() -> tuple[DataFrame, list, list, list]: + col1_val = ["str" + str(idx) for idx in range(100)] + col2_val = ["another_str" + str(idx) for idx in range(100)] + col3_val = list(range(100)) + df = DataFrame({"col1": col1_val, "col2": col2_val, "col3": col3_val}) + + return df, col1_val, col2_val, col3_val + + +@pytest.fixture +def setup_binary_df_to_anndata() -> DataFrame: + col1_val = ["str" + str(idx) for idx in range(100)] + col2_val = ["another_str" + str(idx) for idx in range(100)] + col3_val = [0 for _ in range(100)] + col4_val = [1.0 for _ in range(100)] + col5_val = [0.0 if idx % 2 == 0 else np.nan for idx in range(100)] + col6_val = [idx % 2 for idx in range(100)] + col7_val = [float(idx % 2) for idx in range(100)] + col8_val = [idx % 3 if idx % 3 in {0, 1} else np.nan for idx in range(100)] + df = DataFrame( + { + "col1": col1_val, + "col2": col2_val, + "col3": col3_val, + "col4": col4_val, + "col5": col5_val, + "col6_binary_int": col6_val, + "col7_binary_float": col7_val, + "col8_binary_missing_values": col8_val, + } + ) + + return df + + +@pytest.fixture +def setup_anndata_to_df() -> tuple[list, list, list]: + col1_val = ["patient" + str(idx) for idx in range(100)] + col2_val = ["feature" + str(idx) for idx in range(100)] + col3_val = list(range(100)) + + return col1_val, col2_val, col3_val + + +def test_move_to_obs_only_num(adata_move_obs_num: AnnData): + move_to_obs(adata_move_obs_num, ["los_days", "b12_values"]) + assert list(adata_move_obs_num.obs.columns) == ["los_days", "b12_values"] + assert {str(col) for col in adata_move_obs_num.obs.dtypes} == {"float32"} + assert_frame_equal( + adata_move_obs_num.obs, + DataFrame( + {"los_days": [14.0, 7.0, 10.0, 11.0, 3.0], "b12_values": [500.0, 330.0, 800.0, 765.0, 800.0]}, + index=[str(idx) for idx in range(5)], + ).astype({"b12_values": "float32", "los_days": "float32"}), + ) + + +def test_move_to_obs_mixed(adata_move_obs_mix: AnnData): + move_to_obs(adata_move_obs_mix, ["name", "clinic_id"]) + assert set(adata_move_obs_mix.obs.columns) == {"name", "clinic_id"} + assert {str(col) for col in adata_move_obs_mix.obs.dtypes} == {"float32", "category"} + assert_frame_equal( + adata_move_obs_mix.obs, + DataFrame( + {"clinic_id": list(range(1, 6)), "name": ["foo", "bar", "baz", "buz", "ber"]}, + index=[str(idx) for idx in range(5)], + ).astype({"clinic_id": "float32", "name": "category"}), + ) + + +def test_move_to_obs_copy_obs(adata_move_obs_mix: AnnData): + adata_dim_old = adata_move_obs_mix.X.shape + move_to_obs(adata_move_obs_mix, ["name", "clinic_id"], copy_obs=True) + assert set(adata_move_obs_mix.obs.columns) == {"name", "clinic_id"} + assert adata_move_obs_mix.X.shape == adata_dim_old + assert {str(col) for col in adata_move_obs_mix.obs.dtypes} == {"float32", "category"} + assert_frame_equal( + adata_move_obs_mix.obs, + DataFrame( + {"clinic_id": list(range(1, 6)), "name": ["foo", "bar", "baz", "buz", "ber"]}, + index=[str(idx) for idx in range(5)], + ).astype({"clinic_id": "float32", "name": "category"}), + ) + + +def test_move_to_obs_invalid_column_name(adata_move_obs_mix: AnnData): + with pytest.raises(ValueError): + _ = move_to_obs(adata_move_obs_mix, "nam") + _ = move_to_obs(adata_move_obs_mix, "clic_id") + _ = move_to_obs(adata_move_obs_mix, ["nam", "clic_id"]) + + +def test_move_to_x(adata_move_obs_mix): + move_to_obs(adata_move_obs_mix, ["name"], copy_obs=True) + move_to_obs(adata_move_obs_mix, ["clinic_id"], copy_obs=False) + new_adata_non_num = move_to_x(adata_move_obs_mix, ["name"]) + new_adata_num = move_to_x(adata_move_obs_mix, ["clinic_id"]) + assert set(new_adata_non_num.obs.columns) == {"name", "clinic_id"} + assert set(new_adata_num.obs.columns) == {"name"} + assert {str(col) for col in new_adata_num.obs.dtypes} == {"category"} + assert {str(col) for col in new_adata_non_num.obs.dtypes} == {"float32", "category"} + + assert_frame_equal( + new_adata_non_num.var, + DataFrame( + {FEATURE_TYPE_KEY: [NUMERIC_TAG, NUMERIC_TAG, CATEGORICAL_TAG]}, + index=["los_days", "b12_values", "name"], + ), + ) + + assert_frame_equal( + new_adata_num.var, + DataFrame( + {FEATURE_TYPE_KEY: [NUMERIC_TAG, NUMERIC_TAG, CATEGORICAL_TAG, np.nan]}, + index=["los_days", "b12_values", "name", "clinic_id"], + ), + ) + ep.ad.infer_feature_types(new_adata_num, output=None) + assert np.all(new_adata_num.var[FEATURE_TYPE_KEY] == [NUMERIC_TAG, NUMERIC_TAG, CATEGORICAL_TAG, NUMERIC_TAG]) + + assert_frame_equal( + new_adata_num.obs, + DataFrame( + {"name": ["foo", "bar", "baz", "buz", "ber"]}, + index=[str(idx) for idx in range(5)], + ).astype({"name": "category"}), + ) + + assert_frame_equal( + new_adata_non_num.obs, + DataFrame( + {"name": ["foo", "bar", "baz", "buz", "ber"], "clinic_id": list(range(1, 6))}, + index=[str(idx) for idx in range(5)], + ).astype({"clinic_id": "float32", "name": "category"}), + ) + + +def test_move_to_x_copy_x(adata_move_obs_mix): + move_to_obs(adata_move_obs_mix, ["name"], copy_obs=False) + obs_df = adata_move_obs_mix.obs.copy() + new_adata = move_to_x(adata_move_obs_mix, ["name"], copy_x=True) + assert_frame_equal(new_adata.obs, obs_df) + + +def test_move_to_x_invalid_column_names(adata_move_obs_mix): + move_to_obs(adata_move_obs_mix, ["name"], copy_obs=True) + move_to_obs(adata_move_obs_mix, ["clinic_id"], copy_obs=False) + with pytest.raises(ValueError): + _ = move_to_x(adata_move_obs_mix, ["blabla1"]) + _ = move_to_x(adata_move_obs_mix, ["blabla1", "blabla2"]) + + +def test_move_to_x_move_to_obs(adata_move_obs_mix): + adata_dim_old = adata_move_obs_mix.X.shape + # moving columns from X to obs and back + # case 1: move some column from obs to X and this col was copied previously from X to obs + move_to_obs(adata_move_obs_mix, ["name"], copy_obs=True) + adata = move_to_x(adata_move_obs_mix, ["name"]) + assert {"name"}.issubset(set(adata.var_names)) + assert adata.X.shape == adata_dim_old + delete_from_obs(adata, ["name"]) + + # case 2: move some column from obs to X and this col was previously moved inplace from X to obs + move_to_obs(adata, ["clinic_id"], copy_obs=False) + adata = move_to_x(adata, ["clinic_id"]) + assert not {"clinic_id"}.issubset(set(adata.obs.columns)) + assert {"clinic_id"}.issubset(set(adata.var_names)) + assert adata.X.shape == adata_dim_old + + # case 3: move multiple columns from obs to X and some of them were copied or moved inplace previously from X to obs + move_to_obs(adata, ["los_days"], copy_obs=True) + move_to_obs(adata, ["b12_values"], copy_obs=False) + adata = move_to_x(adata, ["los_days", "b12_values"]) + delete_from_obs(adata, ["los_days"]) + assert not {"los_days"}.issubset( + set(adata.obs.columns) + ) # check if the copied column was removed from obs by delete_from_obs() + assert not {"b12_values"}.issubset(set(adata.obs.columns)) + assert {"los_days", "b12_values"}.issubset(set(adata.var_names)) + assert adata.X.shape == adata_dim_old + + +def test_delete_from_obs(adata_move_obs_mix): + adata = move_to_obs(adata_move_obs_mix, ["los_days"], copy_obs=True) + adata = delete_from_obs(adata, ["los_days"]) + assert not {"los_days"}.issubset(set(adata.obs.columns)) + assert {"los_days"}.issubset(set(adata.var_names)) + + +def test_df_to_anndata_simple(setup_df_to_anndata): + df, col1_val, col2_val, col3_val = setup_df_to_anndata + expected_x = np.array([col1_val, col2_val, col3_val], dtype="object").transpose() + adata = df_to_anndata(df) + + assert adata.X.dtype == "object" + assert adata.X.shape == (100, 3) + np.testing.assert_array_equal(adata.X, expected_x) + + +def test_df_to_anndata_index_column(setup_df_to_anndata): + df, col1_val, col2_val, col3_val = setup_df_to_anndata + expected_x = np.array([col2_val, col3_val], dtype="object").transpose() + adata = df_to_anndata(df, index_column="col1") + + assert adata.X.dtype == "object" + assert adata.X.shape == (100, 2) + np.testing.assert_array_equal(adata.X, expected_x) + assert list(adata.obs.index) == col1_val + assert adata.obs.index.name == "col1" + + +def test_df_to_anndata_index_column_num(setup_df_to_anndata): + df, col1_val, col2_val, col3_val = setup_df_to_anndata + expected_x = np.array([col2_val, col3_val], dtype="object").transpose() + adata = df_to_anndata(df, index_column=0) + + assert adata.X.dtype == "object" + assert adata.X.shape == (100, 2) + np.testing.assert_array_equal(adata.X, expected_x) + assert list(adata.obs.index) == col1_val + assert adata.obs.index.name == "col1" + + +def test_df_to_anndata_index_column_index(): + d = {"col1": [0, 1, 2, 3], "col2": pd.Series([2, 3])} + df = pd.DataFrame(data=d, index=[0, 1, 2, 3]) + df.index.set_names("quarter", inplace=True) + adata = ep.ad.df_to_anndata(df, index_column="quarter") + assert adata.obs.index.name == "quarter" + assert list(adata.obs.index) == ["0", "1", "2", "3"] + + +def test_df_to_anndata_invalid_index_throws_error(setup_df_to_anndata): + df, col1_val, col2_val, col3_val = setup_df_to_anndata + with pytest.raises(ValueError): + _ = df_to_anndata(df, index_column="UnknownCol") + + +def test_df_to_anndata_cols_obs_only(setup_df_to_anndata): + df, col1_val, col2_val, col3_val = setup_df_to_anndata + adata = df_to_anndata(df, columns_obs_only=["col1", "col2"]) + assert adata.X.dtype == "float32" + assert adata.X.shape == (100, 1) + assert_frame_equal( + adata.obs, + DataFrame({"col1": col1_val, "col2": col2_val}, index=[str(idx) for idx in range(100)]).astype("category"), + ) + + +def test_df_to_anndata_all_num(): + test_array = np.random.default_rng().integers(0, 100, (4, 5)) + df = DataFrame(test_array, columns=["col" + str(idx) for idx in range(5)]) + adata = df_to_anndata(df) + + assert adata.X.dtype == "float32" + np.testing.assert_array_equal(test_array, adata.X) + + +def test_df_to_anndata_index_col_obs_only(setup_df_to_anndata): + """Passing index_column and columns_obs_only at the same time.""" + df, col1_val, col2_val, col3_val = setup_df_to_anndata + adata = df_to_anndata(df, index_column="col1", columns_obs_only=["col1", "col2"]) + assert list(adata.obs.index) == col1_val + + +def test_anndata_to_df_simple(setup_anndata_to_df): + col1_val, col2_val, col3_val = setup_anndata_to_df + expected_df = DataFrame({"col1": col1_val, "col2": col2_val, "col3": col3_val}, dtype="object") + adata_x = np.array([col1_val, col2_val, col3_val], dtype="object").transpose() + adata = AnnData( + X=adata_x, + obs=DataFrame(index=list(range(100))), + var=DataFrame(index=["col" + str(idx) for idx in range(1, 4)]), + ) + anndata_df = anndata_to_df(adata) + + assert_frame_equal(anndata_df, expected_df) + + +def test_anndata_to_df_all_from_obs(setup_anndata_to_df): + col1_val, col2_val, col3_val = setup_anndata_to_df + expected_df = DataFrame({"col1": col1_val, "col2": col2_val, "col3": col3_val}) + obs = DataFrame({"col2": col2_val, "col3": col3_val}) + adata_x = np.array([col1_val], dtype="object").transpose() + adata = AnnData(X=adata_x, obs=obs, var=DataFrame(index=["col1"])) + anndata_df = anndata_to_df(adata, obs_cols=list(adata.obs.columns)) + + assert_frame_equal(anndata_df, expected_df) + + +def test_anndata_to_df_some_from_obs(setup_anndata_to_df): + col1_val, col2_val, col3_val = setup_anndata_to_df + expected_df = DataFrame({"col1": col1_val, "col3": col3_val}) + obs = DataFrame({"col2": col2_val, "col3": col3_val}) + adata_x = np.array([col1_val], dtype="object").transpose() + adata = AnnData(X=adata_x, obs=obs, var=DataFrame(index=["col1"])) + anndata_df = anndata_to_df(adata, obs_cols=["col3"]) + + assert_frame_equal(anndata_df, expected_df) + + +def test_anndata_to_df_throws_error_with_empty_obs(): + col1_val = ["patient" + str(idx) for idx in range(100)] + adata_x = np.array([col1_val], dtype="object").transpose() + adata = AnnData(X=adata_x, obs=DataFrame(index=list(range(100))), var=DataFrame(index=["col1"])) + + with pytest.raises(ValueError): + _ = anndata_to_df(adata, obs_cols=["some_missing_column"]) + + +def test_anndata_to_df_all_columns(setup_anndata_to_df): + col1_val, col2_val, col3_val = setup_anndata_to_df + expected_df = DataFrame({"col1": col1_val}) + var = DataFrame(index=["col1"]) + adata_x = np.array([col1_val], dtype="object").transpose() + adata = AnnData(X=adata_x, obs=DataFrame({"col2": col2_val, "col3": col3_val}), var=var) + anndata_df = anndata_to_df(adata, obs_cols=list(adata.var.columns)) + + assert_frame_equal(anndata_df, expected_df) + + +def test_anndata_to_df_layers(setup_anndata_to_df): + col1_val, col2_val, col3_val = setup_anndata_to_df + expected_df = DataFrame({"col1": col1_val, "col2": col2_val, "col3": col3_val}) + obs = DataFrame({"col2": col2_val, "col3": col3_val}) + adata_x = np.array([col1_val], dtype="object").transpose() + adata = AnnData(X=adata_x, obs=obs, var=DataFrame(index=["col1"]), layers={"raw": adata_x.copy()}) + anndata_df = anndata_to_df(adata, obs_cols=list(adata.obs.columns), layer="raw") + + assert_frame_equal(anndata_df, expected_df) + + +def test_detect_binary_columns(setup_binary_df_to_anndata): + adata = df_to_anndata(setup_binary_df_to_anndata) + ep.ad.infer_feature_types(adata, output=None) + + assert_frame_equal( + adata.var, + DataFrame( + { + FEATURE_TYPE_KEY: [ + CATEGORICAL_TAG, + CATEGORICAL_TAG, + CATEGORICAL_TAG, + CATEGORICAL_TAG, + CATEGORICAL_TAG, + CATEGORICAL_TAG, + CATEGORICAL_TAG, + CATEGORICAL_TAG, + ] + }, + index=[ + "col1", + "col2", + "col3", + "col4", + "col5", + "col6_binary_int", + "col7_binary_float", + "col8_binary_missing_values", + ], + ), + ) + + +def test_detect_mixed_binary_columns(): + df = pd.DataFrame( + {"Col1": list(range(4)), "Col2": ["str" + str(i) for i in range(4)], "Col3": [1.0, 0.0, np.nan, 1.0]} + ) + adata = ep.ad.df_to_anndata(df) + ep.ad.infer_feature_types(adata, output=None) + + assert_frame_equal( + adata.var, + DataFrame( + {FEATURE_TYPE_KEY: [NUMERIC_TAG, CATEGORICAL_TAG, CATEGORICAL_TAG]}, + index=["Col1", "Col2", "Col3"], + ), + ) + + +@pytest.fixture +def adata_strings_encoded(): + obs_data = {"ID": ["Patient1", "Patient2", "Patient3"], "Age": [31, 94, 62]} + X_strings = np.array( + [ + [1, 3.4, "A string", "A different string"], + [2, 5.4, "Silly string", "A different string"], + [2, 5.7, "A string", "What string?"], + ], + dtype=pd.StringDtype, + ) + var_strings = { + "Feature": ["Numeric1", "Numeric2", "String1", "String2"], + "Type": ["Numeric", "Numeric", "String", "String"], + } + + adata_strings = AnnData( + X=X_strings, + obs=pd.DataFrame(data=obs_data), + var=pd.DataFrame(data=var_strings, index=var_strings["Feature"]), + ) + adata_strings.var[FEATURE_TYPE_KEY] = [NUMERIC_TAG, NUMERIC_TAG, CATEGORICAL_TAG, CATEGORICAL_TAG] + + adata_encoded = ep.pp.encode(adata_strings.copy(), autodetect=True, encodings="label") + + return adata_strings, adata_encoded + + +@pytest.fixture +def adata_encoded(adata_strings): + return ep.pp.encode(adata_strings.copy(), autodetect=True, encodings="label") + + +def test_assert_encoded(adata_strings_encoded): + adata_strings, adata_encoded = adata_strings_encoded + _assert_encoded(adata_encoded) + with pytest.raises(NotEncodedError, match=r"not yet been encoded"): + _assert_encoded(adata_strings) + + +def test_get_numeric_vars(adata_strings_encoded): + adata_strings, adata_encoded = adata_strings_encoded + vars = get_numeric_vars(adata_encoded) + assert vars == ["Numeric1", "Numeric2"] + with pytest.raises(NotEncodedError, match=r"not yet been encoded"): + get_numeric_vars(adata_strings) + + +def test_get_numeric_vars_numeric_only(): + adata = AnnData(X=np.array([[1, 2, 3], [4, 0, 6]], dtype=np.float32)) + vars = get_numeric_vars(adata) + assert vars == ["0", "1", "2"] + + +def test_assert_numeric_vars(adata_strings_encoded): + adata_strings, adata_encoded = adata_strings_encoded + assert_numeric_vars(adata_encoded, ["Numeric1", "Numeric2"]) + with pytest.raises(ValueError, match=r"Some selected vars are not numeric"): + assert_numeric_vars(adata_encoded, ["Numeric2", "String1"]) + + +def test_set_numeric_vars(adata_strings_encoded): + """Test for the numeric vars setter.""" + adata_strings, adata_encoded = adata_strings_encoded + values = np.array( + [[1.2, 2.2], [3.2, 4.2], [5.2, 6.2]], + dtype=np.dtype(np.float32), + ) + adata_set = set_numeric_vars(adata_encoded, values, copy=True) + np.testing.assert_array_equal(adata_set.X[:, 2], values[:, 0]) and np.testing.assert_array_equal( + adata_set.X[:, 3], values[:, 1] + ) + + with pytest.raises(ValueError, match=r"Some selected vars are not numeric"): + set_numeric_vars(adata_encoded, values, vars=["ehrapycat_String1"]) + + string_values = np.array( + [ + ["A"], + ["B"], + ["A"], + ] + ) + + with pytest.raises(TypeError, match=r"Values must be numeric"): + set_numeric_vars(adata_encoded, string_values) + + extra_values = np.array( + [ + [1.2, 1.3, 1.4], + [2.2, 2.3, 2.4], + [2.2, 2.3, 2.4], + ], + dtype=np.dtype(np.float32), + ) + + with pytest.raises(ValueError, match=r"does not match number of vars"): + set_numeric_vars(adata_encoded, extra_values) + + with pytest.raises(NotEncodedError, match=r"not yet been encoded"): + set_numeric_vars(adata_strings, values) + + +def test_are_ndarrays_equal(impute_num_adata): + impute_num_adata_copy = impute_num_adata.copy() + assert _are_ndarrays_equal(impute_num_adata.X, impute_num_adata_copy.X) + impute_num_adata_copy.X[0, 0] = 42.0 + assert not _are_ndarrays_equal(impute_num_adata.X, impute_num_adata_copy.X) + + +def test_is_val_missing(impute_num_adata): + assert np.array_equal( + _is_val_missing(impute_num_adata.X), + np.array([[False, False, True], [False, False, False], [True, False, False], [False, False, True]]), + )