--- a +++ b/tests/pycombat/test_pycombat.py @@ -0,0 +1,184 @@ +import unittest + +import numpy as np +import pandas as pd +from anndata import AnnData + +from inmoose.pycombat import pycombat_norm +from inmoose.pycombat.covariates import make_design_matrix +from inmoose.pycombat.pycombat_norm import ( + adjust_data, + calculate_mean_var, + calculate_stand_mean, + check_mean_only, + compute_prior, + fit_model, + postmean, + postvar, + standardise_data, +) + + +class test_pycombat(unittest.TestCase): + @classmethod + def setUpClass(self): + matrix = np.transpose( + [ + np.random.normal(size=1000, loc=3, scale=1), + np.random.normal(size=1000, loc=3, scale=1), + np.random.normal(size=1000, loc=3, scale=1), + np.random.normal(size=1000, loc=2, scale=0.6), + np.random.normal(size=1000, loc=2, scale=0.6), + np.random.normal(size=1000, loc=4, scale=1), + np.random.normal(size=1000, loc=4, scale=1), + np.random.normal(size=1000, loc=4, scale=1), + np.random.normal(size=1000, loc=4, scale=1), + ] + ) + + self.matrix = pd.DataFrame( + data=matrix, + columns=["sample_" + str(i + 1) for i in range(9)], + index=["gene_" + str(i + 1) for i in range(1000)], + ) + + # test normal execution + self.batch = np.asarray([1, 1, 1, 2, 2, 3, 3, 3, 3]) + self.matrix_adjusted = pycombat_norm(self.matrix, self.batch) + + # useful constants for unit testing + ref_batch = None + mean_only = False + par_prior = False + precision = None + mod = None + self.dat = self.matrix.values + vci = make_design_matrix( + self.dat, self.batch, covar_mod=mod, ref_batch=ref_batch + ) + self.dat = vci.counts + self.design = np.transpose(vci.design) + self.batchmod = vci.batch_mod + self.batches = [vci.batch_composition[b] for b in vci.batch.categories] + self.n_batches = [len(v) for v in self.batches] + self.batch = vci.batch + ref = vci.ref_batch_idx + self.n_batch = vci.n_batch + self.n_array = self.dat.shape[1] + + self.B_hat, self.grand_mean, self.var_pooled = calculate_mean_var( + self.design, + self.batches, + ref, + self.dat, + self.n_batches, + self.n_batch, + self.n_array, + ) + self.stand_mean = calculate_stand_mean( + self.grand_mean, self.n_array, self.design, self.n_batch, self.B_hat + ) + self.s_data = standardise_data( + self.dat, self.stand_mean, self.var_pooled, self.n_array + ) + self.gamma_star, self.delta_star, self.batch_design = fit_model( + self.design, + self.n_batch, + self.s_data, + self.batches, + mean_only, + par_prior, + precision, + ref, + ) + self.bayes_data = adjust_data( + self.s_data, + self.gamma_star, + self.delta_star, + self.batch_design, + self.n_batches, + self.var_pooled, + self.stand_mean, + self.n_array, + ref, + self.batches, + self.dat, + ) + + def test_compute_prior(self): + print("aprior", compute_prior("a", self.gamma_star, False)) + self.assertEqual(compute_prior("a", self.gamma_star, True), 1) + print("bprior", compute_prior("b", self.gamma_star, False)) + self.assertEqual(compute_prior("b", self.gamma_star, True), 1) + + def test_postmean(self): + self.assertEqual( + np.shape( + postmean( + self.gamma_star, self.delta_star, self.gamma_star, self.delta_star + ) + ), + np.shape(self.gamma_star), + ) + + def test_postvar(self): + self.assertEqual(np.sum(postvar([2, 4, 6], 2, 1, 1) - [2, 3, 4]), 0) + + # def test_it_sol(self): + # () + + # def test_int_eprior(self): + # () + + @unittest.skip("automate the verification") + def test_check_mean_only(self): + check_mean_only(True) + check_mean_only(False) + print("Only one line of text should have been printed above.") + + def test_calculate_mean_var(self): + self.assertEqual(np.shape(self.B_hat)[0], np.shape(self.design)[0]) + self.assertEqual(np.shape(self.grand_mean)[0], np.shape(self.dat)[0]) + self.assertEqual(np.shape(self.var_pooled)[0], np.shape(self.dat)[0]) + + def test_calculate_stand_mean(self): + self.assertEqual(np.shape(self.stand_mean), np.shape(self.dat)) + + def test_standardise_data(self): + self.assertEqual(np.shape(self.s_data), np.shape(self.dat)) + + def test_fit_model(self): + self.assertEqual(np.shape(self.gamma_star)[1], np.shape(self.dat)[0]) + self.assertEqual(np.shape(self.delta_star)[1], np.shape(self.dat)[0]) + self.assertEqual(np.shape(self.batch_design), np.shape(self.design)) + + def test_adjust_data(self): + self.assertEqual(np.shape(self.bayes_data), np.shape(self.dat)) + + def test_pycombat(self): + self.assertEqual(np.shape(self.matrix), np.shape(self.matrix_adjusted)) + self.assertTrue( + np.mean(self.matrix.values) - np.mean(self.matrix_adjusted.values) + <= np.mean(self.matrix.values) * 0.05 + ) + self.assertTrue( + np.var(self.matrix_adjusted.values) <= np.var(self.matrix.values) + ) + # test raise error for single sample batch + with self.assertRaisesRegex( + ValueError, + r"Batches 4 contain a single sample, which is not supported for batch effect correction. Please review your inputs.", + ): + pycombat_norm(self.matrix, np.asarray([1, 1, 1, 2, 2, 3, 3, 3, 4])) + + def test_pycombat_anndata(self): + ad = AnnData( + self.matrix.T, + obs=pd.DataFrame({"batch": self.batch}, index=self.matrix.columns), + ) + res = pycombat_norm(ad, batch="batch") + self.assertTrue(np.allclose(res.X.T, self.matrix_adjusted)) + with self.assertRaisesRegex( + ValueError, 'the batch column "foo" must appear in' + ): + pycombat_norm(ad, batch="foo")