Switch to side-by-side view

--- a
+++ b/catenets/models/jax/xnet.py
@@ -0,0 +1,508 @@
+"""
+Module implements X-learner from Kuenzel et al (2019) using NNs
+"""
+# Author: Alicia Curth
+from typing import Callable, Optional, Tuple
+
+import jax.numpy as jnp
+
+import catenets.logger as log
+from catenets.models.constants import (
+    DEFAULT_AVG_OBJECTIVE,
+    DEFAULT_BATCH_SIZE,
+    DEFAULT_LAYERS_OUT,
+    DEFAULT_LAYERS_OUT_T,
+    DEFAULT_LAYERS_R,
+    DEFAULT_LAYERS_R_T,
+    DEFAULT_N_ITER,
+    DEFAULT_N_ITER_MIN,
+    DEFAULT_N_ITER_PRINT,
+    DEFAULT_NONLIN,
+    DEFAULT_PATIENCE,
+    DEFAULT_PENALTY_L2,
+    DEFAULT_SEED,
+    DEFAULT_STEP_SIZE,
+    DEFAULT_STEP_SIZE_T,
+    DEFAULT_UNITS_OUT,
+    DEFAULT_UNITS_OUT_T,
+    DEFAULT_UNITS_R,
+    DEFAULT_UNITS_R_T,
+    DEFAULT_VAL_SPLIT,
+)
+from catenets.models.jax.base import BaseCATENet, train_output_net_only
+from catenets.models.jax.model_utils import check_shape_1d_data, check_X_is_np
+from catenets.models.jax.pseudo_outcome_nets import (  # same strategies as other nets
+    ALL_STRATEGIES,
+    FLEX_STRATEGY,
+    OFFSET_STRATEGY,
+    S1_STRATEGY,
+    S2_STRATEGY,
+    S3_STRATEGY,
+    S_STRATEGY,
+    T_STRATEGY,
+    predict_flextenet,
+    predict_offsetnet,
+    predict_snet,
+    predict_snet1,
+    predict_snet2,
+    predict_snet3,
+    predict_t_net,
+    train_flextenet,
+    train_offsetnet,
+    train_snet,
+    train_snet1,
+    train_snet2,
+    train_snet3,
+    train_tnet,
+)
+
+
+class XNet(BaseCATENet):
+    """
+    Class implements X-learner using NNs.
+
+    Parameters
+    ----------
+    weight_strategy: int, default None
+        Which strategy to use to weight the two CATE estimators in the second stage. weight_strategy
+        is coded as follows: for tau(x)=g(x)tau_0(x) + (1-g(x))tau_1(x) [eq 9, kuenzel et al (2019)]
+        weight_strategy=0 sets g(x)=0, weight_strategy=1 sets g(x)=1,
+        weight_strategy=None sets g(x)=pi(x) [propensity score],
+         weight_strategy=-1 sets g(x)=(1-pi(x))
+    binary_y: bool, default False
+        Whether the outcome is binary
+    n_layers_out: int
+        First stage Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)
+    n_units_out: int
+        First stage Number of hidden units in each hypothesis layer
+    n_layers_r: int
+        First stage Number of representation layers before hypothesis layers (distinction between
+        hypothesis layers and representation layers is made to match TARNet & SNets)
+    n_units_r: int
+        First stage Number of hidden units in each representation layer
+    n_layers_out_t: int
+        Second stage Number of hypothesis layers (n_layers_out x n_units_out + 1 x Dense layer)
+    n_units_out_t: int
+        Second stage Number of hidden units in each hypothesis layer
+    n_layers_r_t: int
+        Second stage Number of representation layers before hypothesis layers (distinction between
+        hypothesis layers and representation layers is made to match TARNet & SNets)
+    n_units_r_t: int
+        Second stage Number of hidden units in each representation layer
+    penalty_l2: float
+        First stage l2 (ridge) penalty
+    penalty_l2_t: float
+        Second stage l2 (ridge) penalty
+    step_size: float
+        First stage learning rate for optimizer
+    step_size_t: float
+        Second stage learning rate for optimizer
+    n_iter: int
+        Maximum number of iterations
+    batch_size: int
+        Batch size
+    val_split_prop: float
+        Proportion of samples used for validation split (can be 0)
+    early_stopping: bool, default True
+        Whether to use early stopping
+    patience: int
+        Number of iterations to wait before early stopping after decrease in validation loss
+    n_iter_min: int
+        Minimum number of iterations to go through before starting early stopping
+    n_iter_print: int
+        Number of iterations after which to print updates
+    seed: int
+        Seed used
+    nonlin: string, default 'elu'
+        Nonlinearity to use in NN
+    """
+
+    def __init__(
+        self,
+        weight_strategy: Optional[int] = None,
+        first_stage_strategy: str = T_STRATEGY,
+        first_stage_args: Optional[dict] = None,
+        binary_y: bool = False,
+        n_layers_out: int = DEFAULT_LAYERS_OUT,
+        n_layers_r: int = DEFAULT_LAYERS_R,
+        n_layers_out_t: int = DEFAULT_LAYERS_OUT_T,
+        n_layers_r_t: int = DEFAULT_LAYERS_R_T,
+        n_units_out: int = DEFAULT_UNITS_OUT,
+        n_units_r: int = DEFAULT_UNITS_R,
+        n_units_out_t: int = DEFAULT_UNITS_OUT_T,
+        n_units_r_t: int = DEFAULT_UNITS_R_T,
+        penalty_l2: float = DEFAULT_PENALTY_L2,
+        penalty_l2_t: float = DEFAULT_PENALTY_L2,
+        step_size: float = DEFAULT_STEP_SIZE,
+        step_size_t: float = DEFAULT_STEP_SIZE_T,
+        n_iter: int = DEFAULT_N_ITER,
+        batch_size: int = DEFAULT_BATCH_SIZE,
+        n_iter_min: int = DEFAULT_N_ITER_MIN,
+        val_split_prop: float = DEFAULT_VAL_SPLIT,
+        early_stopping: bool = True,
+        patience: int = DEFAULT_PATIENCE,
+        n_iter_print: int = DEFAULT_N_ITER_PRINT,
+        seed: int = DEFAULT_SEED,
+        nonlin: str = DEFAULT_NONLIN,
+    ):
+        # settings
+        self.weight_strategy = weight_strategy
+        self.first_stage_strategy = first_stage_strategy
+        self.first_stage_args = first_stage_args
+        self.binary_y = binary_y
+
+        # model architecture hyperparams
+        self.n_layers_out = n_layers_out
+        self.n_layers_out_t = n_layers_out_t
+        self.n_layers_r = n_layers_r
+        self.n_layers_r_t = n_layers_r_t
+        self.n_units_out = n_units_out
+        self.n_units_out_t = n_units_out_t
+        self.n_units_r = n_units_r
+        self.n_units_r_t = n_units_r_t
+        self.nonlin = nonlin
+
+        # other hyperparameters
+        self.penalty_l2 = penalty_l2
+        self.penalty_l2_t = penalty_l2_t
+        self.step_size = step_size
+        self.step_size_t = step_size_t
+        self.n_iter = n_iter
+        self.batch_size = batch_size
+        self.n_iter_print = n_iter_print
+        self.seed = seed
+        self.val_split_prop = val_split_prop
+        self.early_stopping = early_stopping
+        self.patience = patience
+        self.n_iter_min = n_iter_min
+
+    def _get_train_function(self) -> Callable:
+        return train_x_net
+
+    def _get_predict_function(self) -> Callable:
+        # Two step nets do not need this
+        return predict_x_net
+
+    def predict(
+        self, X: jnp.ndarray, return_po: bool = False, return_prop: bool = False
+    ) -> jnp.ndarray:
+        """
+        Predict treatment effect estimates using a CATENet. Depending on method, can also return
+        potential outcome estimate and propensity score estimate.
+
+        Parameters
+        ----------
+        X: pd.DataFrame or np.array
+            Covariate matrix
+        return_po: bool, default False
+            Whether to return potential outcome estimate
+        return_prop: bool, default False
+            Whether to return propensity estimate
+
+        Returns
+        -------
+        array of CATE estimates, optionally also potential outcomes and propensity
+        """
+        X = check_X_is_np(X)
+        predict_func = self._get_predict_function()
+        return predict_func(
+            X,
+            trained_params=self._params,
+            predict_funs=self._predict_funs,
+            return_po=return_po,
+            return_prop=return_prop,
+            weight_strategy=self.weight_strategy,
+        )
+
+
+def train_x_net(
+    X: jnp.ndarray,
+    y: jnp.ndarray,
+    w: jnp.ndarray,
+    weight_strategy: Optional[int] = None,
+    first_stage_strategy: str = T_STRATEGY,
+    first_stage_args: Optional[dict] = None,
+    binary_y: bool = False,
+    n_layers_out: int = DEFAULT_LAYERS_OUT,
+    n_layers_r: int = DEFAULT_LAYERS_R,
+    n_layers_out_t: int = DEFAULT_LAYERS_OUT_T,
+    n_layers_r_t: int = DEFAULT_LAYERS_R_T,
+    n_units_out: int = DEFAULT_UNITS_OUT,
+    n_units_r: int = DEFAULT_UNITS_R,
+    n_units_out_t: int = DEFAULT_UNITS_OUT_T,
+    n_units_r_t: int = DEFAULT_UNITS_R_T,
+    penalty_l2: float = DEFAULT_PENALTY_L2,
+    penalty_l2_t: float = DEFAULT_PENALTY_L2,
+    step_size: float = DEFAULT_STEP_SIZE,
+    step_size_t: float = DEFAULT_STEP_SIZE_T,
+    n_iter: int = DEFAULT_N_ITER,
+    batch_size: int = DEFAULT_BATCH_SIZE,
+    n_iter_min: int = DEFAULT_N_ITER_MIN,
+    val_split_prop: float = DEFAULT_VAL_SPLIT,
+    early_stopping: bool = True,
+    patience: int = DEFAULT_PATIENCE,
+    n_iter_print: int = DEFAULT_N_ITER_PRINT,
+    seed: int = DEFAULT_SEED,
+    nonlin: str = DEFAULT_NONLIN,
+    return_val_loss: bool = False,
+    avg_objective: bool = DEFAULT_AVG_OBJECTIVE,
+) -> Tuple:
+    y = check_shape_1d_data(y)
+    if len(w.shape) > 1:
+        w = w.reshape((len(w),))
+
+    if weight_strategy not in [0, 1, -1, None]:
+        # weight_strategy is coded as follows:
+        # for tau(x)=g(x)tau_0(x) + (1-g(x))tau_1(x) [eq 9, kuenzel et al (2019)]
+        # weight_strategy=0 sets g(x)=0, weight_strategy=1 sets g(x)=1,
+        # weight_strategy=None sets g(x)=pi(x) [propensity score],
+        # weight_strategy=-1 sets g(x)=(1-pi(x))
+        raise ValueError("XNet only implements weight_strategy in [0, 1, -1, None]")
+
+    if first_stage_strategy not in ALL_STRATEGIES:
+        raise ValueError(
+            "Parameter first stage should be in "
+            "catenets.models.twostep_nets.ALL_STRATEGIES. "
+            "You passed {}".format(first_stage_strategy)
+        )
+
+    # first stage: get estimates of PO regression
+    log.debug("Training first stage")
+
+    mu_hat_0, mu_hat_1 = _get_first_stage_pos(
+        X,
+        y,
+        w,
+        binary_y=binary_y,
+        n_layers_out=n_layers_out,
+        n_units_out=n_units_out,
+        n_layers_r=n_layers_r,
+        n_units_r=n_units_r,
+        penalty_l2=penalty_l2,
+        step_size=step_size,
+        n_iter=n_iter,
+        batch_size=batch_size,
+        val_split_prop=val_split_prop,
+        early_stopping=early_stopping,
+        patience=patience,
+        n_iter_min=n_iter_min,
+        n_iter_print=n_iter_print,
+        seed=seed,
+        nonlin=nonlin,
+        avg_objective=avg_objective,
+        first_stage_strategy=first_stage_strategy,
+        first_stage_args=first_stage_args,
+    )
+
+    if weight_strategy is None or weight_strategy == -1:
+        # also fit propensity estimator
+        log.debug("Training propensity net")
+        params_prop, predict_fun_prop = train_output_net_only(
+            X,
+            w,
+            binary_y=True,
+            n_layers_out=n_layers_out,
+            n_units_out=n_units_out,
+            n_layers_r=n_layers_r,
+            n_units_r=n_units_r,
+            penalty_l2=penalty_l2,
+            step_size=step_size,
+            n_iter=n_iter,
+            batch_size=batch_size,
+            val_split_prop=val_split_prop,
+            early_stopping=early_stopping,
+            patience=patience,
+            n_iter_min=n_iter_min,
+            n_iter_print=n_iter_print,
+            seed=seed,
+            nonlin=nonlin,
+            avg_objective=avg_objective,
+        )
+
+    else:
+        params_prop, predict_fun_prop = None, None
+
+    # second stage
+    log.debug("Training second stage")
+    if not weight_strategy == 0:
+        # fit tau_0
+        log.debug("Fitting tau_0")
+        pseudo_outcome0 = mu_hat_1 - y[w == 0]
+        params_tau0, predict_fun_tau0 = train_output_net_only(
+            X[w == 0],
+            pseudo_outcome0,
+            binary_y=False,
+            n_layers_out=n_layers_out_t,
+            n_units_out=n_units_out_t,
+            n_layers_r=n_layers_r_t,
+            n_units_r=n_units_r_t,
+            penalty_l2=penalty_l2_t,
+            step_size=step_size_t,
+            n_iter=n_iter,
+            batch_size=batch_size,
+            val_split_prop=val_split_prop,
+            early_stopping=early_stopping,
+            patience=patience,
+            n_iter_min=n_iter_min,
+            n_iter_print=n_iter_print,
+            seed=seed,
+            return_val_loss=return_val_loss,
+            nonlin=nonlin,
+            avg_objective=avg_objective,
+        )
+    else:
+        params_tau0, predict_fun_tau0 = None, None
+
+    if not weight_strategy == 1:
+        # fit tau_1
+        log.debug("Fitting tau_1")
+        pseudo_outcome1 = y[w == 1] - mu_hat_0
+        params_tau1, predict_fun_tau1 = train_output_net_only(
+            X[w == 1],
+            pseudo_outcome1,
+            binary_y=False,
+            n_layers_out=n_layers_out_t,
+            n_units_out=n_units_out_t,
+            n_layers_r=n_layers_r_t,
+            n_units_r=n_units_r_t,
+            penalty_l2=penalty_l2_t,
+            step_size=step_size_t,
+            n_iter=n_iter,
+            batch_size=batch_size,
+            val_split_prop=val_split_prop,
+            early_stopping=early_stopping,
+            patience=patience,
+            n_iter_min=n_iter_min,
+            n_iter_print=n_iter_print,
+            seed=seed,
+            return_val_loss=return_val_loss,
+            nonlin=nonlin,
+            avg_objective=avg_objective,
+        )
+
+    else:
+        params_tau1, predict_fun_tau1 = None, None
+
+    params = params_tau0, params_tau1, params_prop
+    predict_funs = predict_fun_tau0, predict_fun_tau1, predict_fun_prop
+
+    return params, predict_funs
+
+
+def _get_first_stage_pos(
+    X: jnp.ndarray,
+    y: jnp.ndarray,
+    w: jnp.ndarray,
+    first_stage_strategy: str = T_STRATEGY,
+    first_stage_args: Optional[dict] = None,
+    binary_y: bool = False,
+    n_layers_out: int = DEFAULT_LAYERS_OUT,
+    n_layers_r: int = DEFAULT_LAYERS_R,
+    n_units_out: int = DEFAULT_UNITS_OUT,
+    n_units_r: int = DEFAULT_UNITS_R,
+    penalty_l2: float = DEFAULT_PENALTY_L2,
+    step_size: float = DEFAULT_STEP_SIZE,
+    n_iter: int = DEFAULT_N_ITER,
+    batch_size: int = DEFAULT_BATCH_SIZE,
+    n_iter_min: int = DEFAULT_N_ITER_MIN,
+    val_split_prop: float = DEFAULT_VAL_SPLIT,
+    early_stopping: bool = True,
+    patience: int = DEFAULT_PATIENCE,
+    n_iter_print: int = DEFAULT_N_ITER_PRINT,
+    seed: int = DEFAULT_SEED,
+    nonlin: str = DEFAULT_NONLIN,
+    avg_objective: bool = DEFAULT_AVG_OBJECTIVE,
+) -> Tuple[jnp.ndarray, jnp.ndarray]:
+    if first_stage_args is None:
+        first_stage_args = {}
+
+    train_fun: Callable
+    predict_fun: Callable
+
+    if first_stage_strategy == T_STRATEGY:
+        train_fun, predict_fun = train_tnet, predict_t_net
+    elif first_stage_strategy == S_STRATEGY:
+        train_fun, predict_fun = train_snet, predict_snet
+    elif first_stage_strategy == S1_STRATEGY:
+        train_fun, predict_fun = train_snet1, predict_snet1
+    elif first_stage_strategy == S2_STRATEGY:
+        train_fun, predict_fun = train_snet2, predict_snet2
+    elif first_stage_strategy == S3_STRATEGY:
+        train_fun, predict_fun = train_snet3, predict_snet3
+    elif first_stage_strategy == OFFSET_STRATEGY:
+        train_fun, predict_fun = train_offsetnet, predict_offsetnet
+    elif first_stage_strategy == FLEX_STRATEGY:
+        train_fun, predict_fun = train_flextenet, predict_flextenet
+
+    trained_params, pred_fun = train_fun(
+        X,
+        y,
+        w,
+        binary_y=binary_y,
+        n_layers_r=n_layers_r,
+        n_units_r=n_units_r,
+        n_layers_out=n_layers_out,
+        n_units_out=n_units_out,
+        penalty_l2=penalty_l2,
+        step_size=step_size,
+        n_iter=n_iter,
+        batch_size=batch_size,
+        val_split_prop=val_split_prop,
+        early_stopping=early_stopping,
+        patience=patience,
+        n_iter_min=n_iter_min,
+        n_iter_print=n_iter_print,
+        seed=seed,
+        nonlin=nonlin,
+        avg_objective=avg_objective,
+        **first_stage_args
+    )
+
+    _, mu_0, mu_1 = predict_fun(X, trained_params, pred_fun, return_po=True)
+
+    return mu_0[w == 1], mu_1[w == 0]
+
+
+def predict_x_net(
+    X: jnp.ndarray,
+    trained_params: dict,
+    predict_funs: list,
+    return_po: bool = False,
+    return_prop: bool = False,
+    weight_strategy: Optional[int] = None,
+) -> jnp.ndarray:
+    if return_po:
+        raise NotImplementedError("TwoStepNets have no Potential outcome predictors.")
+
+    if return_prop:
+        raise NotImplementedError("TwoStepNets have no Propensity predictors.")
+
+    params_tau0, params_tau1, params_prop = trained_params
+    predict_fun_tau0, predict_fun_tau1, predict_fun_prop = predict_funs
+
+    tau0_pred: jnp.ndarray
+    tau1_pred: jnp.ndarray
+
+    if not weight_strategy == 0:
+        tau0_pred = predict_fun_tau0(params_tau0, X)
+    else:
+        tau0_pred = 0
+
+    if not weight_strategy == 1:
+        tau1_pred = predict_fun_tau1(params_tau1, X)
+    else:
+        tau1_pred = 0
+
+    if weight_strategy is None or weight_strategy == -1:
+        prop_pred = predict_fun_prop(params_prop, X)
+
+    if weight_strategy is None:
+        weight = prop_pred
+    elif weight_strategy == -1:
+        weight = 1 - prop_pred
+    elif weight_strategy == 0:
+        weight = 0
+    elif weight_strategy == 1:
+        weight = 1
+
+    return weight * tau0_pred + (1 - weight) * tau1_pred