--- a +++ b/scripts/OutcomeWeights_Extensions.R @@ -0,0 +1,349 @@ +if (!require("grf")) install.packages("grf", dependencies = TRUE); library(grf) +if (!require("hdm")) install.packages("hdm", dependencies = TRUE); library(hdm) +if (!require("OutcomeWeights")) install.packages("OutcomeWeights", dependencies = TRUE); library(OutcomeWeights) + +set.seed(1234) + +data(pension) + +D = pension$p401 +Z = pension$e401 +Y = pension$net_tfa +X = model.matrix(~ 0 + age + db + educ + fsize + hown + inc + male + marr + pira + twoearn, data = pension) +var_nm = c("Age","Benefit pension","Education","Family size","Home owner","Income","Male","Married","IRA","Two earners") +colnames(X) = var_nm + + +rf_Y.hat = regression_forest(X,Y) +Y.hat = predict(rf_Y.hat)$predictions +S = get_forest_weights(rf_Y.hat) +cf = causal_forest(X,Y,D,Y.hat=Y.hat) + +D.hat = cf$W.hat +S.tau = get_outcome_weights(cf, S = S) +all.equal(as.numeric(S.tau$omega %*% Y), + as.numeric(predict(cf)$predictions)) + +# ATE +## Package output +average_treatment_effect(cf, target.sample = "all") # Ignore the warning, which is overly cautious and paternalistic + +## Rebuild using weights following Appendix A.3.2 +lambda1 = D / D.hat +lambda0 = (1-D) / (1-D.hat) +N = length(Y) +ones = matrix(1,N,1) + +### Horribly slow but close to formula (do not run) +# T_ate_slow = S.tau$omega + diag(lambda1 - lambda0) %*% (diag(N) - S - diag(as.numeric(D-D.hat)) %*% S.tau$omega) +# omega_ate_slow = t(ones) %*% T_ate_slow / N +# omega_ate_slow %*% Y + +### Faster avoiding slow matrix multiplications +scaled_S.tau = (D - D.hat) * S.tau$omega # Element-wise multiplication +S_adjusted = diag(N) - S - scaled_S.tau +T_ate = S.tau$omega + (lambda1 - lambda0) * S_adjusted # Element-wise multiplication +omega_ate = t(ones) %*% T_ate / N + +## Check numerical equivalence also with package weights +omega_ate_pkg = get_outcome_weights(cf, target = "ATE", S = S, S.tau = S.tau$omega) +all.equal(as.numeric(omega_ate %*% Y), + as.numeric(average_treatment_effect(cf, target.sample = "all")[1]), + as.numeric(omega_ate_pkg$omega %*% Y)) + +############################## +# ATT Calculation with Updated Structure +############################## + +# Built-in "ATT" from the GRF package +att_cf <- average_treatment_effect(cf, target.sample = "treated")[1] +cat("\n--- GRF's built-in ATT ---\n") +cat("GRF ATT:", att_cf, "\n") + +############################## +# Manual DR formula for ATT +############################## +n <- length(Y) +treated.idx <- which(D == 1) +control.idx <- which(D == 0) + +# Forest pointwise CATE +tau.hat.pointwise <- predict(cf)$predictions + +# "Raw" naive average among treated: +weights.all <- rep(1, n) +tau.avg.raw <- weighted.mean(tau.hat.pointwise[treated.idx], w = weights.all[treated.idx]) + +# Decomposition from cf +W.hat <- cf$W.hat +Y.hat <- cf$Y.hat +Y.hat0 <- Y.hat - W.hat * tau.hat.pointwise +Y.hat1 <- Y.hat + (1 - W.hat) * tau.hat.pointwise + +# IPW-like gamma for ATT +# Treated: gamma=1 +# Controls: gamma= p.hat / (1 - p.hat) +# Then rescale to sum(weights.all) +gamma <- rep(NA, n) + +# Treated +gamma.treat.raw <- rep(1, length(treated.idx)) +sum.treat <- sum(weights.all[treated.idx] * gamma.treat.raw) +scaled.treat.gamma <- gamma.treat.raw / sum.treat * sum(weights.all) +gamma[treated.idx] <- scaled.treat.gamma + +# Controls +gamma.ctrl.raw <- W.hat[control.idx] / (1 - W.hat[control.idx]) +sum.ctrl <- sum(weights.all[control.idx] * gamma.ctrl.raw) +scaled.ctrl.gamma <- gamma.ctrl.raw / sum.ctrl * sum(weights.all) +gamma[control.idx] <- scaled.ctrl.gamma + +# DR correction for ATT (NOTICE THE MINUS SIGN on the control term!) +dr.corr.all <- D * gamma * (Y - Y.hat1) - (1 - D) * gamma * (Y - Y.hat0) +dr.corr <- weighted.mean(dr.corr.all, weights.all) + +# Final manual DR ATT +tau.ATT.manual <- tau.avg.raw + dr.corr + +cat("\n--- Manual DR (raw + correction) for ATT ---\n") +cat("Manual DR ATT:", tau.ATT.manual, "\n") + +############################## +# Rebuild Using Weights (Smoothing Matrix) +############################## + +# Extract the smoothing matrix from the causal forest +S.tau.treated <- S.tau$omega +S.tau.treated[D == 0, ] <- 0 # Zero out rows for controls + +# Adjustment matrices for treated and control groups +N <- length(Y) +S_adjusted_treated <- diag(N) - S - (1 - W.hat) * S.tau$omega +S_adjusted_control <- diag(N) - S + W.hat * S.tau$omega + +# Compute correction term for ATT weights +T_correction <- D * gamma * S_adjusted_treated - + (1 - D) * gamma * S_adjusted_control + +# Final ATT weights computation +omega_att <- rep(1/sum(D), N) %*% S.tau.treated + rep(1/N, N) %*% T_correction + +cat("\n--- Single-weight ATT check ---\n") +cat("Sum(omega_att): ", sum(omega_att), "\n") # Should be close to 1 +cat("Sum(omega_att * Y): ", sum(omega_att * Y), "\n") # Should be close to tau.ATT.manual +cat("Manual DR ATT: ", tau.ATT.manual, "\n") + +# Convert to numeric for plotting +omega_att <- as.numeric(omega_att) + +# ATU +############################## +# ATU Calculation with Updated Structure +############################## + +# Built-in "ATU" from the GRF package +atu_cf <- average_treatment_effect(cf, target.sample = "control")[1] +cat("\n--- GRF's built-in ATU ---\n") +cat("GRF ATU:", atu_cf, "\n") + +############################## +# Manual DR formula for ATU +############################## +n <- length(Y) +treated.idx <- which(D == 1) +control.idx <- which(D == 0) + +# Forest pointwise CATE +tau.hat.pointwise <- predict(cf)$predictions + +# "Raw" naive average among controls: +weights.all <- rep(1, n) +tau.avg.raw <- weighted.mean(tau.hat.pointwise[control.idx], w = weights.all[control.idx]) + +# Decomposition from cf +W.hat <- cf$W.hat +Y.hat <- cf$Y.hat +Y.hat0 <- Y.hat - W.hat * tau.hat.pointwise +Y.hat1 <- Y.hat + (1 - W.hat) * tau.hat.pointwise + +# IPW-like gamma for ATU +# Controls: gamma=1 +# Treated: gamma= (1 - p.hat) / p.hat +# Then rescale to sum(weights.all) +gamma <- rep(NA, n) + +# Controls (ATU target group) +gamma.ctrl.raw <- rep(1, length(control.idx)) +sum.ctrl <- sum(weights.all[control.idx] * gamma.ctrl.raw) +scaled.ctrl.gamma <- gamma.ctrl.raw / sum.ctrl * sum(weights.all) +gamma[control.idx] <- scaled.ctrl.gamma + +# Treated +gamma.treat.raw <- (1 - W.hat[treated.idx]) / W.hat[treated.idx] +sum.treat <- sum(weights.all[treated.idx] * gamma.treat.raw) +scaled.treat.gamma <- gamma.treat.raw / sum.treat * sum(weights.all) +gamma[treated.idx] <- scaled.treat.gamma + +# DR correction for ATU (REVERSED SIGN from ATT calculation) +dr.corr.all <- D * gamma * (Y - Y.hat1) - (1 - D) * gamma * (Y - Y.hat0) +dr.corr <- weighted.mean(dr.corr.all, weights.all) + +# Final manual DR ATU +tau.ATU.manual <- tau.avg.raw + dr.corr + +cat("\n--- Manual DR (raw + correction) for ATU ---\n") +cat("Manual DR ATU:", tau.ATU.manual, "\n") + +############################## +# Rebuild Using Weights (Smoothing Matrix) +############################## + +# Extract the smoothing matrix from the causal forest +S.tau.control <- S.tau$omega +S.tau.control[D == 1, ] <- 0 # Zero out rows for treated + +# Adjustment matrices for treated and control groups +N <- length(Y) +S_adjusted_treated <- diag(N) - S - (1 - W.hat) * S.tau$omega +S_adjusted_control <- diag(N) - S + W.hat * S.tau$omega + +# Compute correction term for ATU weights (REVERSED SIGN) +T_correction <- D * gamma * S_adjusted_treated - + (1 - D) * gamma * S_adjusted_control + +# Final ATU weights computation +omega_atu <- rep(1 / sum(1 - D), N) %*% S.tau.control + rep(1 / N, N) %*% T_correction + +cat("\n--- Single-weight ATU check ---\n") +cat("Sum(omega_atu): ", sum(omega_atu), "\n") # Should be close to 1 +cat("Sum(omega_atu * Y): ", sum(omega_atu * Y), "\n") # Should be close to tau.ATU.manual +cat("Manual DR ATU: ", tau.ATU.manual, "\n") + +# Convert to numeric for plotting +omega_atu <- as.numeric(omega_atu) + +# Overlap +## Package output +average_treatment_effect(cf, target.sample = "overlap") +# Rebuild using weights +N <- nrow(X) +ones <- matrix(1, N, 1) +diag_alpha <- diag(rep(1, N)) +# Define residual maker matrix (M_alpha_N) for alpha = 1 +M_alpha_N <- diag(N) - ones %*% solve(t(ones) %*% diag_alpha %*% ones) %*% t(ones) %*% diag_alpha +# V.hat = Actual treatment (D) - Predicted treatment (D.hat) +V.hat = D-D.hat +# S = get_forest_weights(rf_Y.hat) +omega_overlap <- solve(t(V.hat) %*% M_alpha_N %*% diag_alpha %*% V.hat) %*% + t(V.hat) %*% M_alpha_N %*% diag_alpha %*% (diag(N) - S) +omega_overlap <- as.numeric(omega_overlap) +# Package-calculated overlap ATE using causal_forest +overlap_pkg = as.numeric(average_treatment_effect(cf, target.sample = "overlap")[1]) +# Reconstructed overlap ATE using calculated weights +overlap_rebuilt = omega_overlap %*% Y +## Check numerical equivalence +all.equal(as.numeric(overlap_rebuilt), as.numeric(overlap_pkg)) + + + + + + +# LATE +ivf = instrumental_forest(X,Y,D,Z,Y.hat=Y.hat) +omega_if = get_outcome_weights(ivf, S = S) +all.equal(as.numeric(omega_if$omega %*% Y), + as.numeric(predict(ivf)$predictions)) + +## Package output +average_treatment_effect(ivf) + +## Rebuild using weights + +############################################################################## +# Step 1: "Naive" average of pointwise LATE predictions. +# This is analogous to mean( tau.hat.pointwise ). +############################################################################## +tau.hat.pointwise <- predict(ivf)$predictions # = \hat{\tau}^{ivf}(X_i) +N <- length(tau.hat.pointwise) +weights.all <- rep(1, N) # or sample weights if you have them +tau.avg.raw <- weighted.mean(tau.hat.pointwise, weights.all) + +############################################################################## +# Step 2: Build the compliance score "Delta(x)". +############################################################################## +# For demonstration, we re-estimate a compliance forest: +# This is basically a "causal_forest" of D ~ Z with the same X, +# to predict E[D|X,Z], then subtract for Z=1 vs Z=0. +compliance.forest <- causal_forest( + X, D, Z, + Y.hat = ivf$W.hat, # "W.hat" in 'ivf' is E[D|X] + W.hat = ivf$Z.hat # "Z.hat" is E[Z|X], used to speed up fitting +) +compliance.score <- predict(compliance.forest)$predictions # = Delta(x_i) + +############################################################################## +# Step 3: Construct the "AIPW" style correction: +# gamma_i * [ Y_i - (Y.hat_i + tau.hat_i * (D_i - W.hat_i)) ] +# Where gamma_i = (Z_i - Z.hat_i)/(Z.hat_i*(1-Z.hat_i)*Delta_i). +############################################################################## +Z.hat <- ivf$Z.hat # E[Z|X] +W.hat <- ivf$W.hat # E[D|X], used in the outcome residual +Y.hat <- ivf$Y.hat # E[Y|X], used in the outcome residual + +gamma <- (Z - Z.hat) / (Z.hat * (1 - Z.hat)) / compliance.score +Y.residual <- Y - (Y.hat + tau.hat.pointwise * (D - W.hat)) + +dr.correction.all <- gamma * Y.residual +dr.correction <- weighted.mean(dr.correction.all, weights.all) + +############################################################################## +# Step 4: Final LATE = naive average + DR correction +############################################################################## +tau.late.aipw <- tau.avg.raw + dr.correction + +cat("Manual LATE (AIPW replication):", tau.late.aipw, "\n") + + +## Check numerical equivalence + +late_pkg <- average_treatment_effect(ivf, compliance.score = compliance.score) +cat("Package LATE:", late_pkg[1], "\n") + +all.equal( + as.numeric(tau.late.aipw), + as.numeric(late_pkg[1]) +) + +# We'll pick alpha_i = a * gamma_i + b, then solve for (a,b) so that: +# sum_i alpha_i = a * sum_i gamma_i + b * N = 1 +# sum_i alpha_i Y_i = a * sum_i gamma_i Y_i + b * sum_i Y_i = tau.late.aipw + +S1 <- sum(gamma) # sum of gamma_i +S2 <- sum(gamma * Y) # sum of gamma_i * Y_i +S3 <- sum(Y) # sum of Y_i + +A <- matrix(c(S1, N, # [ S1 N ] + S2, S3), # [ S2 S3 ] + nrow = 2, byrow = TRUE) +rhs <- c(1, tau.late.aipw) + +# Solve the 2x2 system +ab <- solve(A, rhs) +a <- ab[1] +b <- ab[2] + +# Construct alpha +alpha <- a * gamma + b + +cat("\n--- Single alpha vector check ---\n") +cat("Sum(alpha): ", sum(alpha), "\n") # should be 1 +cat("Sum(alpha * Y): ", sum(alpha * Y), "\n") # should be LATE +cat("Alpha-based LATE:", sum(alpha * Y), "\n") +cat("Manual DR LATE: ", tau.late.aipw, "\n") +cat("Package LATE: ", late_pkg[1], "\n") + +all.equal(as.numeric(alpha %*% Y), + as.numeric(tau.late.aipw), + as.numeric(late_pkg[1]) + )