--- a
+++ b/tests/test_anndata.py
@@ -0,0 +1,71 @@
+import unittest
+import pytest
+
+import numpy as np
+import pandas as pd
+
+
+@pytest.mark.usefixtures("filepath_hdf5")
+class TestAnnData:
+    def test_multi_group(self, filepath_hdf5):
+        pytest.importorskip("anndata")
+
+        from anndata import AnnData
+        from mofapy2.run.entry_point import mofa
+
+        views_names = ["view1", "view2"]
+        groups_names = ["groupA", "groupB"]
+
+        # Set dimensions
+        n_g1, n_g2 = 10, 20
+        d_m = 30
+        k = 5
+        n = n_g1 + n_g2
+
+        # Generate data
+        np.random.seed(42)
+        z1 = np.random.normal(size=(n_g1, k))
+        z2 = np.random.normal(size=(n_g2, k))
+        z = np.concatenate([z1, z2], axis=0)
+        w = np.random.normal(size=(d_m, k))
+
+        e1 = np.random.normal(size=(n_g1, d_m))
+        e2 = np.random.normal(size=(n_g2, d_m))
+        e = np.concatenate([e1, e2], axis=0)
+
+        y = np.dot(z, w.T) + e
+
+        # Make sample names
+        samples_names = [
+            f"sample{i}_group{g}"
+            for g, g_size in {"A": n_g1, "B": n_g2}.items()
+            for i in range(g_size)
+        ]
+        np.random.shuffle(samples_names)
+        samples_groups = [s.split("_")[1] for s in samples_names]
+
+        adata = AnnData(
+            X=y,
+            obs=pd.DataFrame(
+                {"sample": samples_names, "group": samples_groups}, index=samples_names
+            ),
+        )
+
+        mofa(
+            adata, groups_label="group", outfile=filepath_hdf5, expectations=["W", "Z"],
+        )
+
+        adata.obs["true_group"] = [s.split("_")[1] for s in adata.obs["sample"]]
+
+        assert all(adata.obs.group.values == adata.obs.true_group.values)
+
+        for sample, value in (
+            ("sample0_groupA", 0.145915),
+            ("sample7_groupB", -0.182254),
+        ):
+            si = np.where(adata.obs_names == sample)[0][0]
+            assert adata.obsm["X_mofa"][si, 0] == pytest.approx(value, 1e-5)
+
+
+if __name__ == "__main__":
+    unittest.main()