|
a |
|
b/src/CV_TMLE.py |
|
|
1 |
import numpy as np |
|
|
2 |
from scipy.special import logit, expit |
|
|
3 |
from scipy.optimize import minimize |
|
|
4 |
|
|
|
5 |
import numpy as np |
|
|
6 |
from scipy.special import logit |
|
|
7 |
import itertools |
|
|
8 |
import sklearn.linear_model as lm |
|
|
9 |
|
|
|
10 |
import numpy as np |
|
|
11 |
from sklearn.feature_extraction.text import CountVectorizer |
|
|
12 |
|
|
|
13 |
np.random.seed(0) |
|
|
14 |
|
|
|
15 |
|
|
|
16 |
class CVTMLE: |
|
|
17 |
def __init__(self, q_t0=None, q_t1=None, g=None, t=None, y=None, fromFolds=None, est_keys=None, |
|
|
18 |
truncate_level=0.05): |
|
|
19 |
""" |
|
|
20 |
CVTMLE as conceived by Levi, 2018: |
|
|
21 |
Levy, Jonathan. "An easy implementation of CV-TMLE." arXiv preprint arXiv:1811.04573 (2018). |
|
|
22 |
|
|
|
23 |
:param q_t0: initial estimate with control exposure |
|
|
24 |
:param q_t1: initial estimate with treatment/non-control exposure |
|
|
25 |
:param g: prediction of propensity score |
|
|
26 |
:param t: treatment label |
|
|
27 |
:param y: factual outcome |
|
|
28 |
:param fromFolds: if files for estimates per fold are provided (type list) which are npz files, then no need to provide first five parameterse |
|
|
29 |
:param est_keys: once npz files are read, the keys are needed to extract estimates (i.e., first five parameters in this init) |
|
|
30 |
:param truncate_level: truncation for propensity scores (0.05 default means that only patients with estimates between 0.05 and 0.95 will be considered) |
|
|
31 |
""" |
|
|
32 |
|
|
|
33 |
|
|
|
34 |
self.q_t0 = q_t0 |
|
|
35 |
self.q_t1 = q_t1 |
|
|
36 |
self.g = g |
|
|
37 |
self.t = t |
|
|
38 |
self.y = y |
|
|
39 |
self.est_keys = est_keys |
|
|
40 |
self.truncate_level = truncate_level |
|
|
41 |
if fromFolds is not None: |
|
|
42 |
self.q_t0, self.q_t1, self.y, self.g, self.t = self.collateFromFolds(fromFolds) |
|
|
43 |
|
|
|
44 |
def _perturbed_model_bin_outcome(self, q_t0, q_t1, g, t, eps): |
|
|
45 |
""" |
|
|
46 |
Helper for psi_tmle_bin_outcome |
|
|
47 |
|
|
|
48 |
Returns q_\eps (t,x) and the h term |
|
|
49 |
(i.e., value of perturbed predictor at t, eps, x; where q_t0, q_t1, g are all evaluated at x |
|
|
50 |
""" |
|
|
51 |
h = t * (1. / g) - (1. - t) / (1. - g) |
|
|
52 |
full_lq = (1. - t) * logit(q_t0) + t * logit(q_t1) # logit predictions from unperturbed model |
|
|
53 |
logit_perturb = full_lq + eps * h |
|
|
54 |
return expit(logit_perturb), h |
|
|
55 |
|
|
|
56 |
def run_tmle_binary(self): |
|
|
57 |
""" |
|
|
58 |
This is for CV-TMLE on binary outcomes yielding risk ratio with 95% CI. Read Levi et al for methodological details. |
|
|
59 |
Influence curves coded from Gruber S, van der Laan, MJ. (2011). |
|
|
60 |
|
|
|
61 |
""" |
|
|
62 |
|
|
|
63 |
print('running CV-TMLE for binary outcomes...') |
|
|
64 |
q_t0, q_t1, g, t, y, truncatel = np.copy(self.q_t0), np.copy(self.q_t1), np.copy(self.g), np.copy( |
|
|
65 |
self.t), np.copy(self.y), np.copy(self.truncate_level) |
|
|
66 |
q_t0, q_t1, g, t, y = self.truncate_all_by_g(q_t0, q_t1, g, t, y, truncatel) |
|
|
67 |
|
|
|
68 |
eps_hat = minimize( |
|
|
69 |
lambda eps: self.cross_entropy(y, self._perturbed_model_bin_outcome(q_t0, q_t1, g, t, eps)[0]), 0., |
|
|
70 |
method='Nelder-Mead') |
|
|
71 |
eps_hat = eps_hat.x[0] |
|
|
72 |
|
|
|
73 |
def q1(t_cf): |
|
|
74 |
return self._perturbed_model_bin_outcome(q_t0, q_t1, g, t_cf, eps_hat) |
|
|
75 |
|
|
|
76 |
qall = ((1. - t) * (q_t0)) + (t * (q_t1)) # full predictions from unperturbed model |
|
|
77 |
|
|
|
78 |
qq1, h1 = q1(np.ones_like(t)) |
|
|
79 |
qq0, h0 = q1(np.zeros_like(t)) |
|
|
80 |
rr = np.mean(qq1) / np.mean(qq0) |
|
|
81 |
|
|
|
82 |
ic = (1 / np.mean(qq1) * (h1 * (y - qall) + qq1 - np.mean(qq1)) - |
|
|
83 |
(1 / np.mean(qq0)) * (-1 * h0 * (y - qall) + qq0 - np.mean(qq0))) |
|
|
84 |
psi_tmle_std = 1.96 * np.sqrt(np.var(ic) / (t.shape[0])) |
|
|
85 |
|
|
|
86 |
return [rr, np.exp(np.log(rr) - psi_tmle_std), np.exp(np.log(rr) + psi_tmle_std)] |
|
|
87 |
|
|
|
88 |
def run_tmle_continuous(self): |
|
|
89 |
""" |
|
|
90 |
This is for CV-TMLE on continuous outcomes yielding ATE/MD with 95% CI. Read Levi et al for methodological details. |
|
|
91 |
Influence curves coded from Gruber S, van der Laan, MJ. (2011). |
|
|
92 |
|
|
|
93 |
""" |
|
|
94 |
print('running CV-TMLE for continuous outcomes...') |
|
|
95 |
|
|
|
96 |
q_t0, q_t1, g, t, y, truncatel = np.copy(self.q_t0), np.copy(self.q_t1), np.copy(self.g), np.copy( |
|
|
97 |
self.t), np.copy(self.y), np.copy(self.truncate_level) |
|
|
98 |
q_t0, q_t1, g, t, y = self.truncate_all_by_g(q_t0, q_t1, g, t, y, truncatel) |
|
|
99 |
|
|
|
100 |
h = t * (1.0 / g) - (1.0 - t) / (1.0 - g) |
|
|
101 |
full_q = (1.0 - t) * q_t0 + t * q_t1 |
|
|
102 |
eps_hat = np.sum(h * (y - full_q)) / np.sum(np.square(h)) |
|
|
103 |
|
|
|
104 |
def q1(t_cf): |
|
|
105 |
h_cf = t_cf * (1.0 / g) - (1.0 - t_cf) / (1.0 - g) |
|
|
106 |
full_q = ((1.0 - t_cf) * q_t0) + (t_cf * q_t1) |
|
|
107 |
return full_q + eps_hat * h_cf, h_cf |
|
|
108 |
|
|
|
109 |
qq1, h_cf1 = q1(np.ones_like(t)) |
|
|
110 |
qq0, h_cf0 = q1(np.zeros_like(t)) |
|
|
111 |
haw = h_cf0 + h_cf1 |
|
|
112 |
|
|
|
113 |
rd = np.mean(qq1 - qq0) |
|
|
114 |
ic = (haw * (y - full_q)) + (qq1 - qq0) - rd |
|
|
115 |
psi_tmle_std = 1.96 * np.sqrt(np.var(ic) / (t.shape[0])) |
|
|
116 |
|
|
|
117 |
return [rd, rd - psi_tmle_std, rd + psi_tmle_std] |
|
|
118 |
|
|
|
119 |
def truncate_by_g(self, attribute, g, level=0.1): |
|
|
120 |
keep_these = np.logical_and(g >= level, g <= 1. - level) |
|
|
121 |
return attribute[keep_these] |
|
|
122 |
|
|
|
123 |
def truncate_all_by_g(self, q_t0, q_t1, g, t, y, truncate_level=0.05): |
|
|
124 |
""" |
|
|
125 |
Helper function to clean up nuisance parameter estimates. |
|
|
126 |
""" |
|
|
127 |
orig_g = np.copy(g) |
|
|
128 |
q_t0 = self.truncate_by_g(np.copy(q_t0), orig_g, truncate_level) |
|
|
129 |
q_t1 = self.truncate_by_g(np.copy(q_t1), orig_g, truncate_level) |
|
|
130 |
g = self.truncate_by_g(np.copy(g), orig_g, truncate_level) |
|
|
131 |
t = self.truncate_by_g(np.copy(t), orig_g, truncate_level) |
|
|
132 |
y = self.truncate_by_g(np.copy(y), orig_g, truncate_level) |
|
|
133 |
return q_t0, q_t1, g, t, y |
|
|
134 |
|
|
|
135 |
def cross_entropy(self, y, p): |
|
|
136 |
return -np.mean((y * np.log(p) + (1. - y) * np.log(1. - p))) |
|
|
137 |
|
|
|
138 |
def collateFromFolds(self, foldNPZ): |
|
|
139 |
""" |
|
|
140 |
FYI: keys can be provided but default is below |
|
|
141 |
est_keys = { |
|
|
142 |
'treatment_label_key' : 'treatment_label', |
|
|
143 |
'outcome_key' : 'outcome', |
|
|
144 |
'treatment_pred_key' : 'treatment', |
|
|
145 |
'outcome_label_key' : 'outcome_label'} |
|
|
146 |
|
|
|
147 |
""" |
|
|
148 |
if self.est_keys is None: |
|
|
149 |
self.est_keys = { |
|
|
150 |
'treatment_label_key': 'treatment_label', |
|
|
151 |
'outcome_key': 'outcome', |
|
|
152 |
'treatment_pred_key': 'treatment', |
|
|
153 |
'outcome_label_key': 'outcome_label'} |
|
|
154 |
t_all = [] |
|
|
155 |
q1_all = [] |
|
|
156 |
q0_all = [] |
|
|
157 |
g_all = [] |
|
|
158 |
y_all = [] |
|
|
159 |
for fold in foldNPZ: |
|
|
160 |
ld = np.load(fold) |
|
|
161 |
t_all.append(ld[self.est_keys['treatment_label_key']]) |
|
|
162 |
q0_all.append(ld[self.est_keys['outcome_key']][:, 0]) |
|
|
163 |
q1_all.append(ld[self.est_keys['outcome_key']][:, 1]) |
|
|
164 |
y_all.append(ld[self.est_keys['outcome_label_key']]) |
|
|
165 |
g_all.append(ld[self.est_keys['treatment_pred_key']][:, 1]) |
|
|
166 |
t_all = np.array(list(itertools.chain(*t_all))).flatten() |
|
|
167 |
g_all = np.array(list(itertools.chain(*g_all))).flatten() |
|
|
168 |
q0_all = np.array(list(itertools.chain(*q0_all))).flatten() |
|
|
169 |
q1_all = np.array(list(itertools.chain(*q1_all))).flatten() |
|
|
170 |
y_all = np.array(list(itertools.chain(*y_all))).flatten() |
|
|
171 |
return q0_all, q1_all, y_all, g_all, t_all |