|
a |
|
b/catenets/datasets/__init__.py |
|
|
1 |
# stdlib |
|
|
2 |
import os |
|
|
3 |
from pathlib import Path |
|
|
4 |
from typing import Any, Tuple |
|
|
5 |
|
|
|
6 |
from . import dataset_acic2016, dataset_ihdp, dataset_twins |
|
|
7 |
|
|
|
8 |
DATA_PATH = Path(os.path.dirname(__file__)) / Path("data") |
|
|
9 |
|
|
|
10 |
try: |
|
|
11 |
os.mkdir(DATA_PATH) |
|
|
12 |
except BaseException: |
|
|
13 |
pass |
|
|
14 |
|
|
|
15 |
|
|
|
16 |
def load(dataset: str, *args: Any, **kwargs: Any) -> Tuple: |
|
|
17 |
""" |
|
|
18 |
Input: |
|
|
19 |
dataset: the name of the dataset to load |
|
|
20 |
Outputs: |
|
|
21 |
- Train_X, Test_X: Train and Test features |
|
|
22 |
- Train_Y: Observable outcomes |
|
|
23 |
- Train_T: Assigned treatment |
|
|
24 |
- Test_Y: Potential outcomes. |
|
|
25 |
""" |
|
|
26 |
if dataset == "twins": |
|
|
27 |
return dataset_twins.load(DATA_PATH, *args, **kwargs) |
|
|
28 |
if dataset == "ihdp": |
|
|
29 |
return dataset_ihdp.load(DATA_PATH, *args, **kwargs) |
|
|
30 |
if dataset == "acic2016": |
|
|
31 |
return dataset_acic2016.load(DATA_PATH, *args, **kwargs) |
|
|
32 |
else: |
|
|
33 |
raise Exception("Unsupported dataset") |
|
|
34 |
|
|
|
35 |
|
|
|
36 |
__all__ = ["dataset_ihdp", "dataset_twins", "dataset_acic2016", "load"] |