Switch to side-by-side view

--- a
+++ b/tests/pycombat/test_pycombat.py
@@ -0,0 +1,184 @@
+import unittest
+
+import numpy as np
+import pandas as pd
+from anndata import AnnData
+
+from inmoose.pycombat import pycombat_norm
+from inmoose.pycombat.covariates import make_design_matrix
+from inmoose.pycombat.pycombat_norm import (
+    adjust_data,
+    calculate_mean_var,
+    calculate_stand_mean,
+    check_mean_only,
+    compute_prior,
+    fit_model,
+    postmean,
+    postvar,
+    standardise_data,
+)
+
+
+class test_pycombat(unittest.TestCase):
+    @classmethod
+    def setUpClass(self):
+        matrix = np.transpose(
+            [
+                np.random.normal(size=1000, loc=3, scale=1),
+                np.random.normal(size=1000, loc=3, scale=1),
+                np.random.normal(size=1000, loc=3, scale=1),
+                np.random.normal(size=1000, loc=2, scale=0.6),
+                np.random.normal(size=1000, loc=2, scale=0.6),
+                np.random.normal(size=1000, loc=4, scale=1),
+                np.random.normal(size=1000, loc=4, scale=1),
+                np.random.normal(size=1000, loc=4, scale=1),
+                np.random.normal(size=1000, loc=4, scale=1),
+            ]
+        )
+
+        self.matrix = pd.DataFrame(
+            data=matrix,
+            columns=["sample_" + str(i + 1) for i in range(9)],
+            index=["gene_" + str(i + 1) for i in range(1000)],
+        )
+
+        # test normal execution
+        self.batch = np.asarray([1, 1, 1, 2, 2, 3, 3, 3, 3])
+        self.matrix_adjusted = pycombat_norm(self.matrix, self.batch)
+
+        # useful constants for unit testing
+        ref_batch = None
+        mean_only = False
+        par_prior = False
+        precision = None
+        mod = None
+        self.dat = self.matrix.values
+        vci = make_design_matrix(
+            self.dat, self.batch, covar_mod=mod, ref_batch=ref_batch
+        )
+        self.dat = vci.counts
+        self.design = np.transpose(vci.design)
+        self.batchmod = vci.batch_mod
+        self.batches = [vci.batch_composition[b] for b in vci.batch.categories]
+        self.n_batches = [len(v) for v in self.batches]
+        self.batch = vci.batch
+        ref = vci.ref_batch_idx
+        self.n_batch = vci.n_batch
+        self.n_array = self.dat.shape[1]
+
+        self.B_hat, self.grand_mean, self.var_pooled = calculate_mean_var(
+            self.design,
+            self.batches,
+            ref,
+            self.dat,
+            self.n_batches,
+            self.n_batch,
+            self.n_array,
+        )
+        self.stand_mean = calculate_stand_mean(
+            self.grand_mean, self.n_array, self.design, self.n_batch, self.B_hat
+        )
+        self.s_data = standardise_data(
+            self.dat, self.stand_mean, self.var_pooled, self.n_array
+        )
+        self.gamma_star, self.delta_star, self.batch_design = fit_model(
+            self.design,
+            self.n_batch,
+            self.s_data,
+            self.batches,
+            mean_only,
+            par_prior,
+            precision,
+            ref,
+        )
+        self.bayes_data = adjust_data(
+            self.s_data,
+            self.gamma_star,
+            self.delta_star,
+            self.batch_design,
+            self.n_batches,
+            self.var_pooled,
+            self.stand_mean,
+            self.n_array,
+            ref,
+            self.batches,
+            self.dat,
+        )
+
+    def test_compute_prior(self):
+        print("aprior", compute_prior("a", self.gamma_star, False))
+        self.assertEqual(compute_prior("a", self.gamma_star, True), 1)
+        print("bprior", compute_prior("b", self.gamma_star, False))
+        self.assertEqual(compute_prior("b", self.gamma_star, True), 1)
+
+    def test_postmean(self):
+        self.assertEqual(
+            np.shape(
+                postmean(
+                    self.gamma_star, self.delta_star, self.gamma_star, self.delta_star
+                )
+            ),
+            np.shape(self.gamma_star),
+        )
+
+    def test_postvar(self):
+        self.assertEqual(np.sum(postvar([2, 4, 6], 2, 1, 1) - [2, 3, 4]), 0)
+
+    # def test_it_sol(self):
+    #    ()
+
+    # def test_int_eprior(self):
+    #    ()
+
+    @unittest.skip("automate the verification")
+    def test_check_mean_only(self):
+        check_mean_only(True)
+        check_mean_only(False)
+        print("Only one line of text should have been printed above.")
+
+    def test_calculate_mean_var(self):
+        self.assertEqual(np.shape(self.B_hat)[0], np.shape(self.design)[0])
+        self.assertEqual(np.shape(self.grand_mean)[0], np.shape(self.dat)[0])
+        self.assertEqual(np.shape(self.var_pooled)[0], np.shape(self.dat)[0])
+
+    def test_calculate_stand_mean(self):
+        self.assertEqual(np.shape(self.stand_mean), np.shape(self.dat))
+
+    def test_standardise_data(self):
+        self.assertEqual(np.shape(self.s_data), np.shape(self.dat))
+
+    def test_fit_model(self):
+        self.assertEqual(np.shape(self.gamma_star)[1], np.shape(self.dat)[0])
+        self.assertEqual(np.shape(self.delta_star)[1], np.shape(self.dat)[0])
+        self.assertEqual(np.shape(self.batch_design), np.shape(self.design))
+
+    def test_adjust_data(self):
+        self.assertEqual(np.shape(self.bayes_data), np.shape(self.dat))
+
+    def test_pycombat(self):
+        self.assertEqual(np.shape(self.matrix), np.shape(self.matrix_adjusted))
+        self.assertTrue(
+            np.mean(self.matrix.values) - np.mean(self.matrix_adjusted.values)
+            <= np.mean(self.matrix.values) * 0.05
+        )
+        self.assertTrue(
+            np.var(self.matrix_adjusted.values) <= np.var(self.matrix.values)
+        )
+        # test raise error for single sample batch
+        with self.assertRaisesRegex(
+            ValueError,
+            r"Batches 4 contain a single sample, which is not supported for batch effect correction. Please review your inputs.",
+        ):
+            pycombat_norm(self.matrix, np.asarray([1, 1, 1, 2, 2, 3, 3, 3, 4]))
+
+    def test_pycombat_anndata(self):
+        ad = AnnData(
+            self.matrix.T,
+            obs=pd.DataFrame({"batch": self.batch}, index=self.matrix.columns),
+        )
+        res = pycombat_norm(ad, batch="batch")
+        self.assertTrue(np.allclose(res.X.T, self.matrix_adjusted))
+        with self.assertRaisesRegex(
+            ValueError, 'the batch column "foo" must appear in'
+        ):
+            pycombat_norm(ad, batch="foo")