|
a |
|
b/tests/tools/causal/test_dowhy.py |
|
|
1 |
import re |
|
|
2 |
import warnings |
|
|
3 |
|
|
|
4 |
import anndata |
|
|
5 |
import dowhy |
|
|
6 |
import dowhy.datasets |
|
|
7 |
import matplotlib.pyplot as plt |
|
|
8 |
import numpy as np |
|
|
9 |
|
|
|
10 |
import ehrapy as ep |
|
|
11 |
|
|
|
12 |
warnings.filterwarnings("ignore") |
|
|
13 |
|
|
|
14 |
|
|
|
15 |
class TestCausal: |
|
|
16 |
def setup_method(self): |
|
|
17 |
linear_data = dowhy.datasets.linear_dataset( |
|
|
18 |
beta=10, |
|
|
19 |
num_common_causes=5, |
|
|
20 |
num_instruments=2, |
|
|
21 |
num_samples=1000, |
|
|
22 |
treatment_is_binary=True, |
|
|
23 |
) |
|
|
24 |
self.linear_data = anndata.AnnData(linear_data["df"].astype(np.float32)) |
|
|
25 |
self.linear_graph = linear_data["gml_graph"] |
|
|
26 |
self.outcome_name = "y" |
|
|
27 |
self.treatment_name = "v0" |
|
|
28 |
|
|
|
29 |
def test_dowhy_linear_dataset(self): |
|
|
30 |
estimate, refute_results = ep.tl.causal_inference( |
|
|
31 |
adata=self.linear_data, |
|
|
32 |
graph=self.linear_graph, |
|
|
33 |
treatment=self.treatment_name, |
|
|
34 |
outcome=self.outcome_name, |
|
|
35 |
estimation_method="backdoor.linear_regression", |
|
|
36 |
return_as="estimate+refute", |
|
|
37 |
) |
|
|
38 |
|
|
|
39 |
assert isinstance(refute_results, dict) |
|
|
40 |
assert len(refute_results) == 6 |
|
|
41 |
assert isinstance(estimate, dowhy.causal_estimator.CausalEstimate) |
|
|
42 |
assert np.isclose( |
|
|
43 |
np.round(refute_results["Refute: Add a random common cause"]["test_significance"], 3), 10.002, atol=0.005 |
|
|
44 |
) |
|
|
45 |
assert np.isclose( |
|
|
46 |
np.round(refute_results["Refute: Add a random common cause"]["test_significance"], 3), 10.002, atol=0.005 |
|
|
47 |
) |
|
|
48 |
|
|
|
49 |
def test_plot_causal_effect(self): |
|
|
50 |
estimate = ep.tl.causal_inference( |
|
|
51 |
adata=self.linear_data, |
|
|
52 |
graph=self.linear_graph, |
|
|
53 |
treatment=self.treatment_name, |
|
|
54 |
outcome=self.outcome_name, |
|
|
55 |
estimation_method="backdoor.linear_regression", |
|
|
56 |
return_as="estimate", |
|
|
57 |
show_graph=False, |
|
|
58 |
show_refute_plots=False, |
|
|
59 |
) |
|
|
60 |
ax = ep.pl.causal_effect(estimate) |
|
|
61 |
|
|
|
62 |
assert isinstance(ax, plt.Axes) |
|
|
63 |
legend = ax.get_legend() |
|
|
64 |
assert len(legend.get_texts()) == 2 # Check the number of legend labels |
|
|
65 |
assert legend.get_texts()[0].get_text() == "Observed data" |
|
|
66 |
assert legend.get_texts()[1].get_text() == "Causal variation" |
|
|
67 |
assert re.search(r"(9\.99\d+|10\.0)", str(ax.get_title())) |