Switch to unified view

a b/catenets/experiment_utils/base.py
1
"""
2
Some utils for experiments
3
"""
4
# Author: Alicia Curth
5
from typing import Callable, Dict, Optional, Union
6
7
import jax.numpy as jnp
8
9
from catenets.models.jax import (
10
    DRNET_NAME,
11
    PSEUDOOUT_NAME,
12
    RANET_NAME,
13
    RNET_NAME,
14
    SNET1_NAME,
15
    SNET2_NAME,
16
    SNET3_NAME,
17
    SNET_NAME,
18
    T_NAME,
19
    XNET_NAME,
20
    PseudoOutcomeNet,
21
    get_catenet,
22
)
23
from catenets.models.jax.base import check_shape_1d_data
24
from catenets.models.jax.transformation_utils import (
25
    DR_TRANSFORMATION,
26
    PW_TRANSFORMATION,
27
    RA_TRANSFORMATION,
28
)
29
30
SEP = "_"
31
32
33
def eval_mse_model(
34
    inputs: jnp.ndarray,
35
    targets: jnp.ndarray,
36
    predict_fun: Callable,
37
    params: jnp.ndarray,
38
) -> jnp.ndarray:
39
    # evaluate the mse of a model given its function and params
40
    preds = predict_fun(params, inputs)
41
    return jnp.mean((preds - targets) ** 2)
42
43
44
def eval_mse(preds: jnp.ndarray, targets: jnp.ndarray) -> jnp.ndarray:
45
    preds = check_shape_1d_data(preds)
46
    targets = check_shape_1d_data(targets)
47
    return jnp.mean((preds - targets) ** 2)
48
49
50
def eval_root_mse(cate_pred: jnp.ndarray, cate_true: jnp.ndarray) -> jnp.ndarray:
51
    cate_true = check_shape_1d_data(cate_true)
52
    cate_pred = check_shape_1d_data(cate_pred)
53
    return jnp.sqrt(eval_mse(cate_pred, cate_true))
54
55
56
def eval_abs_error_ate(cate_pred: jnp.ndarray, cate_true: jnp.ndarray) -> jnp.ndarray:
57
    cate_true = check_shape_1d_data(cate_true)
58
    cate_pred = check_shape_1d_data(cate_pred)
59
    return jnp.abs(jnp.mean(cate_pred) - jnp.mean(cate_true))
60
61
62
def get_model_set(
63
    model_selection: Union[str, list] = "all", model_params: Optional[dict] = None
64
) -> Dict:
65
    """Helper function to retrieve a set of models"""
66
    # get model selection
67
    if type(model_selection) is str:
68
        if model_selection == "snet":
69
            models = get_all_snets()
70
        elif model_selection == "pseudo":
71
            models = get_all_pseudoout_models()
72
        elif model_selection == "twostep":
73
            models = get_all_twostep_models()
74
        elif model_selection == "all":
75
            models = dict(**get_all_snets(), **get_all_pseudoout_models())
76
        else:
77
            models = {model_selection: get_catenet(model_selection)()}  # type: ignore
78
    elif type(model_selection) is list:
79
        models = {}
80
        for model in model_selection:
81
            models.update({model: get_catenet(model)()})
82
    else:
83
        raise ValueError("model_selection should be string or list.")
84
85
    # set hyperparameters
86
    if model_params is not None:
87
        for model in models.values():
88
            existing_params = model.get_params()
89
            new_params = {
90
                key: val
91
                for key, val in model_params.items()
92
                if key in existing_params.keys()
93
            }
94
            model.set_params(**new_params)
95
96
    return models
97
98
99
ALL_SNETS = [T_NAME, SNET1_NAME, SNET2_NAME, SNET3_NAME, SNET_NAME]
100
ALL_PSEUDOOUT_MODELS = [DR_TRANSFORMATION, PW_TRANSFORMATION, RA_TRANSFORMATION]
101
ALL_TWOSTEP_MODELS = [DRNET_NAME, RANET_NAME, XNET_NAME, RNET_NAME]
102
103
104
def get_all_snets() -> Dict:
105
    model_dict = {}
106
    for name in ALL_SNETS:
107
        model_dict.update({name: get_catenet(name)()})
108
    return model_dict
109
110
111
def get_all_pseudoout_models() -> Dict:  # DR, RA, PW learner
112
    model_dict = {}
113
    for trans in ALL_PSEUDOOUT_MODELS:
114
        model_dict.update(
115
            {PSEUDOOUT_NAME + SEP + trans: PseudoOutcomeNet(transformation=trans)}
116
        )
117
    return model_dict
118
119
120
def get_all_twostep_models() -> Dict:  # DR, RA, R, X learner
121
    model_dict = {}
122
    for name in ALL_TWOSTEP_MODELS:
123
        model_dict.update({name: get_catenet(name)()})
124
    return model_dict