--- a +++ b/scripts/Task 2 starting point.R @@ -0,0 +1,94 @@ +if (!require("grf")) install.packages("grf", dependencies = TRUE); library(grf) +if (!require("hdm")) install.packages("hdm", dependencies = TRUE); library(hdm) +if (!require("OutcomeWeights")) install.packages("OutcomeWeights", dependencies = TRUE); library(OutcomeWeights) + +set.seed(1234) + +data(pension) + +D = pension$p401 +Z = pension$e401 +Y = pension$net_tfa +X = model.matrix(~ 0 + age + db + educ + fsize + hown + inc + male + marr + pira + twoearn, data = pension) +var_nm = c("Age","Benefit pension","Education","Family size","Home owner","Income","Male","Married","IRA","Two earners") +colnames(X) = var_nm + + +rf_Y.hat = regression_forest(X,Y) +Y.hat = predict(rf_Y.hat)$predictions +S = get_forest_weights(rf_Y.hat) +cf = causal_forest(X,Y,D,Y.hat=Y.hat) + +D.hat = cf$W.hat +S.tau = get_outcome_weights(cf, S = S) +all.equal(as.numeric(S.tau$omega %*% Y), + as.numeric(predict(cf)$predictions)) + +# ATE +## Package output +average_treatment_effect(cf, target.sample = "all") # Ignore the warning, which is overly cautious and paternalistic + +## Rebuild using weights following Appendix A.3.2 +lambda1 = D / D.hat +lambda0 = (1-D) / (1-D.hat) +N = length(Y) +ones = matrix(1,N,1) + +### Horribly slow but close to formula (do not run) +# T_ate_slow = S.tau$omega + diag(lambda1 - lambda0) %*% (diag(N) - S - diag(as.numeric(D-D.hat)) %*% S.tau$omega) +# omega_ate_slow = t(ones) %*% T_ate_slow / N +# omega_ate_slow %*% Y + +### Faster avoiding slow matrix multiplications +scaled_S.tau = (D - D.hat) * S.tau$omega # Element-wise multiplication +S_adjusted = diag(N) - S - scaled_S.tau +T_ate = S.tau$omega + (lambda1 - lambda0) * S_adjusted # Element-wise multiplication +omega_ate = t(ones) %*% T_ate / N + +## Check numerical equivalence also with package weights +omega_ate_pkg = get_outcome_weights(cf, target = "ATE", S = S, S.tau = S.tau$omega) +all.equal(as.numeric(omega_ate %*% Y), + as.numeric(average_treatment_effect(cf, target.sample = "all")[1]), + as.numeric(omega_ate_pkg$omega %*% Y)) + +# ATT +## Package output +average_treatment_effect(cf, target.sample = "treated") + +## Rebuild using weights + +## Check numerical equivalence + + +# ATU +## Package output +average_treatment_effect(cf, target.sample = "control") # Ignore the warning, which is overly cautious and paternalistic + +## Rebuild using weights + +## Check numerical equivalence + + +# Overlap +## Package output +average_treatment_effect(cf, target.sample = "overlap") # Ignore the warning, this is overly cautious and paternalistic + +## Rebuild using weights + +## Check numerical equivalence + + +# LATE +ivf = instrumental_forest(X,Y,D,Z,Y.hat=Y.hat) +omega_if = get_outcome_weights(ivf, S = S) +all.equal(as.numeric(omega_if$omega %*% Y), + as.numeric(predict(ivf)$predictions)) + +## Package output +average_treatment_effect(ivf) + +## Rebuild using weights + +## Check numerical equivalence + +