if (!require("grf")) install.packages("grf", dependencies = TRUE); library(grf)
if (!require("hdm")) install.packages("hdm", dependencies = TRUE); library(hdm)
if (!require("OutcomeWeights")) install.packages("OutcomeWeights", dependencies = TRUE); library(OutcomeWeights)
set.seed(1234)
data(pension)
D = pension$p401
Z = pension$e401
Y = pension$net_tfa
X = model.matrix(~ 0 + age + db + educ + fsize + hown + inc + male + marr + pira + twoearn, data = pension)
var_nm = c("Age","Benefit pension","Education","Family size","Home owner","Income","Male","Married","IRA","Two earners")
colnames(X) = var_nm
rf_Y.hat = regression_forest(X,Y)
Y.hat = predict(rf_Y.hat)$predictions
S = get_forest_weights(rf_Y.hat)
cf = causal_forest(X,Y,D,Y.hat=Y.hat)
D.hat = cf$W.hat
S.tau = get_outcome_weights(cf, S = S)
all.equal(as.numeric(S.tau$omega %*% Y),
as.numeric(predict(cf)$predictions))
# ATE
## Package output
average_treatment_effect(cf, target.sample = "all") # Ignore the warning, which is overly cautious and paternalistic
## Rebuild using weights following Appendix A.3.2
lambda1 = D / D.hat
lambda0 = (1-D) / (1-D.hat)
N = length(Y)
ones = matrix(1,N,1)
### Horribly slow but close to formula (do not run)
# T_ate_slow = S.tau$omega + diag(lambda1 - lambda0) %*% (diag(N) - S - diag(as.numeric(D-D.hat)) %*% S.tau$omega)
# omega_ate_slow = t(ones) %*% T_ate_slow / N
# omega_ate_slow %*% Y
### Faster avoiding slow matrix multiplications
scaled_S.tau = (D - D.hat) * S.tau$omega # Element-wise multiplication
S_adjusted = diag(N) - S - scaled_S.tau
T_ate = S.tau$omega + (lambda1 - lambda0) * S_adjusted # Element-wise multiplication
omega_ate = t(ones) %*% T_ate / N
## Check numerical equivalence also with package weights
omega_ate_pkg = get_outcome_weights(cf, target = "ATE", S = S, S.tau = S.tau$omega)
all.equal(as.numeric(omega_ate %*% Y),
as.numeric(average_treatment_effect(cf, target.sample = "all")[1]),
as.numeric(omega_ate_pkg$omega %*% Y))
##############################
# ATT Calculation with Updated Structure
##############################
# Built-in "ATT" from the GRF package
att_cf <- average_treatment_effect(cf, target.sample = "treated")[1]
cat("\n--- GRF's built-in ATT ---\n")
cat("GRF ATT:", att_cf, "\n")
##############################
# Manual DR formula for ATT
##############################
n <- length(Y)
treated.idx <- which(D == 1)
control.idx <- which(D == 0)
# Forest pointwise CATE
tau.hat.pointwise <- predict(cf)$predictions
# "Raw" naive average among treated:
weights.all <- rep(1, n)
tau.avg.raw <- weighted.mean(tau.hat.pointwise[treated.idx], w = weights.all[treated.idx])
# Decomposition from cf
W.hat <- cf$W.hat
Y.hat <- cf$Y.hat
Y.hat0 <- Y.hat - W.hat * tau.hat.pointwise
Y.hat1 <- Y.hat + (1 - W.hat) * tau.hat.pointwise
# IPW-like gamma for ATT
# Treated: gamma=1
# Controls: gamma= p.hat / (1 - p.hat)
# Then rescale to sum(weights.all)
gamma <- rep(NA, n)
# Treated
gamma.treat.raw <- rep(1, length(treated.idx))
sum.treat <- sum(weights.all[treated.idx] * gamma.treat.raw)
scaled.treat.gamma <- gamma.treat.raw / sum.treat * sum(weights.all)
gamma[treated.idx] <- scaled.treat.gamma
# Controls
gamma.ctrl.raw <- W.hat[control.idx] / (1 - W.hat[control.idx])
sum.ctrl <- sum(weights.all[control.idx] * gamma.ctrl.raw)
scaled.ctrl.gamma <- gamma.ctrl.raw / sum.ctrl * sum(weights.all)
gamma[control.idx] <- scaled.ctrl.gamma
# DR correction for ATT (NOTICE THE MINUS SIGN on the control term!)
dr.corr.all <- D * gamma * (Y - Y.hat1) - (1 - D) * gamma * (Y - Y.hat0)
dr.corr <- weighted.mean(dr.corr.all, weights.all)
# Final manual DR ATT
tau.ATT.manual <- tau.avg.raw + dr.corr
cat("\n--- Manual DR (raw + correction) for ATT ---\n")
cat("Manual DR ATT:", tau.ATT.manual, "\n")
##############################
# Rebuild Using Weights (Smoothing Matrix)
##############################
# Extract the smoothing matrix from the causal forest
S.tau.treated <- S.tau$omega
S.tau.treated[D == 0, ] <- 0 # Zero out rows for controls
# Adjustment matrices for treated and control groups
N <- length(Y)
S_adjusted_treated <- diag(N) - S - (1 - W.hat) * S.tau$omega
S_adjusted_control <- diag(N) - S + W.hat * S.tau$omega
# Compute correction term for ATT weights
T_correction <- D * gamma * S_adjusted_treated -
(1 - D) * gamma * S_adjusted_control
# Final ATT weights computation
omega_att <- rep(1/sum(D), N) %*% S.tau.treated + rep(1/N, N) %*% T_correction
cat("\n--- Single-weight ATT check ---\n")
cat("Sum(omega_att): ", sum(omega_att), "\n") # Should be close to 1
cat("Sum(omega_att * Y): ", sum(omega_att * Y), "\n") # Should be close to tau.ATT.manual
cat("Manual DR ATT: ", tau.ATT.manual, "\n")
# Convert to numeric for plotting
omega_att <- as.numeric(omega_att)
# ATU
##############################
# ATU Calculation with Updated Structure
##############################
# Built-in "ATU" from the GRF package
atu_cf <- average_treatment_effect(cf, target.sample = "control")[1]
cat("\n--- GRF's built-in ATU ---\n")
cat("GRF ATU:", atu_cf, "\n")
##############################
# Manual DR formula for ATU
##############################
n <- length(Y)
treated.idx <- which(D == 1)
control.idx <- which(D == 0)
# Forest pointwise CATE
tau.hat.pointwise <- predict(cf)$predictions
# "Raw" naive average among controls:
weights.all <- rep(1, n)
tau.avg.raw <- weighted.mean(tau.hat.pointwise[control.idx], w = weights.all[control.idx])
# Decomposition from cf
W.hat <- cf$W.hat
Y.hat <- cf$Y.hat
Y.hat0 <- Y.hat - W.hat * tau.hat.pointwise
Y.hat1 <- Y.hat + (1 - W.hat) * tau.hat.pointwise
# IPW-like gamma for ATU
# Controls: gamma=1
# Treated: gamma= (1 - p.hat) / p.hat
# Then rescale to sum(weights.all)
gamma <- rep(NA, n)
# Controls (ATU target group)
gamma.ctrl.raw <- rep(1, length(control.idx))
sum.ctrl <- sum(weights.all[control.idx] * gamma.ctrl.raw)
scaled.ctrl.gamma <- gamma.ctrl.raw / sum.ctrl * sum(weights.all)
gamma[control.idx] <- scaled.ctrl.gamma
# Treated
gamma.treat.raw <- (1 - W.hat[treated.idx]) / W.hat[treated.idx]
sum.treat <- sum(weights.all[treated.idx] * gamma.treat.raw)
scaled.treat.gamma <- gamma.treat.raw / sum.treat * sum(weights.all)
gamma[treated.idx] <- scaled.treat.gamma
# DR correction for ATU (REVERSED SIGN from ATT calculation)
dr.corr.all <- D * gamma * (Y - Y.hat1) - (1 - D) * gamma * (Y - Y.hat0)
dr.corr <- weighted.mean(dr.corr.all, weights.all)
# Final manual DR ATU
tau.ATU.manual <- tau.avg.raw + dr.corr
cat("\n--- Manual DR (raw + correction) for ATU ---\n")
cat("Manual DR ATU:", tau.ATU.manual, "\n")
##############################
# Rebuild Using Weights (Smoothing Matrix)
##############################
# Extract the smoothing matrix from the causal forest
S.tau.control <- S.tau$omega
S.tau.control[D == 1, ] <- 0 # Zero out rows for treated
# Adjustment matrices for treated and control groups
N <- length(Y)
S_adjusted_treated <- diag(N) - S - (1 - W.hat) * S.tau$omega
S_adjusted_control <- diag(N) - S + W.hat * S.tau$omega
# Compute correction term for ATU weights (REVERSED SIGN)
T_correction <- D * gamma * S_adjusted_treated -
(1 - D) * gamma * S_adjusted_control
# Final ATU weights computation
omega_atu <- rep(1 / sum(1 - D), N) %*% S.tau.control + rep(1 / N, N) %*% T_correction
cat("\n--- Single-weight ATU check ---\n")
cat("Sum(omega_atu): ", sum(omega_atu), "\n") # Should be close to 1
cat("Sum(omega_atu * Y): ", sum(omega_atu * Y), "\n") # Should be close to tau.ATU.manual
cat("Manual DR ATU: ", tau.ATU.manual, "\n")
# Convert to numeric for plotting
omega_atu <- as.numeric(omega_atu)
# Overlap
## Package output
average_treatment_effect(cf, target.sample = "overlap")
# Rebuild using weights
N <- nrow(X)
ones <- matrix(1, N, 1)
diag_alpha <- diag(rep(1, N))
# Define residual maker matrix (M_alpha_N) for alpha = 1
M_alpha_N <- diag(N) - ones %*% solve(t(ones) %*% diag_alpha %*% ones) %*% t(ones) %*% diag_alpha
# V.hat = Actual treatment (D) - Predicted treatment (D.hat)
V.hat = D-D.hat
# S = get_forest_weights(rf_Y.hat)
omega_overlap <- solve(t(V.hat) %*% M_alpha_N %*% diag_alpha %*% V.hat) %*%
t(V.hat) %*% M_alpha_N %*% diag_alpha %*% (diag(N) - S)
omega_overlap <- as.numeric(omega_overlap)
# Package-calculated overlap ATE using causal_forest
overlap_pkg = as.numeric(average_treatment_effect(cf, target.sample = "overlap")[1])
# Reconstructed overlap ATE using calculated weights
overlap_rebuilt = omega_overlap %*% Y
## Check numerical equivalence
all.equal(as.numeric(overlap_rebuilt), as.numeric(overlap_pkg))
# LATE
ivf = instrumental_forest(X,Y,D,Z,Y.hat=Y.hat)
omega_if = get_outcome_weights(ivf, S = S)
all.equal(as.numeric(omega_if$omega %*% Y),
as.numeric(predict(ivf)$predictions))
## Package output
average_treatment_effect(ivf)
## Rebuild using weights
##############################################################################
# Step 1: "Naive" average of pointwise LATE predictions.
# This is analogous to mean( tau.hat.pointwise ).
##############################################################################
tau.hat.pointwise <- predict(ivf)$predictions # = \hat{\tau}^{ivf}(X_i)
N <- length(tau.hat.pointwise)
weights.all <- rep(1, N) # or sample weights if you have them
tau.avg.raw <- weighted.mean(tau.hat.pointwise, weights.all)
##############################################################################
# Step 2: Build the compliance score "Delta(x)".
##############################################################################
# For demonstration, we re-estimate a compliance forest:
# This is basically a "causal_forest" of D ~ Z with the same X,
# to predict E[D|X,Z], then subtract for Z=1 vs Z=0.
compliance.forest <- causal_forest(
X, D, Z,
Y.hat = ivf$W.hat, # "W.hat" in 'ivf' is E[D|X]
W.hat = ivf$Z.hat # "Z.hat" is E[Z|X], used to speed up fitting
)
compliance.score <- predict(compliance.forest)$predictions # = Delta(x_i)
##############################################################################
# Step 3: Construct the "AIPW" style correction:
# gamma_i * [ Y_i - (Y.hat_i + tau.hat_i * (D_i - W.hat_i)) ]
# Where gamma_i = (Z_i - Z.hat_i)/(Z.hat_i*(1-Z.hat_i)*Delta_i).
##############################################################################
Z.hat <- ivf$Z.hat # E[Z|X]
W.hat <- ivf$W.hat # E[D|X], used in the outcome residual
Y.hat <- ivf$Y.hat # E[Y|X], used in the outcome residual
gamma <- (Z - Z.hat) / (Z.hat * (1 - Z.hat)) / compliance.score
Y.residual <- Y - (Y.hat + tau.hat.pointwise * (D - W.hat))
dr.correction.all <- gamma * Y.residual
dr.correction <- weighted.mean(dr.correction.all, weights.all)
##############################################################################
# Step 4: Final LATE = naive average + DR correction
##############################################################################
tau.late.aipw <- tau.avg.raw + dr.correction
cat("Manual LATE (AIPW replication):", tau.late.aipw, "\n")
## Check numerical equivalence
late_pkg <- average_treatment_effect(ivf, compliance.score = compliance.score)
cat("Package LATE:", late_pkg[1], "\n")
all.equal(
as.numeric(tau.late.aipw),
as.numeric(late_pkg[1])
)
# We'll pick alpha_i = a * gamma_i + b, then solve for (a,b) so that:
# sum_i alpha_i = a * sum_i gamma_i + b * N = 1
# sum_i alpha_i Y_i = a * sum_i gamma_i Y_i + b * sum_i Y_i = tau.late.aipw
S1 <- sum(gamma) # sum of gamma_i
S2 <- sum(gamma * Y) # sum of gamma_i * Y_i
S3 <- sum(Y) # sum of Y_i
A <- matrix(c(S1, N, # [ S1 N ]
S2, S3), # [ S2 S3 ]
nrow = 2, byrow = TRUE)
rhs <- c(1, tau.late.aipw)
# Solve the 2x2 system
ab <- solve(A, rhs)
a <- ab[1]
b <- ab[2]
# Construct alpha
alpha <- a * gamma + b
cat("\n--- Single alpha vector check ---\n")
cat("Sum(alpha): ", sum(alpha), "\n") # should be 1
cat("Sum(alpha * Y): ", sum(alpha * Y), "\n") # should be LATE
cat("Alpha-based LATE:", sum(alpha * Y), "\n")
cat("Manual DR LATE: ", tau.late.aipw, "\n")
cat("Package LATE: ", late_pkg[1], "\n")
all.equal(as.numeric(alpha %*% Y),
as.numeric(tau.late.aipw),
as.numeric(late_pkg[1])
)