Download this file

125 lines (101 with data), 3.7 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""
Some utils for experiments
"""
# Author: Alicia Curth
from typing import Callable, Dict, Optional, Union
import jax.numpy as jnp
from catenets.models.jax import (
DRNET_NAME,
PSEUDOOUT_NAME,
RANET_NAME,
RNET_NAME,
SNET1_NAME,
SNET2_NAME,
SNET3_NAME,
SNET_NAME,
T_NAME,
XNET_NAME,
PseudoOutcomeNet,
get_catenet,
)
from catenets.models.jax.base import check_shape_1d_data
from catenets.models.jax.transformation_utils import (
DR_TRANSFORMATION,
PW_TRANSFORMATION,
RA_TRANSFORMATION,
)
SEP = "_"
def eval_mse_model(
inputs: jnp.ndarray,
targets: jnp.ndarray,
predict_fun: Callable,
params: jnp.ndarray,
) -> jnp.ndarray:
# evaluate the mse of a model given its function and params
preds = predict_fun(params, inputs)
return jnp.mean((preds - targets) ** 2)
def eval_mse(preds: jnp.ndarray, targets: jnp.ndarray) -> jnp.ndarray:
preds = check_shape_1d_data(preds)
targets = check_shape_1d_data(targets)
return jnp.mean((preds - targets) ** 2)
def eval_root_mse(cate_pred: jnp.ndarray, cate_true: jnp.ndarray) -> jnp.ndarray:
cate_true = check_shape_1d_data(cate_true)
cate_pred = check_shape_1d_data(cate_pred)
return jnp.sqrt(eval_mse(cate_pred, cate_true))
def eval_abs_error_ate(cate_pred: jnp.ndarray, cate_true: jnp.ndarray) -> jnp.ndarray:
cate_true = check_shape_1d_data(cate_true)
cate_pred = check_shape_1d_data(cate_pred)
return jnp.abs(jnp.mean(cate_pred) - jnp.mean(cate_true))
def get_model_set(
model_selection: Union[str, list] = "all", model_params: Optional[dict] = None
) -> Dict:
"""Helper function to retrieve a set of models"""
# get model selection
if type(model_selection) is str:
if model_selection == "snet":
models = get_all_snets()
elif model_selection == "pseudo":
models = get_all_pseudoout_models()
elif model_selection == "twostep":
models = get_all_twostep_models()
elif model_selection == "all":
models = dict(**get_all_snets(), **get_all_pseudoout_models())
else:
models = {model_selection: get_catenet(model_selection)()} # type: ignore
elif type(model_selection) is list:
models = {}
for model in model_selection:
models.update({model: get_catenet(model)()})
else:
raise ValueError("model_selection should be string or list.")
# set hyperparameters
if model_params is not None:
for model in models.values():
existing_params = model.get_params()
new_params = {
key: val
for key, val in model_params.items()
if key in existing_params.keys()
}
model.set_params(**new_params)
return models
ALL_SNETS = [T_NAME, SNET1_NAME, SNET2_NAME, SNET3_NAME, SNET_NAME]
ALL_PSEUDOOUT_MODELS = [DR_TRANSFORMATION, PW_TRANSFORMATION, RA_TRANSFORMATION]
ALL_TWOSTEP_MODELS = [DRNET_NAME, RANET_NAME, XNET_NAME, RNET_NAME]
def get_all_snets() -> Dict:
model_dict = {}
for name in ALL_SNETS:
model_dict.update({name: get_catenet(name)()})
return model_dict
def get_all_pseudoout_models() -> Dict: # DR, RA, PW learner
model_dict = {}
for trans in ALL_PSEUDOOUT_MODELS:
model_dict.update(
{PSEUDOOUT_NAME + SEP + trans: PseudoOutcomeNet(transformation=trans)}
)
return model_dict
def get_all_twostep_models() -> Dict: # DR, RA, R, X learner
model_dict = {}
for name in ALL_TWOSTEP_MODELS:
model_dict.update({name: get_catenet(name)()})
return model_dict