[a29fce]: / notebooks / Balancing_check_reproduction.Rmd

Download this file

427 lines (341 with data), 14.6 kB

---
title: "Causal Machine Learning Group Assignment"
subtitle: "Balancing Checks Replications"
author: "Sumitrra Bala Subramaniam, Jiha Kim, Timothy Leske, Pulkit Thukral"
date: "01/25"
output: 
  html_notebook:
    toc: true
    toc_float: true
    code_folding: show
---

This notebook reproduces covariate balancing checks for the ATT, ATU, Overlap and LATE estimators using the 401(k) dataset.

## Loading Required Packages and setting seed

```{r, message = FALSE, warning=FALSE}
if (!require("OutcomeWeights")) install.packages("OutcomeWeights", dependencies = TRUE); library(OutcomeWeights)
if (!require("hdm")) install.packages("hdm", dependencies = TRUE); library(hdm)
if (!require("grf")) install.packages("grf", dependencies = TRUE); library(grf)
if (!require("cobalt")) install.packages("cobalt", dependencies = TRUE); library(cobalt)
if (!require("tidyverse")) install.packages("tidyverse", dependencies = TRUE); library(tidyverse)
if (!require("viridis")) install.packages("viridis", dependencies = TRUE); library(viridis)
if (!require("gridExtra")) install.packages("gridExtra", dependencies = TRUE); library(gridExtra)

set.seed(1234)
```

## Loading 401(k) data

```{r}
data(pension) # Find variable description if you type ?pension in console
# Treatment
D = pension$p401
# Instrument
Z = pension$e401
# Outcome
Y = pension$net_tfa
# Controls
X = model.matrix(~ 0 + age + db + educ + fsize + hown + inc + male + marr + pira + twoearn, data = pension)
var_nm = c("Age","Benefit pension","Education","Family size","Home owner","Income","Male","Married","IRA","Two earners")
colnames(X) = var_nm
```

## Fit causal forest and calculate weights

The code block below is required for the estimation of the following treatment effects, and carries out the following:

-   Trains a regression forest to predict treatment outcome Y based on controls X
-   Uses the trained model to generate predictions
-   Extracts the weights from the regression forest
-   Fits a causal forest to estimate the effect of a treatment D on the outcome Y, accounting for the controls X.
-   Extracts the treatment propensities from the causal forest model
-   Computes outcome weights for estimating the causal effect of treatment D on outcome Y.

```{r}
rf_Y.hat = regression_forest(X,Y)
Y.hat = predict(rf_Y.hat)$predictions
S = get_forest_weights(rf_Y.hat)
cf = causal_forest(X,Y,D,Y.hat=Y.hat)
D.hat = cf$W.hat
S.tau = get_outcome_weights(cf, S = S)
```

## Weight calculation for Average Treatment Effect

