--- a +++ b/tests/problems/time/test_mixins.py @@ -0,0 +1,411 @@ +from typing import Tuple + +import pytest + +import numpy as np +import pandas as pd + +from anndata import AnnData + +from moscot.problems.time import TemporalProblem +from tests._utils import MockSolverOutput + + +class TestTemporalMixin: + @pytest.mark.fast + @pytest.mark.parametrize("forward", [True, False]) + def test_cell_transition_full_pipeline(self, gt_temporal_adata: AnnData, forward: bool): + config = gt_temporal_adata.uns + key = config["key"] + key_1 = config["key_1"] + key_2 = config["key_2"] + key_3 = config["key_3"] + cell_types = set(gt_temporal_adata.obs["cell_type"].cat.categories) + problem = TemporalProblem(gt_temporal_adata) + problem = problem.prepare(key) + assert set(problem.problems.keys()) == {(key_1, key_2), (key_2, key_3)} + problem[key_1, key_2]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_105"]) + problem[key_2, key_3]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_105_11"]) + + cell_types_present_key_1 = ( + gt_temporal_adata[gt_temporal_adata.obs[key] == key_1].obs["cell_type"].cat.categories + ) + cell_types_present_key_2 = ( + gt_temporal_adata[gt_temporal_adata.obs[key] == key_2].obs["cell_type"].cat.categories + ) + + result = problem.cell_transition( + key_1, + key_2, + "cell_type", + "cell_type", + forward=forward, + ) + assert isinstance(result, pd.DataFrame) + expected_shape = (len(cell_types_present_key_1), len(cell_types_present_key_2)) + assert result.shape == expected_shape + assert set(result.index) == set(cell_types_present_key_1) if forward else set(cell_types) + assert set(result.columns) == set(cell_types_present_key_2) if not forward else set(cell_types) + marginal = result.sum(axis=forward == 1).values + present_cell_type_marginal = marginal[marginal > 0] + np.testing.assert_allclose(present_cell_type_marginal, 1.0) + + @pytest.mark.fast + @pytest.mark.parametrize("forward", [True, False]) + @pytest.mark.parametrize("mapping_mode", ["max", "sum"]) + @pytest.mark.parametrize("batch_size", [3, 7, None]) + @pytest.mark.parametrize("problem_kind", ["temporal"]) + def test_annotation_mapping(self, adata_anno: AnnData, forward: bool, mapping_mode, batch_size, gt_tm_annotation): + problem = TemporalProblem(adata_anno) + problem_keys = (0, 1) + problem = problem.prepare(time_key="day", joint_attr="X_pca") + assert set(problem.problems.keys()) == {problem_keys} + problem[problem_keys]._solution = MockSolverOutput(gt_tm_annotation) + annotation_label = "celltype1" if forward else "celltype2" + result = problem.annotation_mapping( + mapping_mode=mapping_mode, + annotation_label=annotation_label, + forward=forward, + source=0, + target=1, + batch_size=batch_size, + ) + if forward: + expected_result = ( + adata_anno.uns["expected_max1"] if mapping_mode == "max" else adata_anno.uns["expected_sum1"] + ) + else: + expected_result = ( + adata_anno.uns["expected_max2"] if mapping_mode == "max" else adata_anno.uns["expected_sum2"] + ) + assert (result[annotation_label] == expected_result).all() + + @pytest.mark.fast + @pytest.mark.parametrize("forward", [True, False]) + def test_cell_transition_different_groups(self, gt_temporal_adata: AnnData, forward: bool): + config = gt_temporal_adata.uns + key = config["key"] + key_1 = config["key_1"] + key_2 = config["key_2"] + key_3 = config["key_3"] + + gt_temporal_adata.obs["batch"] = gt_temporal_adata.obs["batch"].astype("category") + batches = set(gt_temporal_adata.obs["batch"].cat.categories) + problem = TemporalProblem(gt_temporal_adata) + problem = problem.prepare(key) + assert set(problem.problems.keys()) == {(key_1, key_2), (key_2, key_3)} + problem[key_1, key_2]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_105"]) + problem[key_2, key_3]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_105_11"]) + + result = problem.cell_transition( + key_1, + key_2, + "cell_type", + "batch", + forward=forward, + ) + assert isinstance(result, pd.DataFrame) + cell_types = set(gt_temporal_adata[gt_temporal_adata.obs[key] == key_1].obs["cell_type"].cat.categories) + batches = set(gt_temporal_adata[gt_temporal_adata.obs[key] == key_2].obs["batch"].cat.categories) + assert set(result.index) == cell_types + assert set(result.columns) == batches + + @pytest.mark.fast + @pytest.mark.parametrize("forward", [True, False]) + def test_cell_transition_subset_pipeline(self, gt_temporal_adata: AnnData, forward: bool): + config = gt_temporal_adata.uns + key = config["key"] + key_1 = config["key_1"] + key_2 = config["key_2"] + key_3 = config["key_3"] + problem = TemporalProblem(gt_temporal_adata) + problem = problem.prepare(key) + assert set(problem.problems.keys()) == {(key_1, key_2), (key_2, key_3)} + problem[key_1, key_2]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_105"]) + problem[key_2, key_3]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_105_11"]) + + early_annotation = ["Stromal", "unknown"] + late_annotation = ["Stromal", "Epithelial"] + result = problem.cell_transition( + key_1, + key_2, + {"cell_type": early_annotation}, + {"cell_type": late_annotation}, + forward=forward, + ) + assert isinstance(result, pd.DataFrame) + assert result.shape == (len(early_annotation), len(late_annotation)) + assert set(result.index) == set(early_annotation) + assert set(result.columns) == set(late_annotation) + + marginal = result.sum(axis=forward == 1).values + present_cell_type_marginal = marginal[marginal > 0] + np.testing.assert_allclose(present_cell_type_marginal, np.ones(len(present_cell_type_marginal))) + + @pytest.mark.parametrize("forward", [True, False]) + def test_cell_transition_regression(self, gt_temporal_adata: AnnData, forward: bool): + config = gt_temporal_adata.uns + key = config["key"] + key_1 = config["key_1"] + key_2 = config["key_2"] + key_3 = config["key_3"] + problem = TemporalProblem(gt_temporal_adata) + problem = problem.prepare(key) + assert set(problem.problems.keys()) == {(key_1, key_2), (key_2, key_3)} + problem[key_1, key_2]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_105"]) + problem[key_2, key_3]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_105_11"]) + result = problem.cell_transition( + 10, + 10.5, + source_groups="cell_type", + target_groups="cell_type", + forward=forward, + ) + cell_types_present_key_1 = ( + gt_temporal_adata[gt_temporal_adata.obs[key] == key_1].obs["cell_type"].cat.categories + ) + cell_types_present_key_2 = ( + gt_temporal_adata[gt_temporal_adata.obs[key] == key_2].obs["cell_type"].cat.categories + ) + expected_shape = ( + (len(cell_types_present_key_1), len(cell_types_present_key_2)) + if forward + else (len(cell_types_present_key_1), len(cell_types_present_key_2)) + ) + assert result.shape == expected_shape + marginal = result.sum(axis=forward == 1).values + present_cell_type_marginal = marginal[marginal > 0] + np.testing.assert_allclose(present_cell_type_marginal, 1.0, rtol=1e-6, atol=1e-6) + + direction = "forward" if forward else "backward" + gt = gt_temporal_adata.uns[f"cell_transition_10_105_{direction}"] + gt = gt.sort_index() + result = result.sort_index() + result = result[gt.columns] + np.testing.assert_allclose(result.values, gt.values, rtol=1e-6, atol=1e-6) + + def test_compute_time_point_distances_pipeline(self, adata_time: AnnData): + problem = TemporalProblem(adata_time).prepare("time") + distance_source_intermediate, distance_intermediate_target = problem.compute_time_point_distances( + source=0, + intermediate=1, + target=2, + posterior_marginals=False, + epsilon=10, + ) + assert distance_source_intermediate > 0 + assert distance_source_intermediate < 100 + assert distance_intermediate_target > 0 + + def test_batch_distances_pipeline(self, adata_time: AnnData): + problem = TemporalProblem(adata_time) + problem.prepare("time") + + batch_distance = problem.compute_batch_distances(time=1, batch_key="batch", epsilon=10) + assert batch_distance > 0 + + @pytest.mark.parametrize("account_for_unbalancedness", [True, False]) + def test_compute_interpolated_distance_pipeline(self, gt_temporal_adata: AnnData, account_for_unbalancedness: bool): + config = gt_temporal_adata.uns + key = config["key"] + key_1 = config["key_1"] + key_2 = config["key_2"] + key_3 = config["key_3"] + problem = TemporalProblem(gt_temporal_adata) + problem = problem.prepare( + key, + subset=[(key_1, key_2), (key_2, key_3), (key_1, key_3)], + policy="explicit", + xy_callback_kwargs={"n_comps": 50}, + ) + assert set(problem.problems.keys()) == {(key_1, key_2), (key_2, key_3), (key_1, key_3)} + problem[key_1, key_2]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_105"]) + problem[key_2, key_3]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_105_11"]) + problem[key_1, key_3]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_11"]) + + interpolation_result = problem.compute_interpolated_distance( + key_1, + key_2, + key_3, + account_for_unbalancedness=account_for_unbalancedness, + posterior_marginals=False, + seed=config["seed"], + epsilon=10, + ) + assert isinstance(interpolation_result, float) + assert interpolation_result > 0 + + def test_compute_interpolated_distance_regression(self, gt_temporal_adata: AnnData): + config = gt_temporal_adata.uns + key = config["key"] + key_1 = config["key_1"] + key_2 = config["key_2"] + key_3 = config["key_3"] + problem = TemporalProblem(gt_temporal_adata) + problem = problem.prepare( + key, + subset=[(key_1, key_2), (key_2, key_3), (key_1, key_3)], + policy="explicit", + xy_callback_kwargs={"n_comps": 50}, + ) + assert set(problem.problems.keys()) == {(key_1, key_2), (key_2, key_3), (key_1, key_3)} + problem[key_1, key_2]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_105"]) + problem[key_2, key_3]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_105_11"]) + problem[key_1, key_3]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_11"]) + + interpolation_result = problem.compute_interpolated_distance( + key_1, key_2, key_3, posterior_marginals=False, seed=config["seed"], epsilon=10 + ) + assert isinstance(interpolation_result, float) + assert interpolation_result > 0 + np.testing.assert_allclose( + interpolation_result, gt_temporal_adata.uns["interpolated_distance_10_105_11"], rtol=1e-6, atol=1e-6 + ) + + def test_compute_time_point_distances_regression(self, gt_temporal_adata: AnnData): + config = gt_temporal_adata.uns + key = config["key"] + key_1 = config["key_1"] + key_2 = config["key_2"] + key_3 = config["key_3"] + problem = TemporalProblem(gt_temporal_adata) + problem = problem.prepare( + key, + subset=[(key_1, key_2), (key_2, key_3), (key_1, key_3)], + policy="explicit", + xy_callback_kwargs={"n_comps": 50}, + ) + assert set(problem.problems.keys()) == {(key_1, key_2), (key_2, key_3), (key_1, key_3)} + problem[key_1, key_2]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_105"]) + problem[key_2, key_3]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_105_11"]) + problem[key_1, key_3]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_11"]) + + result = problem.compute_time_point_distances(key_1, key_2, key_3, posterior_marginals=False, epsilon=10) + assert isinstance(result, tuple) + assert result[0] > 0 + assert result[1] > 0 + np.testing.assert_allclose( + result[0], gt_temporal_adata.uns["time_point_distances_10_105_11"][0], rtol=1e-6, atol=1e-6 + ) + np.testing.assert_allclose( + result[1], gt_temporal_adata.uns["time_point_distances_10_105_11"][1], rtol=1e-6, atol=1e-6 + ) + + def test_compute_batch_distances_regression(self, gt_temporal_adata: AnnData): + config = gt_temporal_adata.uns + key = config["key"] + key_1 = config["key_1"] + key_2 = config["key_2"] + key_3 = config["key_3"] + problem = TemporalProblem(gt_temporal_adata) + problem = problem.prepare( + key, + subset=[(key_1, key_2), (key_2, key_3), (key_1, key_3)], + policy="explicit", + xy_callback_kwargs={"n_comps": 50}, + ) + assert set(problem.problems.keys()) == {(key_1, key_2), (key_2, key_3), (key_1, key_3)} + problem[key_1, key_2]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_105"]) + problem[key_2, key_3]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_105_11"]) + problem[key_1, key_3]._solution = MockSolverOutput(gt_temporal_adata.uns["tmap_10_11"]) + + result = problem.compute_batch_distances(key_1, "batch", epsilon=10) + assert isinstance(result, float) + np.testing.assert_allclose(result, gt_temporal_adata.uns["batch_distances_10"], rtol=1e-5) + + def test_compute_random_distance_regression(self, gt_temporal_adata: AnnData): + config = gt_temporal_adata.uns + key = config["key"] + key_1 = config["key_1"] + key_2 = config["key_2"] + key_3 = config["key_3"] + problem = TemporalProblem(gt_temporal_adata) + problem = problem.prepare( + key, + subset=[(key_1, key_2), (key_2, key_3), (key_1, key_3)], + policy="explicit", + xy_callback_kwargs={"n_comps": 50}, + ) + assert set(problem.problems.keys()) == {(key_1, key_2), (key_2, key_3), (key_1, key_3)} + + result = problem.compute_random_distance( + key_1, key_2, key_3, posterior_marginals=False, seed=config["seed"], epsilon=10 + ) + assert isinstance(result, float) + np.testing.assert_allclose(result, gt_temporal_adata.uns["random_distance_10_105_11"], rtol=1e-6, atol=1e-6) + + # TODO(MUCDK): split into 2 tests + @pytest.mark.fast + @pytest.mark.parametrize("only_start", [True, False]) + def test_get_data_pipeline(self, adata_time: AnnData, only_start: bool): + problem = TemporalProblem(adata_time) + problem.prepare("time") + + # TODO(MUCDK): use namedtuple + result = ( + problem._get_data(0, only_start=only_start, posterior_marginals=False) + if only_start + else problem._get_data(0, 1, 2, posterior_marginals=False) + ) + + assert isinstance(result, tuple) + assert len(result) == 2 if only_start else len(result) == 5 + if only_start: + assert isinstance(result[0], np.ndarray) + assert isinstance(result[1], AnnData) + else: + assert isinstance(result[0], np.ndarray) + # assert isinstance(result[1], np.ndarray) # FIXME: None growth-rates + assert isinstance(result[2], np.ndarray) + assert isinstance(result[3], AnnData) + assert isinstance(result[4], np.ndarray) + + @pytest.mark.parametrize("time_points", [(0, 1, 2), (0, 2, 1), ()]) + def test_get_interp_param_pipeline(self, adata_time: AnnData, time_points: Tuple[float]): + start, intermediate, end = time_points if len(time_points) else (42, 43, 44) + interpolation_parameter = None if len(time_points) == 3 else 0.5 + problem = TemporalProblem(adata_time) + problem.prepare("time") + problem.solve(max_iterations=2) + + if intermediate <= start or end <= intermediate: + with np.testing.assert_raises(ValueError): + problem._get_interp_param(start, intermediate, end, interpolation_parameter) + else: + inter_param = problem._get_interp_param(start, intermediate, end, interpolation_parameter) + assert inter_param == 0.5 + + @pytest.mark.fast + def test_cell_transition_regression_notparam( + self, + adata_time_with_tmap: AnnData, + ): # TODO(MUCDK): please check. + problem = TemporalProblem(adata_time_with_tmap).prepare("time") + problem[0, 1]._solution = MockSolverOutput(adata_time_with_tmap.uns["transport_matrix"]) + + result = problem.cell_transition( + 0, + 1, + source_groups="cell_type", + target_groups="cell_type", + forward=True, + ) + res = result.sort_index().sort_index(axis=1) + df_expected = adata_time_with_tmap.uns["cell_transition_gt"].sort_index().sort_index(axis=1) + # TODO(MUCDK): use pandas.testing + np.testing.assert_allclose(res.values, df_expected.values, rtol=1e-6, atol=1e-6) + + @pytest.mark.fast + @pytest.mark.parametrize("temporal_key", ["celltype", "time", "missing"]) + def test_temporal_key_numeric(self, adata_time: AnnData, temporal_key: str): + problem = TemporalProblem(adata_time) + if temporal_key == "missing": + with pytest.raises(KeyError, match=r"Unable to find temporal key"): + _ = problem.prepare(temporal_key) + elif temporal_key == "celltype": + with pytest.raises(TypeError, match=rf"Expected `adata.obs\[{temporal_key!r}\]`.*"): + _ = problem.prepare(temporal_key) + elif temporal_key == "time": + _ = problem.prepare(temporal_key) + else: + raise NotImplementedError(temporal_key)