Switch to unified view

a b/tests/tools/causal/test_dowhy.py
1
import re
2
import warnings
3
4
import anndata
5
import dowhy
6
import dowhy.datasets
7
import matplotlib.pyplot as plt
8
import numpy as np
9
10
import ehrapy as ep
11
12
warnings.filterwarnings("ignore")
13
14
15
class TestCausal:
16
    def setup_method(self):
17
        linear_data = dowhy.datasets.linear_dataset(
18
            beta=10,
19
            num_common_causes=5,
20
            num_instruments=2,
21
            num_samples=1000,
22
            treatment_is_binary=True,
23
        )
24
        self.linear_data = anndata.AnnData(linear_data["df"].astype(np.float32))
25
        self.linear_graph = linear_data["gml_graph"]
26
        self.outcome_name = "y"
27
        self.treatment_name = "v0"
28
29
    def test_dowhy_linear_dataset(self):
30
        estimate, refute_results = ep.tl.causal_inference(
31
            adata=self.linear_data,
32
            graph=self.linear_graph,
33
            treatment=self.treatment_name,
34
            outcome=self.outcome_name,
35
            estimation_method="backdoor.linear_regression",
36
            return_as="estimate+refute",
37
        )
38
39
        assert isinstance(refute_results, dict)
40
        assert len(refute_results) == 6
41
        assert isinstance(estimate, dowhy.causal_estimator.CausalEstimate)
42
        assert np.isclose(
43
            np.round(refute_results["Refute: Add a random common cause"]["test_significance"], 3), 10.002, atol=0.005
44
        )
45
        assert np.isclose(
46
            np.round(refute_results["Refute: Add a random common cause"]["test_significance"], 3), 10.002, atol=0.005
47
        )
48
49
    def test_plot_causal_effect(self):
50
        estimate = ep.tl.causal_inference(
51
            adata=self.linear_data,
52
            graph=self.linear_graph,
53
            treatment=self.treatment_name,
54
            outcome=self.outcome_name,
55
            estimation_method="backdoor.linear_regression",
56
            return_as="estimate",
57
            show_graph=False,
58
            show_refute_plots=False,
59
        )
60
        ax = ep.pl.causal_effect(estimate)
61
62
        assert isinstance(ax, plt.Axes)
63
        legend = ax.get_legend()
64
        assert len(legend.get_texts()) == 2  # Check the number of legend labels
65
        assert legend.get_texts()[0].get_text() == "Observed data"
66
        assert legend.get_texts()[1].get_text() == "Causal variation"
67
        assert re.search(r"(9\.99\d+|10\.0)", str(ax.get_title()))