--- a
+++ b/scripts/OutcomeWeights_Extensions.R
@@ -0,0 +1,349 @@
+if (!require("grf")) install.packages("grf", dependencies = TRUE); library(grf)
+if (!require("hdm")) install.packages("hdm", dependencies = TRUE); library(hdm)
+if (!require("OutcomeWeights")) install.packages("OutcomeWeights", dependencies = TRUE); library(OutcomeWeights)
+
+set.seed(1234)
+
+data(pension)
+
+D = pension$p401
+Z = pension$e401
+Y = pension$net_tfa
+X = model.matrix(~ 0 + age + db + educ + fsize + hown + inc + male + marr + pira + twoearn, data = pension)
+var_nm = c("Age","Benefit pension","Education","Family size","Home owner","Income","Male","Married","IRA","Two earners")
+colnames(X) = var_nm
+
+
+rf_Y.hat = regression_forest(X,Y)
+Y.hat = predict(rf_Y.hat)$predictions
+S = get_forest_weights(rf_Y.hat)
+cf = causal_forest(X,Y,D,Y.hat=Y.hat)
+
+D.hat = cf$W.hat
+S.tau = get_outcome_weights(cf, S = S)
+all.equal(as.numeric(S.tau$omega %*% Y),
+          as.numeric(predict(cf)$predictions))
+
+# ATE
+## Package output
+average_treatment_effect(cf, target.sample = "all") # Ignore the warning, which is overly cautious and paternalistic
+
+## Rebuild using weights following Appendix A.3.2
+lambda1 = D / D.hat
+lambda0 = (1-D) / (1-D.hat)
+N = length(Y)
+ones = matrix(1,N,1)
+
+### Horribly slow but close to formula (do not run)
+# T_ate_slow = S.tau$omega + diag(lambda1 - lambda0) %*% (diag(N) - S - diag(as.numeric(D-D.hat)) %*% S.tau$omega) 
+# omega_ate_slow = t(ones) %*% T_ate_slow / N
+# omega_ate_slow %*% Y
+
+### Faster avoiding slow matrix multiplications
+scaled_S.tau = (D - D.hat) * S.tau$omega  # Element-wise multiplication
+S_adjusted = diag(N) - S - scaled_S.tau
+T_ate = S.tau$omega + (lambda1 - lambda0) * S_adjusted  # Element-wise multiplication
+omega_ate = t(ones) %*% T_ate / N
+
+## Check numerical equivalence also with package weights
+omega_ate_pkg = get_outcome_weights(cf, target = "ATE", S = S, S.tau = S.tau$omega)
+all.equal(as.numeric(omega_ate %*% Y),
+          as.numeric(average_treatment_effect(cf, target.sample = "all")[1]),
+          as.numeric(omega_ate_pkg$omega %*% Y))
+
+##############################
+# ATT Calculation with Updated Structure
+##############################
+
+# Built-in "ATT" from the GRF package
+att_cf <- average_treatment_effect(cf, target.sample = "treated")[1]
+cat("\n--- GRF's built-in ATT ---\n")
+cat("GRF ATT:", att_cf, "\n")
+
+##############################
+# Manual DR formula for ATT
+##############################
+n <- length(Y)
+treated.idx <- which(D == 1)
+control.idx <- which(D == 0)
+
+# Forest pointwise CATE
+tau.hat.pointwise <- predict(cf)$predictions
+
+# "Raw" naive average among treated:
+weights.all <- rep(1, n)
+tau.avg.raw <- weighted.mean(tau.hat.pointwise[treated.idx], w = weights.all[treated.idx])
+
+# Decomposition from cf
+W.hat  <- cf$W.hat
+Y.hat  <- cf$Y.hat
+Y.hat0 <- Y.hat - W.hat * tau.hat.pointwise
+Y.hat1 <- Y.hat + (1 - W.hat) * tau.hat.pointwise
+
+# IPW-like gamma for ATT
+#   Treated: gamma=1
+#   Controls: gamma= p.hat / (1 - p.hat)
+#   Then rescale to sum(weights.all)
+gamma <- rep(NA, n)
+
+# Treated
+gamma.treat.raw <- rep(1, length(treated.idx))
+sum.treat <- sum(weights.all[treated.idx] * gamma.treat.raw)
+scaled.treat.gamma <- gamma.treat.raw / sum.treat * sum(weights.all)
+gamma[treated.idx] <- scaled.treat.gamma
+
+# Controls
+gamma.ctrl.raw <- W.hat[control.idx] / (1 - W.hat[control.idx])
+sum.ctrl <- sum(weights.all[control.idx] * gamma.ctrl.raw)
+scaled.ctrl.gamma <- gamma.ctrl.raw / sum.ctrl * sum(weights.all)
+gamma[control.idx] <- scaled.ctrl.gamma
+
+# DR correction for ATT (NOTICE THE MINUS SIGN on the control term!)
+dr.corr.all <- D * gamma * (Y - Y.hat1)  -  (1 - D) * gamma * (Y - Y.hat0)
+dr.corr     <- weighted.mean(dr.corr.all, weights.all)
+
+# Final manual DR ATT
+tau.ATT.manual <- tau.avg.raw + dr.corr
+
+cat("\n--- Manual DR (raw + correction) for ATT ---\n")
+cat("Manual DR ATT:", tau.ATT.manual, "\n")
+
+##############################
+# Rebuild Using Weights (Smoothing Matrix)
+##############################
+
+# Extract the smoothing matrix from the causal forest
+S.tau.treated <- S.tau$omega
+S.tau.treated[D == 0, ] <- 0  # Zero out rows for controls
+
+# Adjustment matrices for treated and control groups
+N <- length(Y)
+S_adjusted_treated <- diag(N) - S - (1 - W.hat) * S.tau$omega
+S_adjusted_control <- diag(N) - S + W.hat * S.tau$omega
+
+# Compute correction term for ATT weights
+T_correction <- D * gamma * S_adjusted_treated -  
+  (1 - D) * gamma * S_adjusted_control  
+
+# Final ATT weights computation
+omega_att <- rep(1/sum(D), N) %*% S.tau.treated + rep(1/N, N) %*% T_correction
+
+cat("\n--- Single-weight ATT check ---\n")
+cat("Sum(omega_att):     ", sum(omega_att), "\n")          # Should be close to 1
+cat("Sum(omega_att * Y): ", sum(omega_att * Y), "\n")      # Should be close to tau.ATT.manual
+cat("Manual DR ATT:      ", tau.ATT.manual, "\n")
+
+# Convert to numeric for plotting
+omega_att <- as.numeric(omega_att)
+
+# ATU
+##############################
+# ATU Calculation with Updated Structure
+##############################
+
+# Built-in "ATU" from the GRF package
+atu_cf <- average_treatment_effect(cf, target.sample = "control")[1]
+cat("\n--- GRF's built-in ATU ---\n")
+cat("GRF ATU:", atu_cf, "\n")
+
+##############################
+# Manual DR formula for ATU
+##############################
+n <- length(Y)
+treated.idx <- which(D == 1)
+control.idx <- which(D == 0)
+
+# Forest pointwise CATE
+tau.hat.pointwise <- predict(cf)$predictions
+
+# "Raw" naive average among controls:
+weights.all <- rep(1, n)
+tau.avg.raw <- weighted.mean(tau.hat.pointwise[control.idx], w = weights.all[control.idx])
+
+# Decomposition from cf
+W.hat  <- cf$W.hat
+Y.hat  <- cf$Y.hat
+Y.hat0 <- Y.hat - W.hat * tau.hat.pointwise
+Y.hat1 <- Y.hat + (1 - W.hat) * tau.hat.pointwise
+
+# IPW-like gamma for ATU
+#   Controls: gamma=1
+#   Treated: gamma= (1 - p.hat) / p.hat
+#   Then rescale to sum(weights.all)
+gamma <- rep(NA, n)
+
+# Controls (ATU target group)
+gamma.ctrl.raw <- rep(1, length(control.idx))
+sum.ctrl <- sum(weights.all[control.idx] * gamma.ctrl.raw)
+scaled.ctrl.gamma <- gamma.ctrl.raw / sum.ctrl * sum(weights.all)
+gamma[control.idx] <- scaled.ctrl.gamma
+
+# Treated
+gamma.treat.raw <- (1 - W.hat[treated.idx]) / W.hat[treated.idx]
+sum.treat <- sum(weights.all[treated.idx] * gamma.treat.raw)
+scaled.treat.gamma <- gamma.treat.raw / sum.treat * sum(weights.all)
+gamma[treated.idx] <- scaled.treat.gamma
+
+# DR correction for ATU (REVERSED SIGN from ATT calculation)
+dr.corr.all <- D * gamma * (Y - Y.hat1)  -  (1 - D) * gamma * (Y - Y.hat0)
+dr.corr     <- weighted.mean(dr.corr.all, weights.all)
+
+# Final manual DR ATU
+tau.ATU.manual <- tau.avg.raw + dr.corr
+
+cat("\n--- Manual DR (raw + correction) for ATU ---\n")
+cat("Manual DR ATU:", tau.ATU.manual, "\n")
+
+##############################
+# Rebuild Using Weights (Smoothing Matrix)
+##############################
+
+# Extract the smoothing matrix from the causal forest
+S.tau.control <- S.tau$omega
+S.tau.control[D == 1, ] <- 0  # Zero out rows for treated
+
+# Adjustment matrices for treated and control groups
+N <- length(Y)
+S_adjusted_treated <- diag(N) - S - (1 - W.hat) * S.tau$omega
+S_adjusted_control <- diag(N) - S + W.hat * S.tau$omega
+
+# Compute correction term for ATU weights (REVERSED SIGN)
+T_correction <- D * gamma * S_adjusted_treated -  
+  (1 - D) * gamma * S_adjusted_control  
+
+# Final ATU weights computation
+omega_atu <- rep(1 / sum(1 - D), N) %*% S.tau.control + rep(1 / N, N) %*% T_correction
+
+cat("\n--- Single-weight ATU check ---\n")
+cat("Sum(omega_atu):     ", sum(omega_atu), "\n")          # Should be close to 1
+cat("Sum(omega_atu * Y): ", sum(omega_atu * Y), "\n")      # Should be close to tau.ATU.manual
+cat("Manual DR ATU:      ", tau.ATU.manual, "\n")
+
+# Convert to numeric for plotting
+omega_atu <- as.numeric(omega_atu)
+
+# Overlap
+## Package output
+average_treatment_effect(cf, target.sample = "overlap") 
+# Rebuild using weights
+N <- nrow(X)  
+ones <- matrix(1, N, 1)
+diag_alpha <- diag(rep(1, N))  
+# Define residual maker matrix (M_alpha_N) for alpha = 1
+M_alpha_N <- diag(N) - ones %*% solve(t(ones) %*% diag_alpha %*% ones) %*% t(ones) %*% diag_alpha
+# V.hat = Actual treatment (D) - Predicted treatment (D.hat)
+V.hat = D-D.hat
+# S = get_forest_weights(rf_Y.hat)
+omega_overlap <- solve(t(V.hat) %*% M_alpha_N %*% diag_alpha %*% V.hat) %*%
+  t(V.hat) %*% M_alpha_N %*% diag_alpha %*% (diag(N) - S)
+omega_overlap <- as.numeric(omega_overlap)
+# Package-calculated overlap ATE using causal_forest
+overlap_pkg = as.numeric(average_treatment_effect(cf, target.sample = "overlap")[1])
+# Reconstructed overlap ATE using calculated weights
+overlap_rebuilt = omega_overlap %*% Y
+## Check numerical equivalence
+all.equal(as.numeric(overlap_rebuilt), as.numeric(overlap_pkg))
+
+
+
+
+
+
+# LATE
+ivf = instrumental_forest(X,Y,D,Z,Y.hat=Y.hat)
+omega_if = get_outcome_weights(ivf, S = S)
+all.equal(as.numeric(omega_if$omega %*% Y),
+          as.numeric(predict(ivf)$predictions))
+
+## Package output
+average_treatment_effect(ivf)
+
+## Rebuild using weights
+
+##############################################################################
+# Step 1: "Naive" average of pointwise LATE predictions.
+#         This is analogous to   mean( tau.hat.pointwise ).
+##############################################################################
+tau.hat.pointwise <- predict(ivf)$predictions  # = \hat{\tau}^{ivf}(X_i)
+N <- length(tau.hat.pointwise)
+weights.all <- rep(1, N)                       # or sample weights if you have them
+tau.avg.raw <- weighted.mean(tau.hat.pointwise, weights.all)
+
+##############################################################################
+# Step 2: Build the compliance score "Delta(x)".
+##############################################################################
+# For demonstration, we re-estimate a compliance forest:
+#   This is basically a "causal_forest" of D ~ Z with the same X, 
+#   to predict E[D|X,Z], then subtract for Z=1 vs Z=0.
+compliance.forest <- causal_forest(
+  X, D, Z,
+  Y.hat = ivf$W.hat,  # "W.hat" in 'ivf' is E[D|X]
+  W.hat = ivf$Z.hat   # "Z.hat" is E[Z|X], used to speed up fitting
+)
+compliance.score <- predict(compliance.forest)$predictions  # = Delta(x_i)
+
+##############################################################################
+# Step 3: Construct the "AIPW" style correction:
+#         gamma_i * [ Y_i - (Y.hat_i + tau.hat_i * (D_i - W.hat_i)) ]
+#   Where gamma_i = (Z_i - Z.hat_i)/(Z.hat_i*(1-Z.hat_i)*Delta_i).
+##############################################################################
+Z.hat <- ivf$Z.hat    # E[Z|X]
+W.hat <- ivf$W.hat    # E[D|X], used in the outcome residual
+Y.hat <- ivf$Y.hat    # E[Y|X], used in the outcome residual
+
+gamma <- (Z - Z.hat) / (Z.hat * (1 - Z.hat)) / compliance.score
+Y.residual <- Y - (Y.hat + tau.hat.pointwise * (D - W.hat))
+
+dr.correction.all <- gamma * Y.residual
+dr.correction <- weighted.mean(dr.correction.all, weights.all)
+
+##############################################################################
+# Step 4: Final LATE = naive average + DR correction
+##############################################################################
+tau.late.aipw <- tau.avg.raw + dr.correction
+
+cat("Manual LATE (AIPW replication):", tau.late.aipw, "\n")
+
+
+## Check numerical equivalence
+
+late_pkg <- average_treatment_effect(ivf, compliance.score = compliance.score)
+cat("Package LATE:", late_pkg[1], "\n")
+
+all.equal(
+  as.numeric(tau.late.aipw),
+  as.numeric(late_pkg[1])
+)
+
+# We'll pick alpha_i = a * gamma_i + b, then solve for (a,b) so that:
+#   sum_i alpha_i = a * sum_i gamma_i + b * N  = 1
+#   sum_i alpha_i Y_i = a * sum_i gamma_i Y_i + b * sum_i Y_i  = tau.late.aipw
+
+S1 <- sum(gamma)         # sum of gamma_i
+S2 <- sum(gamma * Y)     # sum of gamma_i * Y_i
+S3 <- sum(Y)             # sum of Y_i
+
+A <- matrix(c(S1, N,  #   [ S1   N   ]
+              S2, S3), # [ S2   S3  ]
+            nrow = 2, byrow = TRUE)
+rhs <- c(1, tau.late.aipw)
+
+# Solve the 2x2 system
+ab <- solve(A, rhs)
+a <- ab[1]
+b <- ab[2]
+
+# Construct alpha
+alpha <- a * gamma + b
+
+cat("\n--- Single alpha vector check ---\n")
+cat("Sum(alpha):     ", sum(alpha), "\n")            # should be 1
+cat("Sum(alpha * Y): ", sum(alpha * Y), "\n")         # should be LATE
+cat("Alpha-based LATE:", sum(alpha * Y), "\n")
+cat("Manual DR LATE:   ", tau.late.aipw, "\n")
+cat("Package LATE:     ", late_pkg[1], "\n")
+
+all.equal(as.numeric(alpha %*% Y), 
+          as.numeric(tau.late.aipw), 
+          as.numeric(late_pkg[1])
+          )