[9e1f38]: / src / CV_TMLE.py

Download this file

172 lines (139 with data), 7.1 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import numpy as np
from scipy.special import logit, expit
from scipy.optimize import minimize
import numpy as np
from scipy.special import logit
import itertools
import sklearn.linear_model as lm
import numpy as np
from sklearn.feature_extraction.text import CountVectorizer
np.random.seed(0)
class CVTMLE:
def __init__(self, q_t0=None, q_t1=None, g=None, t=None, y=None, fromFolds=None, est_keys=None,
truncate_level=0.05):
"""
CVTMLE as conceived by Levi, 2018:
Levy, Jonathan. "An easy implementation of CV-TMLE." arXiv preprint arXiv:1811.04573 (2018).
:param q_t0: initial estimate with control exposure
:param q_t1: initial estimate with treatment/non-control exposure
:param g: prediction of propensity score
:param t: treatment label
:param y: factual outcome
:param fromFolds: if files for estimates per fold are provided (type list) which are npz files, then no need to provide first five parameterse
:param est_keys: once npz files are read, the keys are needed to extract estimates (i.e., first five parameters in this init)
:param truncate_level: truncation for propensity scores (0.05 default means that only patients with estimates between 0.05 and 0.95 will be considered)
"""
self.q_t0 = q_t0
self.q_t1 = q_t1
self.g = g
self.t = t
self.y = y
self.est_keys = est_keys
self.truncate_level = truncate_level
if fromFolds is not None:
self.q_t0, self.q_t1, self.y, self.g, self.t = self.collateFromFolds(fromFolds)
def _perturbed_model_bin_outcome(self, q_t0, q_t1, g, t, eps):
"""
Helper for psi_tmle_bin_outcome
Returns q_\eps (t,x) and the h term
(i.e., value of perturbed predictor at t, eps, x; where q_t0, q_t1, g are all evaluated at x
"""
h = t * (1. / g) - (1. - t) / (1. - g)
full_lq = (1. - t) * logit(q_t0) + t * logit(q_t1) # logit predictions from unperturbed model
logit_perturb = full_lq + eps * h
return expit(logit_perturb), h
def run_tmle_binary(self):
"""
This is for CV-TMLE on binary outcomes yielding risk ratio with 95% CI. Read Levi et al for methodological details.
Influence curves coded from Gruber S, van der Laan, MJ. (2011).
"""
print('running CV-TMLE for binary outcomes...')
q_t0, q_t1, g, t, y, truncatel = np.copy(self.q_t0), np.copy(self.q_t1), np.copy(self.g), np.copy(
self.t), np.copy(self.y), np.copy(self.truncate_level)
q_t0, q_t1, g, t, y = self.truncate_all_by_g(q_t0, q_t1, g, t, y, truncatel)
eps_hat = minimize(
lambda eps: self.cross_entropy(y, self._perturbed_model_bin_outcome(q_t0, q_t1, g, t, eps)[0]), 0.,
method='Nelder-Mead')
eps_hat = eps_hat.x[0]
def q1(t_cf):
return self._perturbed_model_bin_outcome(q_t0, q_t1, g, t_cf, eps_hat)
qall = ((1. - t) * (q_t0)) + (t * (q_t1)) # full predictions from unperturbed model
qq1, h1 = q1(np.ones_like(t))
qq0, h0 = q1(np.zeros_like(t))
rr = np.mean(qq1) / np.mean(qq0)
ic = (1 / np.mean(qq1) * (h1 * (y - qall) + qq1 - np.mean(qq1)) -
(1 / np.mean(qq0)) * (-1 * h0 * (y - qall) + qq0 - np.mean(qq0)))
psi_tmle_std = 1.96 * np.sqrt(np.var(ic) / (t.shape[0]))
return [rr, np.exp(np.log(rr) - psi_tmle_std), np.exp(np.log(rr) + psi_tmle_std)]
def run_tmle_continuous(self):
"""
This is for CV-TMLE on continuous outcomes yielding ATE/MD with 95% CI. Read Levi et al for methodological details.
Influence curves coded from Gruber S, van der Laan, MJ. (2011).
"""
print('running CV-TMLE for continuous outcomes...')
q_t0, q_t1, g, t, y, truncatel = np.copy(self.q_t0), np.copy(self.q_t1), np.copy(self.g), np.copy(
self.t), np.copy(self.y), np.copy(self.truncate_level)
q_t0, q_t1, g, t, y = self.truncate_all_by_g(q_t0, q_t1, g, t, y, truncatel)
h = t * (1.0 / g) - (1.0 - t) / (1.0 - g)
full_q = (1.0 - t) * q_t0 + t * q_t1
eps_hat = np.sum(h * (y - full_q)) / np.sum(np.square(h))
def q1(t_cf):
h_cf = t_cf * (1.0 / g) - (1.0 - t_cf) / (1.0 - g)
full_q = ((1.0 - t_cf) * q_t0) + (t_cf * q_t1)
return full_q + eps_hat * h_cf, h_cf
qq1, h_cf1 = q1(np.ones_like(t))
qq0, h_cf0 = q1(np.zeros_like(t))
haw = h_cf0 + h_cf1
rd = np.mean(qq1 - qq0)
ic = (haw * (y - full_q)) + (qq1 - qq0) - rd
psi_tmle_std = 1.96 * np.sqrt(np.var(ic) / (t.shape[0]))
return [rd, rd - psi_tmle_std, rd + psi_tmle_std]
def truncate_by_g(self, attribute, g, level=0.1):
keep_these = np.logical_and(g >= level, g <= 1. - level)
return attribute[keep_these]
def truncate_all_by_g(self, q_t0, q_t1, g, t, y, truncate_level=0.05):
"""
Helper function to clean up nuisance parameter estimates.
"""
orig_g = np.copy(g)
q_t0 = self.truncate_by_g(np.copy(q_t0), orig_g, truncate_level)
q_t1 = self.truncate_by_g(np.copy(q_t1), orig_g, truncate_level)
g = self.truncate_by_g(np.copy(g), orig_g, truncate_level)
t = self.truncate_by_g(np.copy(t), orig_g, truncate_level)
y = self.truncate_by_g(np.copy(y), orig_g, truncate_level)
return q_t0, q_t1, g, t, y
def cross_entropy(self, y, p):
return -np.mean((y * np.log(p) + (1. - y) * np.log(1. - p)))
def collateFromFolds(self, foldNPZ):
"""
FYI: keys can be provided but default is below
est_keys = {
'treatment_label_key' : 'treatment_label',
'outcome_key' : 'outcome',
'treatment_pred_key' : 'treatment',
'outcome_label_key' : 'outcome_label'}
"""
if self.est_keys is None:
self.est_keys = {
'treatment_label_key': 'treatment_label',
'outcome_key': 'outcome',
'treatment_pred_key': 'treatment',
'outcome_label_key': 'outcome_label'}
t_all = []
q1_all = []
q0_all = []
g_all = []
y_all = []
for fold in foldNPZ:
ld = np.load(fold)
t_all.append(ld[self.est_keys['treatment_label_key']])
q0_all.append(ld[self.est_keys['outcome_key']][:, 0])
q1_all.append(ld[self.est_keys['outcome_key']][:, 1])
y_all.append(ld[self.est_keys['outcome_label_key']])
g_all.append(ld[self.est_keys['treatment_pred_key']][:, 1])
t_all = np.array(list(itertools.chain(*t_all))).flatten()
g_all = np.array(list(itertools.chain(*g_all))).flatten()
q0_all = np.array(list(itertools.chain(*q0_all))).flatten()
q1_all = np.array(list(itertools.chain(*q1_all))).flatten()
y_all = np.array(list(itertools.chain(*y_all))).flatten()
return q0_all, q1_all, y_all, g_all, t_all