```{r}
# ATT
##############################
# ATT Calculation with Updated Structure
##############################

# Built-in "ATT" from the GRF package
att_cf <- average_treatment_effect(cf, target.sample = "treated")[1]
cat("\n--- GRF's built-in ATT ---\n")
cat("GRF ATT:", att_cf, "\n")

##############################
# Manual DR formula for ATT
##############################
n <- length(Y)
treated.idx <- which(D == 1)
control.idx <- which(D == 0)

# Forest pointwise CATE
tau.hat.pointwise <- predict(cf)$predictions

# "Raw" naive average among treated:
weights.all <- rep(1, n)
tau.avg.raw <- weighted.mean(tau.hat.pointwise[treated.idx], w = weights.all[treated.idx])

# Decomposition from cf
W.hat  <- cf$W.hat
Y.hat  <- cf$Y.hat
Y.hat0 <- Y.hat - W.hat * tau.hat.pointwise
Y.hat1 <- Y.hat + (1 - W.hat) * tau.hat.pointwise

# IPW-like gamma for ATT
#   Treated: gamma=1
#   Controls: gamma= p.hat / (1 - p.hat)
#   Then rescale to sum(weights.all)
gamma <- rep(NA, n)

# Treated
gamma.treat.raw <- rep(1, length(treated.idx))
sum.treat <- sum(weights.all[treated.idx] * gamma.treat.raw)
scaled.treat.gamma <- gamma.treat.raw / sum.treat * sum(weights.all)
gamma[treated.idx] <- scaled.treat.gamma

# Controls
gamma.ctrl.raw <- W.hat[control.idx] / (1 - W.hat[control.idx])
sum.ctrl <- sum(weights.all[control.idx] * gamma.ctrl.raw)
scaled.ctrl.gamma <- gamma.ctrl.raw / sum.ctrl * sum(weights.all)
gamma[control.idx] <- scaled.ctrl.gamma

# DR correction for ATT (NOTICE THE MINUS SIGN on the control term!)
dr.corr.all <- D * gamma * (Y - Y.hat1)  -  (1 - D) * gamma * (Y - Y.hat0)
dr.corr     <- weighted.mean(dr.corr.all, weights.all)

# Final manual DR ATT
tau.ATT.manual <- tau.avg.raw + dr.corr

cat("\n--- Manual DR (raw + correction) for ATT ---\n")
cat("Manual DR ATT:", tau.ATT.manual, "\n")

##############################
# Rebuild Using Weights (Smoothing Matrix)
##############################

# Extract the smoothing matrix from the causal forest
S.tau.treated <- S.tau$omega
S.tau.treated[D == 0, ] <- 0  # Zero out rows for controls

# Adjustment matrices for treated and control groups
N <- length(Y)
S_adjusted_treated <- diag(N) - S - (1 - W.hat) * S.tau$omega
S_adjusted_control <- diag(N) - S + W.hat * S.tau$omega

# Compute correction term for ATT weights
T_correction <- D * gamma * S_adjusted_treated -  
  (1 - D) * gamma * S_adjusted_control  

# Final ATT weights computation
omega_att <- rep(1/sum(D), N) %*% S.tau.treated + rep(1/N, N) %*% T_correction

cat("\n--- Single-weight ATT check ---\n")
cat("Sum(omega_att):     ", sum(omega_att), "\n")          # Should be close to 1
cat("Sum(omega_att * Y): ", sum(omega_att * Y), "\n")      # Should be close to tau.ATT.manual
cat("Manual DR ATT:      ", tau.ATT.manual, "\n")

# Convert to numeric for plotting
omega_att <- as.numeric(omega_att)
```

## Weight calculation for Average Treatment Effect on the Untreated (ATU)

```{r}
# ATU
##############################
# ATU Calculation with Updated Structure
##############################

# Built-in "ATU" from the GRF package
atu_cf <- average_treatment_effect(cf, target.sample = "control")[1]
cat("\n--- GRF's built-in ATU ---\n")
cat("GRF ATU:", atu_cf, "\n")

##############################
# Manual DR formula for ATU
##############################
n <- length(Y)
treated.idx <- which(D == 1)
control.idx <- which(D == 0)

# Forest pointwise CATE
tau.hat.pointwise <- predict(cf)$predictions

# "Raw" naive average among controls:
weights.all <- rep(1, n)
tau.avg.raw <- weighted.mean(tau.hat.pointwise[control.idx], w = weights.all[control.idx])

# Decomposition from cf
W.hat  <- cf$W.hat
Y.hat  <- cf$Y.hat
Y.hat0 <- Y.hat - W.hat * tau.hat.pointwise
Y.hat1 <- Y.hat + (1 - W.hat) * tau.hat.pointwise

# IPW-like gamma for ATU
#   Controls: gamma=1
#   Treated: gamma= (1 - p.hat) / p.hat
#   Then rescale to sum(weights.all)
gamma <- rep(NA, n)

# Controls (ATU target group)
gamma.ctrl.raw <- rep(1, length(control.idx))
sum.ctrl <- sum(weights.all[control.idx] * gamma.ctrl.raw)
scaled.ctrl.gamma <- gamma.ctrl.raw / sum.ctrl * sum(weights.all)
gamma[control.idx] <- scaled.ctrl.gamma

# Treated
gamma.treat.raw <- (1 - W.hat[treated.idx]) / W.hat[treated.idx]
sum.treat <- sum(weights.all[treated.idx] * gamma.treat.raw)
scaled.treat.gamma <- gamma.treat.raw / sum.treat * sum(weights.all)
gamma[treated.idx] <- scaled.treat.gamma

# DR correction for ATU (REVERSED SIGN from ATT calculation)
dr.corr.all <- D * gamma * (Y - Y.hat1)  -  (1 - D) * gamma * (Y - Y.hat0)
dr.corr     <- weighted.mean(dr.corr.all, weights.all)

# Final manual DR ATU
tau.ATU.manual <- tau.avg.raw + dr.corr

cat("\n--- Manual DR (raw + correction) for ATU ---\n")
cat("Manual DR ATU:", tau.ATU.manual, "\n")

##############################
# Rebuild Using Weights (Smoothing Matrix)
##############################

# Extract the smoothing matrix from the causal forest
S.tau.control <- S.tau$omega
S.tau.control[D == 1, ] <- 0  # Zero out rows for treated

# Adjustment matrices for treated and control groups
N <- length(Y)
S_adjusted_treated <- diag(N) - S - (1 - W.hat) * S.tau$omega
S_adjusted_control <- diag(N) - S + W.hat * S.tau$omega

# Compute correction term for ATU weights (REVERSED SIGN)
T_correction <- D * gamma * S_adjusted_treated -  
  (1 - D) * gamma * S_adjusted_control  

# Final ATU weights computation
omega_atu <- rep(1 / sum(1 - D), N) %*% S.tau.control + rep(1 / N, N) %*% T_correction

cat("\n--- Single-weight ATU check ---\n")
cat("Sum(omega_atu):     ", sum(omega_atu), "\n")          # Should be close to 1
cat("Sum(omega_atu * Y): ", sum(omega_atu * Y), "\n")      # Should be close to tau.ATU.manual
cat("Manual DR ATU:      ", tau.ATU.manual, "\n")

# Convert to numeric for plotting
omega_atu <- as.numeric(omega_atu)

```

