Switch to side-by-side view

--- a
+++ b/tests/preprocessing/test_encoding.py
@@ -0,0 +1,316 @@
+from pathlib import Path
+
+import numpy as np
+import pandas as pd
+import pytest
+from pandas import CategoricalDtype, DataFrame
+from pandas.testing import assert_frame_equal
+
+from ehrapy.anndata._constants import CATEGORICAL_TAG, FEATURE_TYPE_KEY, NUMERIC_TAG
+from ehrapy.io._read import read_csv
+from ehrapy.preprocessing._encoding import _reorder_encodings, encode
+from tests.conftest import TEST_DATA_PATH
+
+CURRENT_DIR = Path(__file__).parent
+_TEST_PATH = f"{TEST_DATA_PATH}/encode"
+
+
+def test_unknown_encode_mode(encode_ds_1_adata):
+    with pytest.raises(ValueError):
+        encoded_ann_data = encode(encode_ds_1_adata, autodetect=False, encodings={"unknown_mode": ["survival"]})  # noqa: F841
+
+
+def test_duplicate_column_encoding(encode_ds_1_adata):
+    with pytest.raises(ValueError):
+        encoded_ann_data = encode(  # noqa: F841
+            encode_ds_1_adata,
+            autodetect=False,
+            encodings={"label": ["survival"], "one-hot": ["survival"]},
+        )
+
+
+def test_autodetect_encode(encode_ds_1_adata):
+    encoded_ann_data = encode(encode_ds_1_adata, autodetect=True)
+    assert list(encoded_ann_data.obs.columns) == ["survival", "clinic_day"]
+    assert set(encoded_ann_data.var_names) == {
+        "ehrapycat_survival_False",
+        "ehrapycat_survival_True",
+        "ehrapycat_clinic_day_Monday",
+        "ehrapycat_clinic_day_Friday",
+        "ehrapycat_clinic_day_Saturday",
+        "ehrapycat_clinic_day_Sunday",
+        "patient_id",
+        "los_days",
+        "b12_values",
+    }
+
+    assert np.all(
+        encoded_ann_data.var["unencoded_var_names"]
+        == [
+            "survival",
+            "survival",
+            "clinic_day",
+            "clinic_day",
+            "clinic_day",
+            "clinic_day",
+            "patient_id",
+            "los_days",
+            "b12_values",
+        ]
+    )
+
+    assert np.all(encoded_ann_data.var["encoding_mode"][:6] == ["one-hot"] * 6)
+    assert np.all(enc is None for enc in encoded_ann_data.var["encoding_mode"][6:])
+
+    assert id(encoded_ann_data.X) != id(encoded_ann_data.layers["original"])
+    assert (
+        encode_ds_1_adata is not None
+        and encode_ds_1_adata.X is not None
+        and encode_ds_1_adata.obs is not None
+        and encode_ds_1_adata.uns is not None
+    )
+    assert id(encoded_ann_data) != id(encode_ds_1_adata)
+    assert id(encoded_ann_data.obs) != id(encode_ds_1_adata.obs)
+    assert id(encoded_ann_data.uns) != id(encode_ds_1_adata.uns)
+    assert id(encoded_ann_data.var) != id(encode_ds_1_adata.var)
+    assert all(column in set(encoded_ann_data.obs.columns) for column in ["survival", "clinic_day"])
+    assert not any(column in set(encode_ds_1_adata.obs.columns) for column in ["survival", "clinic_day"])
+
+    assert_frame_equal(
+        encode_ds_1_adata.var,
+        DataFrame(
+            {FEATURE_TYPE_KEY: [NUMERIC_TAG, NUMERIC_TAG, NUMERIC_TAG, CATEGORICAL_TAG, CATEGORICAL_TAG]},
+            index=["patient_id", "los_days", "b12_values", "survival", "clinic_day"],
+        ),
+    )
+
+    assert np.all(
+        encoded_ann_data.var[FEATURE_TYPE_KEY]
+        == [
+            CATEGORICAL_TAG,
+            CATEGORICAL_TAG,
+            CATEGORICAL_TAG,
+            CATEGORICAL_TAG,
+            CATEGORICAL_TAG,
+            CATEGORICAL_TAG,
+            NUMERIC_TAG,
+            NUMERIC_TAG,
+            NUMERIC_TAG,
+        ]
+    )
+
+    assert pd.api.types.is_bool_dtype(encoded_ann_data.obs["survival"].dtype)
+    assert isinstance(encoded_ann_data.obs["clinic_day"].dtype, CategoricalDtype)
+
+
+def test_autodetect_num_only(capfd, encode_ds_2_adata):
+    encoded_ann_data = encode(encode_ds_2_adata, autodetect=True)
+    out, err = capfd.readouterr()
+    assert id(encoded_ann_data) == id(encode_ds_2_adata)
+
+
+def test_autodetect_custom_mode(encode_ds_1_adata):
+    encoded_ann_data = encode(encode_ds_1_adata, autodetect=True, encodings="label")
+    assert list(encoded_ann_data.obs.columns) == ["survival", "clinic_day"]
+    assert set(encoded_ann_data.var_names) == {
+        "ehrapycat_survival",
+        "ehrapycat_clinic_day",
+        "patient_id",
+        "los_days",
+        "b12_values",
+    }
+
+    assert np.all(
+        encoded_ann_data.var["unencoded_var_names"]
+        == ["survival", "clinic_day", "patient_id", "los_days", "b12_values"]
+    )
+    assert np.all(encoded_ann_data.var["encoding_mode"][:2] == ["label"] * 2)
+    assert np.all(enc is None for enc in encoded_ann_data.var["encoding_mode"][2:])
+
+    assert id(encoded_ann_data.X) != id(encoded_ann_data.layers["original"])
+    assert (
+        encode_ds_1_adata is not None
+        and encode_ds_1_adata.X is not None
+        and encode_ds_1_adata.obs is not None
+        and encode_ds_1_adata.uns is not None
+    )
+    assert id(encoded_ann_data) != id(encode_ds_1_adata)
+    assert id(encoded_ann_data.obs) != id(encode_ds_1_adata.obs)
+    assert id(encoded_ann_data.uns) != id(encode_ds_1_adata.uns)
+    assert id(encoded_ann_data.var) != id(encode_ds_1_adata.var)
+    assert all(column in set(encoded_ann_data.obs.columns) for column in ["survival", "clinic_day"])
+    assert not any(column in set(encode_ds_1_adata.obs.columns) for column in ["survival", "clinic_day"])
+
+    assert_frame_equal(
+        encode_ds_1_adata.var,
+        DataFrame(
+            {FEATURE_TYPE_KEY: [NUMERIC_TAG, NUMERIC_TAG, NUMERIC_TAG, CATEGORICAL_TAG, CATEGORICAL_TAG]},
+            index=["patient_id", "los_days", "b12_values", "survival", "clinic_day"],
+        ),
+    )
+
+    assert np.all(
+        encoded_ann_data.var[FEATURE_TYPE_KEY]
+        == [
+            CATEGORICAL_TAG,
+            CATEGORICAL_TAG,
+            NUMERIC_TAG,
+            NUMERIC_TAG,
+            NUMERIC_TAG,
+        ]
+    )
+
+    assert pd.api.types.is_bool_dtype(encoded_ann_data.obs["survival"].dtype)
+    assert isinstance(encoded_ann_data.obs["clinic_day"].dtype, CategoricalDtype)
+
+
+def test_autodetect_encode_again(encode_ds_1_adata):
+    encoded_ann_data = encode(encode_ds_1_adata, autodetect=True)
+    encoded_ann_data_again = encode(encoded_ann_data, autodetect=True)
+    assert id(encoded_ann_data_again) == id(encoded_ann_data)
+
+
+def test_custom_encode(encode_ds_1_adata):
+    encoded_ann_data = encode(
+        encode_ds_1_adata,
+        autodetect=False,
+        encodings={"label": ["survival"], "one-hot": ["clinic_day"]},
+    )
+    assert encoded_ann_data.X.shape == (5, 8)
+    assert list(encoded_ann_data.obs.columns) == ["survival", "clinic_day"]
+    assert "ehrapycat_survival" in list(encoded_ann_data.var_names)
+    assert all(
+        clinic_day in list(encoded_ann_data.var_names)
+        for clinic_day in [
+            "ehrapycat_clinic_day_Friday",
+            "ehrapycat_clinic_day_Monday",
+            "ehrapycat_clinic_day_Saturday",
+            "ehrapycat_clinic_day_Sunday",
+        ]
+    )
+
+    assert np.all(
+        encoded_ann_data.var["unencoded_var_names"]
+        == ["clinic_day", "clinic_day", "clinic_day", "clinic_day", "survival", "patient_id", "los_days", "b12_values"]
+    )
+    assert np.all(encoded_ann_data.var["encoding_mode"][:5] == ["one-hot"] * 4 + ["label"])
+    assert np.all(enc is None for enc in encoded_ann_data.var["encoding_mode"][5:])
+
+    assert id(encoded_ann_data.X) != id(encoded_ann_data.layers["original"])
+    assert (
+        encode_ds_1_adata is not None
+        and encode_ds_1_adata.X is not None
+        and encode_ds_1_adata.obs is not None
+        and encode_ds_1_adata.uns is not None
+    )
+    assert id(encoded_ann_data) != id(encode_ds_1_adata)
+    assert id(encoded_ann_data.obs) != id(encode_ds_1_adata.obs)
+    assert id(encoded_ann_data.uns) != id(encode_ds_1_adata.uns)
+    assert id(encoded_ann_data.var) != id(encode_ds_1_adata.var)
+    assert all(column in set(encoded_ann_data.obs.columns) for column in ["survival", "clinic_day"])
+    assert not any(column in set(encode_ds_1_adata.obs.columns) for column in ["survival", "clinic_day"])
+
+    assert_frame_equal(
+        encode_ds_1_adata.var,
+        DataFrame(
+            {FEATURE_TYPE_KEY: [NUMERIC_TAG, NUMERIC_TAG, NUMERIC_TAG, CATEGORICAL_TAG, CATEGORICAL_TAG]},
+            index=["patient_id", "los_days", "b12_values", "survival", "clinic_day"],
+        ),
+    )
+
+    assert np.all(
+        encoded_ann_data.var[FEATURE_TYPE_KEY]
+        == [
+            CATEGORICAL_TAG,
+            CATEGORICAL_TAG,
+            CATEGORICAL_TAG,
+            CATEGORICAL_TAG,
+            CATEGORICAL_TAG,
+            NUMERIC_TAG,
+            NUMERIC_TAG,
+            NUMERIC_TAG,
+        ]
+    )
+
+    assert pd.api.types.is_bool_dtype(encoded_ann_data.obs["survival"].dtype)
+    assert isinstance(encoded_ann_data.obs["clinic_day"].dtype, CategoricalDtype)
+
+
+def test_custom_encode_again_single_columns_encoding(encode_ds_1_adata):
+    encoded_ann_data = encode(
+        encode_ds_1_adata,
+        autodetect=False,
+        encodings={"label": ["survival"], "one-hot": ["clinic_day"]},
+    )
+    encoded_ann_data_again = encode(encoded_ann_data, autodetect=False, encodings={"label": ["clinic_day"]})
+    assert encoded_ann_data_again.X.shape == (5, 5)
+    assert len(encoded_ann_data_again.obs.columns) == 2
+    assert set(encoded_ann_data_again.obs.columns) == {"survival", "clinic_day"}
+    assert "ehrapycat_survival" in list(encoded_ann_data_again.var_names)
+    assert "ehrapycat_clinic_day" in list(encoded_ann_data_again.var_names)
+    assert all(
+        clinic_day not in list(encoded_ann_data_again.var_names)
+        for clinic_day in [
+            "ehrapycat_clinic_day_Friday",
+            "ehrapycat_clinic_day_Monday",
+            "ehrapycat_clinic_day_Saturday",
+            "ehrapycat_clinic_day_Sunday",
+        ]
+    )
+
+    assert np.all(
+        encoded_ann_data_again.var["encoding_mode"].loc[["ehrapycat_survival", "ehrapycat_clinic_day"]]
+        == ["label", "label"]
+    )
+
+    assert id(encoded_ann_data_again.X) != id(encoded_ann_data_again.layers["original"])
+    assert pd.api.types.is_bool_dtype(encoded_ann_data.obs["survival"].dtype)
+    assert isinstance(encoded_ann_data.obs["clinic_day"].dtype, CategoricalDtype)
+
+
+def test_custom_encode_again_multiple_columns_encoding(encode_ds_1_adata):
+    encoded_ann_data = encode(encode_ds_1_adata, autodetect=False, encodings={"one-hot": ["clinic_day", "survival"]})
+    encoded_ann_data_again = encode(
+        encoded_ann_data,
+        autodetect=False,
+        encodings={"label": ["survival"], "one-hot": ["clinic_day"]},
+    )
+    assert encoded_ann_data_again.X.shape == (5, 8)
+    assert len(encoded_ann_data_again.obs.columns) == 2
+    assert set(encoded_ann_data_again.obs.columns) == {"survival", "clinic_day"}
+    assert "ehrapycat_survival" in list(encoded_ann_data_again.var_names)
+    assert "ehrapycat_clinic_day_Friday" in list(encoded_ann_data_again.var_names)
+    assert all(
+        survival_outcome not in list(encoded_ann_data_again.var_names)
+        for survival_outcome in ["ehrapycat_survival_False", "ehrapycat_survival_True"]
+    )
+
+    assert np.all(
+        encoded_ann_data_again.var.loc[encoded_ann_data_again.var["unencoded_var_names"] == "survival", "encoding_mode"]
+        == "label"
+    )
+    assert np.all(
+        encoded_ann_data_again.var.loc[
+            encoded_ann_data_again.var["unencoded_var_names"] == "clinic_day", "encoding_mode"
+        ]
+        == "one-hot"
+    )
+
+    assert id(encoded_ann_data_again.X) != id(encoded_ann_data_again.layers["original"])
+    assert pd.api.types.is_bool_dtype(encoded_ann_data.obs["survival"].dtype)
+    assert isinstance(encoded_ann_data.obs["clinic_day"].dtype, CategoricalDtype)
+
+
+def test_update_encoding_scheme_1(encode_ds_1_adata):
+    encode_ds_1_adata.var["unencoded_var_names"] = ["col1", "col2", "col3", "col4", "col5"]
+    encode_ds_1_adata.var["encoding_mode"] = ["label", "label", "label", "one-hot", "one-hot"]
+
+    new_encodings = {"one-hot": ["col1"], "label": ["col2", "col3", "col4"]}
+
+    expected_encodings = {
+        "label": ["col2", "col3", "col4"],
+        "one-hot": ["col1", "col5"],
+    }
+    updated_encodings = _reorder_encodings(encode_ds_1_adata, new_encodings)
+
+    assert expected_encodings == updated_encodings