[a29fce]: / scripts / OutcomeWeights_Extensions.R

Download this file

350 lines (273 with data), 12.4 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
if (!require("grf")) install.packages("grf", dependencies = TRUE); library(grf)
if (!require("hdm")) install.packages("hdm", dependencies = TRUE); library(hdm)
if (!require("OutcomeWeights")) install.packages("OutcomeWeights", dependencies = TRUE); library(OutcomeWeights)
set.seed(1234)
data(pension)
D = pension$p401
Z = pension$e401
Y = pension$net_tfa
X = model.matrix(~ 0 + age + db + educ + fsize + hown + inc + male + marr + pira + twoearn, data = pension)
var_nm = c("Age","Benefit pension","Education","Family size","Home owner","Income","Male","Married","IRA","Two earners")
colnames(X) = var_nm
rf_Y.hat = regression_forest(X,Y)
Y.hat = predict(rf_Y.hat)$predictions
S = get_forest_weights(rf_Y.hat)
cf = causal_forest(X,Y,D,Y.hat=Y.hat)
D.hat = cf$W.hat
S.tau = get_outcome_weights(cf, S = S)
all.equal(as.numeric(S.tau$omega %*% Y),
as.numeric(predict(cf)$predictions))
# ATE
## Package output
average_treatment_effect(cf, target.sample = "all") # Ignore the warning, which is overly cautious and paternalistic
## Rebuild using weights following Appendix A.3.2
lambda1 = D / D.hat
lambda0 = (1-D) / (1-D.hat)
N = length(Y)
ones = matrix(1,N,1)
### Horribly slow but close to formula (do not run)
# T_ate_slow = S.tau$omega + diag(lambda1 - lambda0) %*% (diag(N) - S - diag(as.numeric(D-D.hat)) %*% S.tau$omega)
# omega_ate_slow = t(ones) %*% T_ate_slow / N
# omega_ate_slow %*% Y
### Faster avoiding slow matrix multiplications
scaled_S.tau = (D - D.hat) * S.tau$omega # Element-wise multiplication
S_adjusted = diag(N) - S - scaled_S.tau
T_ate = S.tau$omega + (lambda1 - lambda0) * S_adjusted # Element-wise multiplication
omega_ate = t(ones) %*% T_ate / N
## Check numerical equivalence also with package weights
omega_ate_pkg = get_outcome_weights(cf, target = "ATE", S = S, S.tau = S.tau$omega)
all.equal(as.numeric(omega_ate %*% Y),
as.numeric(average_treatment_effect(cf, target.sample = "all")[1]),
as.numeric(omega_ate_pkg$omega %*% Y))
##############################
# ATT Calculation with Updated Structure
##############################
# Built-in "ATT" from the GRF package
att_cf <- average_treatment_effect(cf, target.sample = "treated")[1]
cat("\n--- GRF's built-in ATT ---\n")
cat("GRF ATT:", att_cf, "\n")
##############################
# Manual DR formula for ATT
##############################
n <- length(Y)
treated.idx <- which(D == 1)
control.idx <- which(D == 0)
# Forest pointwise CATE
tau.hat.pointwise <- predict(cf)$predictions
# "Raw" naive average among treated:
weights.all <- rep(1, n)
tau.avg.raw <- weighted.mean(tau.hat.pointwise[treated.idx], w = weights.all[treated.idx])
# Decomposition from cf
W.hat <- cf$W.hat
Y.hat <- cf$Y.hat
Y.hat0 <- Y.hat - W.hat * tau.hat.pointwise
Y.hat1 <- Y.hat + (1 - W.hat) * tau.hat.pointwise
# IPW-like gamma for ATT
# Treated: gamma=1
# Controls: gamma= p.hat / (1 - p.hat)
# Then rescale to sum(weights.all)
gamma <- rep(NA, n)
# Treated
gamma.treat.raw <- rep(1, length(treated.idx))
sum.treat <- sum(weights.all[treated.idx] * gamma.treat.raw)
scaled.treat.gamma <- gamma.treat.raw / sum.treat * sum(weights.all)
gamma[treated.idx] <- scaled.treat.gamma
# Controls
gamma.ctrl.raw <- W.hat[control.idx] / (1 - W.hat[control.idx])
sum.ctrl <- sum(weights.all[control.idx] * gamma.ctrl.raw)
scaled.ctrl.gamma <- gamma.ctrl.raw / sum.ctrl * sum(weights.all)
gamma[control.idx] <- scaled.ctrl.gamma
# DR correction for ATT (NOTICE THE MINUS SIGN on the control term!)
dr.corr.all <- D * gamma * (Y - Y.hat1) - (1 - D) * gamma * (Y - Y.hat0)
dr.corr <- weighted.mean(dr.corr.all, weights.all)
# Final manual DR ATT
tau.ATT.manual <- tau.avg.raw + dr.corr
cat("\n--- Manual DR (raw + correction) for ATT ---\n")
cat("Manual DR ATT:", tau.ATT.manual, "\n")
##############################
# Rebuild Using Weights (Smoothing Matrix)
##############################
# Extract the smoothing matrix from the causal forest
S.tau.treated <- S.tau$omega
S.tau.treated[D == 0, ] <- 0 # Zero out rows for controls
# Adjustment matrices for treated and control groups
N <- length(Y)
S_adjusted_treated <- diag(N) - S - (1 - W.hat) * S.tau$omega
S_adjusted_control <- diag(N) - S + W.hat * S.tau$omega
# Compute correction term for ATT weights
T_correction <- D * gamma * S_adjusted_treated -
(1 - D) * gamma * S_adjusted_control
# Final ATT weights computation
omega_att <- rep(1/sum(D), N) %*% S.tau.treated + rep(1/N, N) %*% T_correction
cat("\n--- Single-weight ATT check ---\n")
cat("Sum(omega_att): ", sum(omega_att), "\n") # Should be close to 1
cat("Sum(omega_att * Y): ", sum(omega_att * Y), "\n") # Should be close to tau.ATT.manual
cat("Manual DR ATT: ", tau.ATT.manual, "\n")
# Convert to numeric for plotting
omega_att <- as.numeric(omega_att)
# ATU
##############################
# ATU Calculation with Updated Structure
##############################
# Built-in "ATU" from the GRF package
atu_cf <- average_treatment_effect(cf, target.sample = "control")[1]
cat("\n--- GRF's built-in ATU ---\n")
cat("GRF ATU:", atu_cf, "\n")
##############################
# Manual DR formula for ATU
##############################
n <- length(Y)
treated.idx <- which(D == 1)
control.idx <- which(D == 0)
# Forest pointwise CATE
tau.hat.pointwise <- predict(cf)$predictions
# "Raw" naive average among controls:
weights.all <- rep(1, n)
tau.avg.raw <- weighted.mean(tau.hat.pointwise[control.idx], w = weights.all[control.idx])
# Decomposition from cf
W.hat <- cf$W.hat
Y.hat <- cf$Y.hat
Y.hat0 <- Y.hat - W.hat * tau.hat.pointwise
Y.hat1 <- Y.hat + (1 - W.hat) * tau.hat.pointwise
# IPW-like gamma for ATU
# Controls: gamma=1
# Treated: gamma= (1 - p.hat) / p.hat
# Then rescale to sum(weights.all)
gamma <- rep(NA, n)
# Controls (ATU target group)
gamma.ctrl.raw <- rep(1, length(control.idx))
sum.ctrl <- sum(weights.all[control.idx] * gamma.ctrl.raw)
scaled.ctrl.gamma <- gamma.ctrl.raw / sum.ctrl * sum(weights.all)
gamma[control.idx] <- scaled.ctrl.gamma
# Treated
gamma.treat.raw <- (1 - W.hat[treated.idx]) / W.hat[treated.idx]
sum.treat <- sum(weights.all[treated.idx] * gamma.treat.raw)
scaled.treat.gamma <- gamma.treat.raw / sum.treat * sum(weights.all)
gamma[treated.idx] <- scaled.treat.gamma
# DR correction for ATU (REVERSED SIGN from ATT calculation)
dr.corr.all <- D * gamma * (Y - Y.hat1) - (1 - D) * gamma * (Y - Y.hat0)
dr.corr <- weighted.mean(dr.corr.all, weights.all)
# Final manual DR ATU
tau.ATU.manual <- tau.avg.raw + dr.corr
cat("\n--- Manual DR (raw + correction) for ATU ---\n")
cat("Manual DR ATU:", tau.ATU.manual, "\n")
##############################
# Rebuild Using Weights (Smoothing Matrix)
##############################
# Extract the smoothing matrix from the causal forest
S.tau.control <- S.tau$omega
S.tau.control[D == 1, ] <- 0 # Zero out rows for treated
# Adjustment matrices for treated and control groups
N <- length(Y)
S_adjusted_treated <- diag(N) - S - (1 - W.hat) * S.tau$omega
S_adjusted_control <- diag(N) - S + W.hat * S.tau$omega
# Compute correction term for ATU weights (REVERSED SIGN)
T_correction <- D * gamma * S_adjusted_treated -
(1 - D) * gamma * S_adjusted_control
# Final ATU weights computation
omega_atu <- rep(1 / sum(1 - D), N) %*% S.tau.control + rep(1 / N, N) %*% T_correction
cat("\n--- Single-weight ATU check ---\n")
cat("Sum(omega_atu): ", sum(omega_atu), "\n") # Should be close to 1
cat("Sum(omega_atu * Y): ", sum(omega_atu * Y), "\n") # Should be close to tau.ATU.manual
cat("Manual DR ATU: ", tau.ATU.manual, "\n")
# Convert to numeric for plotting
omega_atu <- as.numeric(omega_atu)
# Overlap
## Package output
average_treatment_effect(cf, target.sample = "overlap")
# Rebuild using weights
N <- nrow(X)
ones <- matrix(1, N, 1)
diag_alpha <- diag(rep(1, N))
# Define residual maker matrix (M_alpha_N) for alpha = 1
M_alpha_N <- diag(N) - ones %*% solve(t(ones) %*% diag_alpha %*% ones) %*% t(ones) %*% diag_alpha
# V.hat = Actual treatment (D) - Predicted treatment (D.hat)
V.hat = D-D.hat
# S = get_forest_weights(rf_Y.hat)
omega_overlap <- solve(t(V.hat) %*% M_alpha_N %*% diag_alpha %*% V.hat) %*%
t(V.hat) %*% M_alpha_N %*% diag_alpha %*% (diag(N) - S)
omega_overlap <- as.numeric(omega_overlap)
# Package-calculated overlap ATE using causal_forest
overlap_pkg = as.numeric(average_treatment_effect(cf, target.sample = "overlap")[1])
# Reconstructed overlap ATE using calculated weights
overlap_rebuilt = omega_overlap %*% Y
## Check numerical equivalence
all.equal(as.numeric(overlap_rebuilt), as.numeric(overlap_pkg))
# LATE
ivf = instrumental_forest(X,Y,D,Z,Y.hat=Y.hat)
omega_if = get_outcome_weights(ivf, S = S)
all.equal(as.numeric(omega_if$omega %*% Y),
as.numeric(predict(ivf)$predictions))
## Package output
average_treatment_effect(ivf)
## Rebuild using weights
##############################################################################
# Step 1: "Naive" average of pointwise LATE predictions.
# This is analogous to mean( tau.hat.pointwise ).
##############################################################################
tau.hat.pointwise <- predict(ivf)$predictions # = \hat{\tau}^{ivf}(X_i)
N <- length(tau.hat.pointwise)
weights.all <- rep(1, N) # or sample weights if you have them
tau.avg.raw <- weighted.mean(tau.hat.pointwise, weights.all)
##############################################################################
# Step 2: Build the compliance score "Delta(x)".
##############################################################################
# For demonstration, we re-estimate a compliance forest:
# This is basically a "causal_forest" of D ~ Z with the same X,
# to predict E[D|X,Z], then subtract for Z=1 vs Z=0.
compliance.forest <- causal_forest(
X, D, Z,
Y.hat = ivf$W.hat, # "W.hat" in 'ivf' is E[D|X]
W.hat = ivf$Z.hat # "Z.hat" is E[Z|X], used to speed up fitting
)
compliance.score <- predict(compliance.forest)$predictions # = Delta(x_i)
##############################################################################
# Step 3: Construct the "AIPW" style correction:
# gamma_i * [ Y_i - (Y.hat_i + tau.hat_i * (D_i - W.hat_i)) ]
# Where gamma_i = (Z_i - Z.hat_i)/(Z.hat_i*(1-Z.hat_i)*Delta_i).
##############################################################################
Z.hat <- ivf$Z.hat # E[Z|X]
W.hat <- ivf$W.hat # E[D|X], used in the outcome residual
Y.hat <- ivf$Y.hat # E[Y|X], used in the outcome residual
gamma <- (Z - Z.hat) / (Z.hat * (1 - Z.hat)) / compliance.score
Y.residual <- Y - (Y.hat + tau.hat.pointwise * (D - W.hat))
dr.correction.all <- gamma * Y.residual
dr.correction <- weighted.mean(dr.correction.all, weights.all)
##############################################################################
# Step 4: Final LATE = naive average + DR correction
##############################################################################
tau.late.aipw <- tau.avg.raw + dr.correction
cat("Manual LATE (AIPW replication):", tau.late.aipw, "\n")
## Check numerical equivalence
late_pkg <- average_treatment_effect(ivf, compliance.score = compliance.score)
cat("Package LATE:", late_pkg[1], "\n")
all.equal(
as.numeric(tau.late.aipw),
as.numeric(late_pkg[1])
)
# We'll pick alpha_i = a * gamma_i + b, then solve for (a,b) so that:
# sum_i alpha_i = a * sum_i gamma_i + b * N = 1
# sum_i alpha_i Y_i = a * sum_i gamma_i Y_i + b * sum_i Y_i = tau.late.aipw
S1 <- sum(gamma) # sum of gamma_i
S2 <- sum(gamma * Y) # sum of gamma_i * Y_i
S3 <- sum(Y) # sum of Y_i
A <- matrix(c(S1, N, # [ S1 N ]
S2, S3), # [ S2 S3 ]
nrow = 2, byrow = TRUE)
rhs <- c(1, tau.late.aipw)
# Solve the 2x2 system
ab <- solve(A, rhs)
a <- ab[1]
b <- ab[2]
# Construct alpha
alpha <- a * gamma + b
cat("\n--- Single alpha vector check ---\n")
cat("Sum(alpha): ", sum(alpha), "\n") # should be 1
cat("Sum(alpha * Y): ", sum(alpha * Y), "\n") # should be LATE
cat("Alpha-based LATE:", sum(alpha * Y), "\n")
cat("Manual DR LATE: ", tau.late.aipw, "\n")
cat("Package LATE: ", late_pkg[1], "\n")
all.equal(as.numeric(alpha %*% Y),
as.numeric(tau.late.aipw),
as.numeric(late_pkg[1])
)