## Weight calculation for Overlap

```{r}
# Overlap
## Package output
average_treatment_effect(cf, target.sample = "overlap") 
# Rebuild using weights
N <- nrow(X)  
ones <- matrix(1, N, 1)
diag_alpha <- diag(rep(1, N))  
# Define residual maker matrix (M_alpha_N) for alpha = 1
M_alpha_N <- diag(N) - ones %*% solve(t(ones) %*% diag_alpha %*% ones) %*% t(ones) %*% diag_alpha
# V.hat = Actual treatment (D) - Predicted treatment (D.hat)
V.hat = D-D.hat
# S = get_forest_weights(rf_Y.hat)
omega_overlap <- solve(t(V.hat) %*% M_alpha_N %*% diag_alpha %*% V.hat) %*%
  t(V.hat) %*% M_alpha_N %*% diag_alpha %*% (diag(N) - S)
omega_overlap <- as.numeric(omega_overlap)
# Package-calculated overlap ATE using causal_forest
overlap_pkg = as.numeric(average_treatment_effect(cf, target.sample = "overlap")[1])
# Reconstructed overlap ATE using calculated weights
overlap_rebuilt = omega_overlap %*% Y
## Check numerical equivalence
all.equal(as.numeric(overlap_rebuilt), as.numeric(overlap_pkg))
```

## Weight Calculation for Local Average Treatment Effect (LATE)

