--- a +++ b/catenets/experiment_utils/tester.py @@ -0,0 +1,73 @@ +# stdlib +import copy +from typing import Any, Tuple + +# third party +import numpy as np +import torch +from sklearn.model_selection import KFold, StratifiedKFold + +from catenets.experiment_utils.torch_metrics import abs_error_ATE, sqrt_PEHE + + +def generate_score(metric: np.ndarray) -> Tuple[float, float]: + percentile_val = 1.96 + return (np.mean(metric), percentile_val * np.std(metric) / np.sqrt(len(metric))) + + +def print_score(score: Tuple[float, float]) -> str: + return str(round(score[0], 4)) + " +/- " + str(round(score[1], 4)) + + +def evaluate_treatments_model( + estimator: Any, + X: torch.Tensor, + Y: torch.Tensor, + Y_full: torch.Tensor, + W: torch.Tensor, + n_folds: int = 3, + seed: int = 0, +) -> dict: + metric_pehe = np.zeros(n_folds) + metric_ate = np.zeros(n_folds) + + indx = 0 + if len(np.unique(Y)) == 2: + skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=seed) + else: + skf = KFold(n_splits=n_folds, shuffle=True, random_state=seed) + + for train_index, test_index in skf.split(X, Y): + + X_train = X[train_index] + Y_train = Y[train_index] + W_train = W[train_index] + + X_test = X[test_index] + Y_full_test = Y_full[test_index] + + model = copy.deepcopy(estimator) + model.fit(X_train, Y_train, W_train) + + try: + te_pred = model.predict(X_test).detach().cpu().numpy() + except BaseException: + te_pred = np.asarray(model.predict(X_test)) + + metric_ate[indx] = abs_error_ATE(Y_full_test, te_pred) + metric_pehe[indx] = sqrt_PEHE(Y_full_test, te_pred) + indx += 1 + + output_pehe = generate_score(metric_pehe) + output_ate = generate_score(metric_ate) + + return { + "raw": { + "pehe": output_pehe, + "ate": output_ate, + }, + "str": { + "pehe": print_score(output_pehe), + "ate": print_score(output_ate), + }, + }