|
a |
|
b/catenets/README.md |
|
|
1 |
# CATENets - Conditional Average Treatment Effect Estimation Using Neural Networks |
|
|
2 |
|
|
|
3 |
[](https://github.com/AliciaCurth/CATENets/actions/workflows/test.yml) |
|
|
4 |
[](https://catenets.readthedocs.io/en/latest/?badge=latest) |
|
|
5 |
[](https://github.com/AliciaCurth/CATENets/blob/main/LICENSE) |
|
|
6 |
|
|
|
7 |
|
|
|
8 |
Code Author: Alicia Curth (amc253@cam.ac.uk) |
|
|
9 |
|
|
|
10 |
This repo contains Jax-based, sklearn-style implementations of Neural Network-based Conditional |
|
|
11 |
Average Treatment Effect (CATE) Estimators, which were used in the AISTATS21 paper |
|
|
12 |
['Nonparametric Estimation of Heterogeneous Treatment Effects: From Theory to Learning |
|
|
13 |
Algorithms']( https://arxiv.org/abs/2101.10943) (Curth & vd Schaar, 2021a) as well as the follow up |
|
|
14 |
NeurIPS21 paper ["On Inductive Biases for Heterogeneous Treatment Effect Estimation"](https://arxiv.org/abs/2106.03765) (Curth & vd |
|
|
15 |
Schaar, 2021b) and the NeurIPS21 Datasets & Benchmarks track paper ["Really Doing Great at Estimating CATE? A Critical Look at ML Benchmarking Practices in Treatment Effect Estimation"](https://openreview.net/forum?id=FQLzQqGEAH) (Curth et al, 2021). |
|
|
16 |
|
|
|
17 |
We implement the SNet-class we introduce in Curth & vd Schaar (2021a), as well as FlexTENet and |
|
|
18 |
OffsetNet as discussed in Curth & vd Schaar (2021b), and re-implement a number of |
|
|
19 |
NN-based algorithms from existing literature (Shalit et al (2017), Shi et al (2019), Hassanpour |
|
|
20 |
& Greiner (2020)). We also provide Neural Network (NN)-based instantiations of a number of so-called |
|
|
21 |
meta-learners for CATE estimation, including two-step pseudo-outcome regression estimators (the |
|
|
22 |
DR-learner (Kennedy, 2020) and single-robust propensity-weighted (PW) and regression-adjusted (RA) learners), Nie & Wager (2017)'s R-learner and Kuenzel et al (2019)'s X-learner. The jax implementations in ``catenets.models.jax`` were used in all papers listed; additionally, pytorch versions of some models (``catenets.models.torch``) were contributed by [Bogdan Cebere](https://github.com/bcebere). |
|
|
23 |
|
|
|
24 |
### Interface |
|
|
25 |
The repo contains a package ``catenets``, which contains all general code used for modeling and evaluation, and a folder ``experiments``, in which the code for replicating experimental results is contained. All implemented learning algorithms in ``catenets`` (``SNet, FlexTENet, OffsetNet, TNet, SNet1 (TARNet), SNet2 |
|
|
26 |
(DragonNet), SNet3, DRNet, RANet, PWNet, RNet, XNet``) come with a sklearn-style wrapper, implementing a ``.fit(X, y, w)`` and a ``.predict(X)`` method, where predict returns CATE by default. All hyperparameters are documented in detail in the respective files in ``catenets.models`` folder. |
|
|
27 |
|
|
|
28 |
Example usage: |
|
|
29 |
|
|
|
30 |
```python |
|
|
31 |
from catenets.models.jax import TNet, SNet |
|
|
32 |
from catenets.experiment_utils.simulation_utils import simulate_treatment_setup |
|
|
33 |
|
|
|
34 |
# simulate some data (here: unconfounded, 10 prognostic variables and 5 predictive variables) |
|
|
35 |
X, y, w, p, cate = simulate_treatment_setup(n=2000, n_o=10, n_t=5, n_c=0) |
|
|
36 |
|
|
|
37 |
# estimate CATE using TNet |
|
|
38 |
t = TNet() |
|
|
39 |
t.fit(X, y, w) |
|
|
40 |
cate_pred_t = t.predict(X) # without potential outcomes |
|
|
41 |
cate_pred_t, po0_pred_t, po1_pred_t = t.predict(X, return_po=True) # predict potential outcomes too |
|
|
42 |
|
|
|
43 |
# estimate CATE using SNet |
|
|
44 |
s = SNet(penalty_orthogonal=0.01) |
|
|
45 |
s.fit(X, y, w) |
|
|
46 |
cate_pred_s = s.predict(X) |
|
|
47 |
|
|
|
48 |
``` |
|
|
49 |
|
|
|
50 |
All experiments in Curth & vd Schaar (2021a) can be replicated using this repository; the necessary |
|
|
51 |
code is in ``experiments.experiments_AISTATS21``. To do so from shell, clone the repo, create a new |
|
|
52 |
virtual environment and run |
|
|
53 |
```shell |
|
|
54 |
pip install -r requirements.txt #install requirements |
|
|
55 |
python run_experiments_AISTATS.py |
|
|
56 |
``` |
|
|
57 |
```shell |
|
|
58 |
Options: |
|
|
59 |
--experiment # defaults to 'simulation', 'ihdp' will run ihdp experiments |
|
|
60 |
--setting # different simulation settings in synthetic experiments (can be 1-5) |
|
|
61 |
--models # defaults to None which will train all models considered in paper, |
|
|
62 |
# can be string of model name (e.g 'TNet'), 'plug' for all plugin models, |
|
|
63 |
# 'pseudo' for all pseudo-outcome regression models |
|
|
64 |
|
|
|
65 |
--file_name # base file name to write to, defaults to 'results' |
|
|
66 |
--n_repeats # number of experiments to run for each configuration, defaults to 10 (should be set to 100 for IHDP) |
|
|
67 |
``` |
|
|
68 |
|
|
|
69 |
Similarly, the experiments in Curth & vd Schaar (2021b) can be replicated using the code in |
|
|
70 |
``experiments.experiments_inductivebias_NeurIPS21`` (or from shell using ```python |
|
|
71 |
run_experiments_inductive_bias_NeurIPS.py```) and the experiments in Curth et al (2021) can be replicated using the code in ``experiments.experiments_benchmarks_NeurIPS21`` (the catenets experiments can also be run from shell using ``python run_experiments_benchmarks_NeurIPS``). |
|
|
72 |
|
|
|
73 |
The code can also be installed as a python package (``catenets``). From a local copy of the repo, run ``python setup.py install``. |
|
|
74 |
|
|
|
75 |
Note: jax is currently only supported on macOS and linux, but can be run from windows using WSL (the windows subsystem for linux). |
|
|
76 |
|
|
|
77 |
|
|
|
78 |
### Citing |
|
|
79 |
|
|
|
80 |
If you use this software please cite the corresponding paper(s): |
|
|
81 |
|
|
|
82 |
``` |
|
|
83 |
@inproceedings{curth2021nonparametric, |
|
|
84 |
title={Nonparametric Estimation of Heterogeneous Treatment Effects: From Theory to Learning Algorithms}, |
|
|
85 |
author={Curth, Alicia and van der Schaar, Mihaela}, |
|
|
86 |
year={2021}, |
|
|
87 |
booktitle={Proceedings of the 24th International Conference on Artificial |
|
|
88 |
Intelligence and Statistics (AISTATS)}, |
|
|
89 |
organization={PMLR} |
|
|
90 |
} |
|
|
91 |
|
|
|
92 |
@article{curth2021inductive, |
|
|
93 |
title={On Inductive Biases for Heterogeneous Treatment Effect Estimation}, |
|
|
94 |
author={Curth, Alicia and van der Schaar, Mihaela}, |
|
|
95 |
booktitle={Proceedings of the Thirty-Fifth Conference on Neural Information Processing Systems}, |
|
|
96 |
year={2021} |
|
|
97 |
} |
|
|
98 |
|
|
|
99 |
|
|
|
100 |
@article{curth2021really, |
|
|
101 |
title={Really Doing Great at Estimating CATE? A Critical Look at ML Benchmarking Practices in Treatment Effect Estimation}, |
|
|
102 |
author={Curth, Alicia and Svensson, David and Weatherall, James and van der Schaar, Mihaela}, |
|
|
103 |
booktitle={Proceedings of the Neural Information Processing Systems Track on Datasets and Benchmarks}, |
|
|
104 |
year={2021} |
|
|
105 |
} |
|
|
106 |
|
|
|
107 |
``` |