[e5f1db]: / tests / tools / causal / test_dowhy.py

Download this file

68 lines (58 with data), 2.3 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import re
import warnings
import anndata
import dowhy
import dowhy.datasets
import matplotlib.pyplot as plt
import numpy as np
import ehrapy as ep
warnings.filterwarnings("ignore")
class TestCausal:
def setup_method(self):
linear_data = dowhy.datasets.linear_dataset(
beta=10,
num_common_causes=5,
num_instruments=2,
num_samples=1000,
treatment_is_binary=True,
)
self.linear_data = anndata.AnnData(linear_data["df"].astype(np.float32))
self.linear_graph = linear_data["gml_graph"]
self.outcome_name = "y"
self.treatment_name = "v0"
def test_dowhy_linear_dataset(self):
estimate, refute_results = ep.tl.causal_inference(
adata=self.linear_data,
graph=self.linear_graph,
treatment=self.treatment_name,
outcome=self.outcome_name,
estimation_method="backdoor.linear_regression",
return_as="estimate+refute",
)
assert isinstance(refute_results, dict)
assert len(refute_results) == 6
assert isinstance(estimate, dowhy.causal_estimator.CausalEstimate)
assert np.isclose(
np.round(refute_results["Refute: Add a random common cause"]["test_significance"], 3), 10.002, atol=0.005
)
assert np.isclose(
np.round(refute_results["Refute: Add a random common cause"]["test_significance"], 3), 10.002, atol=0.005
)
def test_plot_causal_effect(self):
estimate = ep.tl.causal_inference(
adata=self.linear_data,
graph=self.linear_graph,
treatment=self.treatment_name,
outcome=self.outcome_name,
estimation_method="backdoor.linear_regression",
return_as="estimate",
show_graph=False,
show_refute_plots=False,
)
ax = ep.pl.causal_effect(estimate)
assert isinstance(ax, plt.Axes)
legend = ax.get_legend()
assert len(legend.get_texts()) == 2 # Check the number of legend labels
assert legend.get_texts()[0].get_text() == "Observed data"
assert legend.get_texts()[1].get_text() == "Causal variation"
assert re.search(r"(9\.99\d+|10\.0)", str(ax.get_title()))