--- a
+++ b/scripts/Task 2 starting point.R
@@ -0,0 +1,94 @@
+if (!require("grf")) install.packages("grf", dependencies = TRUE); library(grf)
+if (!require("hdm")) install.packages("hdm", dependencies = TRUE); library(hdm)
+if (!require("OutcomeWeights")) install.packages("OutcomeWeights", dependencies = TRUE); library(OutcomeWeights)
+
+set.seed(1234)
+
+data(pension)
+
+D = pension$p401
+Z = pension$e401
+Y = pension$net_tfa
+X = model.matrix(~ 0 + age + db + educ + fsize + hown + inc + male + marr + pira + twoearn, data = pension)
+var_nm = c("Age","Benefit pension","Education","Family size","Home owner","Income","Male","Married","IRA","Two earners")
+colnames(X) = var_nm
+
+
+rf_Y.hat = regression_forest(X,Y)
+Y.hat = predict(rf_Y.hat)$predictions
+S = get_forest_weights(rf_Y.hat)
+cf = causal_forest(X,Y,D,Y.hat=Y.hat)
+
+D.hat = cf$W.hat
+S.tau = get_outcome_weights(cf, S = S)
+all.equal(as.numeric(S.tau$omega %*% Y),
+          as.numeric(predict(cf)$predictions))
+
+# ATE
+## Package output
+average_treatment_effect(cf, target.sample = "all") # Ignore the warning, which is overly cautious and paternalistic
+
+## Rebuild using weights following Appendix A.3.2
+lambda1 = D / D.hat
+lambda0 = (1-D) / (1-D.hat)
+N = length(Y)
+ones = matrix(1,N,1)
+
+### Horribly slow but close to formula (do not run)
+# T_ate_slow = S.tau$omega + diag(lambda1 - lambda0) %*% (diag(N) - S - diag(as.numeric(D-D.hat)) %*% S.tau$omega) 
+# omega_ate_slow = t(ones) %*% T_ate_slow / N
+# omega_ate_slow %*% Y
+
+### Faster avoiding slow matrix multiplications
+scaled_S.tau = (D - D.hat) * S.tau$omega  # Element-wise multiplication
+S_adjusted = diag(N) - S - scaled_S.tau
+T_ate = S.tau$omega + (lambda1 - lambda0) * S_adjusted  # Element-wise multiplication
+omega_ate = t(ones) %*% T_ate / N
+
+## Check numerical equivalence also with package weights
+omega_ate_pkg = get_outcome_weights(cf, target = "ATE", S = S, S.tau = S.tau$omega)
+all.equal(as.numeric(omega_ate %*% Y),
+          as.numeric(average_treatment_effect(cf, target.sample = "all")[1]),
+          as.numeric(omega_ate_pkg$omega %*% Y))
+
+# ATT
+## Package output
+average_treatment_effect(cf, target.sample = "treated")
+
+## Rebuild using weights
+
+## Check numerical equivalence
+
+
+# ATU
+## Package output
+average_treatment_effect(cf, target.sample = "control") # Ignore the warning, which is overly cautious and paternalistic
+
+## Rebuild using weights
+
+## Check numerical equivalence
+
+
+# Overlap
+## Package output
+average_treatment_effect(cf, target.sample = "overlap") # Ignore the warning, this is overly cautious and paternalistic
+
+## Rebuild using weights
+
+## Check numerical equivalence
+
+
+# LATE
+ivf = instrumental_forest(X,Y,D,Z,Y.hat=Y.hat)
+omega_if = get_outcome_weights(ivf, S = S)
+all.equal(as.numeric(omega_if$omega %*% Y),
+          as.numeric(predict(ivf)$predictions))
+
+## Package output
+average_treatment_effect(ivf)
+
+## Rebuild using weights
+
+## Check numerical equivalence
+
+