Switch to unified view

a b/scripts/Task 2 starting point.R
1
if (!require("grf")) install.packages("grf", dependencies = TRUE); library(grf)
2
if (!require("hdm")) install.packages("hdm", dependencies = TRUE); library(hdm)
3
if (!require("OutcomeWeights")) install.packages("OutcomeWeights", dependencies = TRUE); library(OutcomeWeights)
4
5
set.seed(1234)
6
7
data(pension)
8
9
D = pension$p401
10
Z = pension$e401
11
Y = pension$net_tfa
12
X = model.matrix(~ 0 + age + db + educ + fsize + hown + inc + male + marr + pira + twoearn, data = pension)
13
var_nm = c("Age","Benefit pension","Education","Family size","Home owner","Income","Male","Married","IRA","Two earners")
14
colnames(X) = var_nm
15
16
17
rf_Y.hat = regression_forest(X,Y)
18
Y.hat = predict(rf_Y.hat)$predictions
19
S = get_forest_weights(rf_Y.hat)
20
cf = causal_forest(X,Y,D,Y.hat=Y.hat)
21
22
D.hat = cf$W.hat
23
S.tau = get_outcome_weights(cf, S = S)
24
all.equal(as.numeric(S.tau$omega %*% Y),
25
          as.numeric(predict(cf)$predictions))
26
27
# ATE
28
## Package output
29
average_treatment_effect(cf, target.sample = "all") # Ignore the warning, which is overly cautious and paternalistic
30
31
## Rebuild using weights following Appendix A.3.2
32
lambda1 = D / D.hat
33
lambda0 = (1-D) / (1-D.hat)
34
N = length(Y)
35
ones = matrix(1,N,1)
36
37
### Horribly slow but close to formula (do not run)
38
# T_ate_slow = S.tau$omega + diag(lambda1 - lambda0) %*% (diag(N) - S - diag(as.numeric(D-D.hat)) %*% S.tau$omega) 
39
# omega_ate_slow = t(ones) %*% T_ate_slow / N
40
# omega_ate_slow %*% Y
41
42
### Faster avoiding slow matrix multiplications
43
scaled_S.tau = (D - D.hat) * S.tau$omega  # Element-wise multiplication
44
S_adjusted = diag(N) - S - scaled_S.tau
45
T_ate = S.tau$omega + (lambda1 - lambda0) * S_adjusted  # Element-wise multiplication
46
omega_ate = t(ones) %*% T_ate / N
47
48
## Check numerical equivalence also with package weights
49
omega_ate_pkg = get_outcome_weights(cf, target = "ATE", S = S, S.tau = S.tau$omega)
50
all.equal(as.numeric(omega_ate %*% Y),
51
          as.numeric(average_treatment_effect(cf, target.sample = "all")[1]),
52
          as.numeric(omega_ate_pkg$omega %*% Y))
53
54
# ATT
55
## Package output
56
average_treatment_effect(cf, target.sample = "treated")
57
58
## Rebuild using weights
59
60
## Check numerical equivalence
61
62
63
# ATU
64
## Package output
65
average_treatment_effect(cf, target.sample = "control") # Ignore the warning, which is overly cautious and paternalistic
66
67
## Rebuild using weights
68
69
## Check numerical equivalence
70
71
72
# Overlap
73
## Package output
74
average_treatment_effect(cf, target.sample = "overlap") # Ignore the warning, this is overly cautious and paternalistic
75
76
## Rebuild using weights
77
78
## Check numerical equivalence
79
80
81
# LATE
82
ivf = instrumental_forest(X,Y,D,Z,Y.hat=Y.hat)
83
omega_if = get_outcome_weights(ivf, S = S)
84
all.equal(as.numeric(omega_if$omega %*% Y),
85
          as.numeric(predict(ivf)$predictions))
86
87
## Package output
88
average_treatment_effect(ivf)
89
90
## Rebuild using weights
91
92
## Check numerical equivalence
93
94