Switch to side-by-side view

--- a
+++ b/tests/preprocessing/test_balanced_sampling.py
@@ -0,0 +1,62 @@
+from pathlib import Path
+
+import pytest
+
+import ehrapy as ep
+from ehrapy.io._read import read_csv
+from tests.conftest import TEST_DATA_PATH
+
+CURRENT_DIR = Path(__file__).parent
+
+
+@pytest.fixture
+def adata_mini():
+    return read_csv(f"{TEST_DATA_PATH}/encode/dataset1.csv", columns_obs_only=["clinic_day"])
+
+
+def test_balanced_sampling_basic(adata_mini):
+    # no key
+    with pytest.raises(TypeError):
+        ep.pp.balanced_sample(adata_mini)
+
+    # invalid key
+    with pytest.raises(ValueError):
+        ep.pp.balanced_sample(adata_mini, key="non_existing_column")
+
+    # invalid method
+    with pytest.raises(ValueError):
+        ep.pp.balanced_sample(adata_mini, key="clinic_day", method="non_existing_method")
+
+    # undersampling
+    adata_sampled = ep.pp.balanced_sample(adata_mini, key="clinic_day", method="RandomUnderSampler", copy=True)
+    assert adata_sampled.n_obs == 4
+    assert adata_sampled.obs.clinic_day.value_counts().min() == adata_sampled.obs.clinic_day.value_counts().max()
+
+    # oversampling
+    adata_sampled = ep.pp.balanced_sample(adata_mini, key="clinic_day", method="RandomOverSampler", copy=True)
+    assert adata_sampled.n_obs == 8
+    assert adata_sampled.obs.clinic_day.value_counts().min() == adata_sampled.obs.clinic_day.value_counts().max()
+
+    # undersampling, no copy
+    adata_mini_for_undersampling = adata_mini.copy()
+    output = ep.pp.balanced_sample(
+        adata_mini_for_undersampling, key="clinic_day", method="RandomUnderSampler", copy=False
+    )
+    assert output is None
+    assert adata_mini_for_undersampling.n_obs == 4
+    assert (
+        adata_mini_for_undersampling.obs.clinic_day.value_counts().min()
+        == adata_mini_for_undersampling.obs.clinic_day.value_counts().max()
+    )
+
+    # oversampling, no copy
+    adata_mini_for_oversampling = adata_mini.copy()
+    output = ep.pp.balanced_sample(
+        adata_mini_for_oversampling, key="clinic_day", method="RandomOverSampler", copy=False
+    )
+    assert output is None
+    assert adata_mini_for_oversampling.n_obs == 8
+    assert (
+        adata_mini_for_oversampling.obs.clinic_day.value_counts().min()
+        == adata_mini_for_oversampling.obs.clinic_day.value_counts().max()
+    )