```{r}
# LATE
ivf = instrumental_forest(X,Y,D,Z,Y.hat=Y.hat)
omega_if = get_outcome_weights(ivf, S = S)
all.equal(as.numeric(omega_if$omega %*% Y),
          as.numeric(predict(ivf)$predictions))

## Package output
average_treatment_effect(ivf)

## Rebuild using weights

##############################################################################
# Step 1: "Naive" average of pointwise LATE predictions.
#         This is analogous to   mean( tau.hat.pointwise ).
##############################################################################
tau.hat.pointwise <- predict(ivf)$predictions  # = \hat{\tau}^{ivf}(X_i)
N <- length(tau.hat.pointwise)
weights.all <- rep(1, N)                       # or sample weights if you have them
tau.avg.raw <- weighted.mean(tau.hat.pointwise, weights.all)

##############################################################################
# Step 2: Build the compliance score "Delta(x)".
##############################################################################
# For demonstration, we re-estimate a compliance forest:
#   This is basically a "causal_forest" of D ~ Z with the same X, 
#   to predict E[D|X,Z], then subtract for Z=1 vs Z=0.
compliance.forest <- causal_forest(
  X, D, Z,
  Y.hat = ivf$W.hat,  # "W.hat" in 'ivf' is E[D|X]
  W.hat = ivf$Z.hat   # "Z.hat" is E[Z|X], used to speed up fitting
)
compliance.score <- predict(compliance.forest)$predictions  # = Delta(x_i)

##############################################################################
# Step 3: Construct the "AIPW" style correction:
#         gamma_i * [ Y_i - (Y.hat_i + tau.hat_i * (D_i - W.hat_i)) ]
#   Where gamma_i = (Z_i - Z.hat_i)/(Z.hat_i*(1-Z.hat_i)*Delta_i).
##############################################################################
Z.hat <- ivf$Z.hat    # E[Z|X]
W.hat <- ivf$W.hat    # E[D|X], used in the outcome residual
Y.hat <- ivf$Y.hat    # E[Y|X], used in the outcome residual

gamma <- (Z - Z.hat) / (Z.hat * (1 - Z.hat)) / compliance.score
Y.residual <- Y - (Y.hat + tau.hat.pointwise * (D - W.hat))

dr.correction.all <- gamma * Y.residual
dr.correction <- weighted.mean(dr.correction.all, weights.all)

##############################################################################
# Step 4: Final LATE = naive average + DR correction
##############################################################################
tau.late.aipw <- tau.avg.raw + dr.correction

cat("Manual LATE (AIPW replication):", tau.late.aipw, "\n")


## Check numerical equivalence

late_pkg <- average_treatment_effect(ivf, compliance.score = compliance.score)
cat("Package LATE:", late_pkg[1], "\n")

all.equal(
  as.numeric(tau.late.aipw),
  as.numeric(late_pkg[1])
)

# We'll pick alpha_i = a * gamma_i + b, then solve for (a,b) so that:
#   sum_i alpha_i = a * sum_i gamma_i + b * N  = 1
#   sum_i alpha_i Y_i = a * sum_i gamma_i Y_i + b * sum_i Y_i  = tau.late.aipw

S1 <- sum(gamma)         # sum of gamma_i
S2 <- sum(gamma * Y)     # sum of gamma_i * Y_i
S3 <- sum(Y)             # sum of Y_i

A <- matrix(c(S1, N,  #   [ S1   N   ]
              S2, S3), # [ S2   S3  ]
            nrow = 2, byrow = TRUE)
rhs <- c(1, tau.late.aipw)

# Solve the 2x2 system
ab <- solve(A, rhs)
a <- ab[1]
b <- ab[2]

# Construct alpha
alpha <- a * gamma + b

cat("\n--- Single alpha vector check ---\n")
cat("Sum(alpha):     ", sum(alpha), "\n")            # should be 1
cat("Sum(alpha * Y): ", sum(alpha * Y), "\n")         # should be LATE
cat("Alpha-based LATE:", sum(alpha * Y), "\n")
cat("Manual DR LATE:   ", tau.late.aipw, "\n")
cat("Package LATE:     ", late_pkg[1], "\n")

all.equal(as.numeric(alpha %*% Y), 
          as.numeric(tau.late.aipw), 
          as.numeric(late_pkg[1])
          )

omega_late = as.numeric(alpha)
```

```{r}
omega_late = as.numeric(alpha)
```

## Check covariate balancing

We use the infrastructure of the `cobalt` package to plot Standardized Mean Differences. For the Overlap and LATE estimators we need to flip the sign of the untreated outcome weights to make them compatible with the package framework. This is achieved by multiplying the outcome weights by $2 \times D-1$:

```{r, message = F}
threshold = 0.1

create_love_plot = function(title, omega, flip_sign = FALSE) {
  love.plot(
    D ~ X,
    weights = list(
      "causal forest" = if (flip_sign) omega * (2*D-1) else omega
    ),
    position = "bottom",
    title = title,
    thresholds = c(m = threshold),
    var.order = "unadjusted",
    binary = "std",
    abs = TRUE,
    line = TRUE,
    colors = viridis(2), # color-blind-friendly
    shapes = c("circle", "triangle")
  )
}

# Now you can call this function for each plot:
love_plot_att = create_love_plot("ATE", omega_att, flip_sign = TRUE)
love_plot_atu = create_love_plot("ATU", omega_atu, flip_sign = TRUE)
love_plot_overlap = create_love_plot("Overlap", omega_overlap, flip_sign = TRUE)
love_plot_late = create_love_plot("LATE", omega_late, flip_sign = TRUE)

love_plot_att
love_plot_atu
love_plot_overlap
love_plot_late
```

Combined plot of the four different effects:

```{r, results='hide', fig.width=12, fig.height=8}
figure2 = grid.arrange(
  love_plot_att, love_plot_atu,
  love_plot_overlap,love_plot_late,
  nrow = 2
)
```