a b/catenets/datasets/__init__.py
1
# stdlib
2
import os
3
from pathlib import Path
4
from typing import Any, Tuple
5
6
from . import dataset_acic2016, dataset_ihdp, dataset_twins
7
8
DATA_PATH = Path(os.path.dirname(__file__)) / Path("data")
9
10
try:
11
    os.mkdir(DATA_PATH)
12
except BaseException:
13
    pass
14
15
16
def load(dataset: str, *args: Any, **kwargs: Any) -> Tuple:
17
    """
18
    Input:
19
        dataset: the name of the dataset to load
20
    Outputs:
21
        - Train_X, Test_X: Train and Test features
22
        - Train_Y: Observable outcomes
23
        - Train_T: Assigned treatment
24
        - Test_Y: Potential outcomes.
25
    """
26
    if dataset == "twins":
27
        return dataset_twins.load(DATA_PATH, *args, **kwargs)
28
    if dataset == "ihdp":
29
        return dataset_ihdp.load(DATA_PATH, *args, **kwargs)
30
    if dataset == "acic2016":
31
        return dataset_acic2016.load(DATA_PATH, *args, **kwargs)
32
    else:
33
        raise Exception("Unsupported dataset")
34
35
36
__all__ = ["dataset_ihdp", "dataset_twins", "dataset_acic2016", "load"]