Diff of /src/CV_TMLE.py [000000] .. [9e1f38]

Switch to unified view

a b/src/CV_TMLE.py
1
import numpy as np
2
from scipy.special import logit, expit
3
from scipy.optimize import minimize
4
5
import numpy as np
6
from scipy.special import logit
7
import itertools
8
import sklearn.linear_model as lm
9
10
import numpy as np
11
from sklearn.feature_extraction.text import CountVectorizer
12
13
np.random.seed(0)
14
15
16
class CVTMLE:
17
    def __init__(self, q_t0=None, q_t1=None, g=None, t=None, y=None, fromFolds=None, est_keys=None,
18
                 truncate_level=0.05):
19
        """
20
        CVTMLE as conceived by Levi, 2018:
21
        Levy, Jonathan. "An easy implementation of CV-TMLE." arXiv preprint arXiv:1811.04573 (2018).
22
23
        :param q_t0: initial estimate with control exposure
24
        :param q_t1: initial estimate with treatment/non-control exposure
25
        :param g: prediction of propensity score
26
        :param t: treatment label
27
        :param y: factual outcome
28
        :param fromFolds: if files for estimates per fold are provided (type list) which are npz files, then no need to provide first five parameterse
29
        :param est_keys: once npz files are read, the keys are needed to extract estimates (i.e., first five parameters in this init)
30
        :param truncate_level: truncation for propensity scores (0.05 default means that only patients with estimates between 0.05 and 0.95 will be considered)
31
        """
32
33
34
        self.q_t0 = q_t0
35
        self.q_t1 = q_t1
36
        self.g = g
37
        self.t = t
38
        self.y = y
39
        self.est_keys = est_keys
40
        self.truncate_level = truncate_level
41
        if fromFolds is not None:
42
            self.q_t0, self.q_t1, self.y, self.g, self.t = self.collateFromFolds(fromFolds)
43
44
    def _perturbed_model_bin_outcome(self, q_t0, q_t1, g, t, eps):
45
        """
46
        Helper for psi_tmle_bin_outcome
47
48
        Returns q_\eps (t,x) and the h term
49
        (i.e., value of perturbed predictor at t, eps, x; where q_t0, q_t1, g are all evaluated at x
50
        """
51
        h = t * (1. / g) - (1. - t) / (1. - g)
52
        full_lq = (1. - t) * logit(q_t0) + t * logit(q_t1)  # logit predictions from unperturbed model
53
        logit_perturb = full_lq + eps * h
54
        return expit(logit_perturb), h
55
56
    def run_tmle_binary(self):
57
        """
58
        This is for CV-TMLE on binary outcomes yielding risk ratio with 95% CI. Read Levi et al for methodological details.
59
        Influence curves coded from Gruber S, van der Laan, MJ. (2011).
60
61
        """
62
63
        print('running CV-TMLE for binary outcomes...')
64
        q_t0, q_t1, g, t, y, truncatel = np.copy(self.q_t0), np.copy(self.q_t1), np.copy(self.g), np.copy(
65
            self.t), np.copy(self.y), np.copy(self.truncate_level)
66
        q_t0, q_t1, g, t, y = self.truncate_all_by_g(q_t0, q_t1, g, t, y, truncatel)
67
68
        eps_hat = minimize(
69
            lambda eps: self.cross_entropy(y, self._perturbed_model_bin_outcome(q_t0, q_t1, g, t, eps)[0]), 0.,
70
            method='Nelder-Mead')
71
        eps_hat = eps_hat.x[0]
72
73
        def q1(t_cf):
74
            return self._perturbed_model_bin_outcome(q_t0, q_t1, g, t_cf, eps_hat)
75
76
        qall = ((1. - t) * (q_t0)) + (t * (q_t1))  # full predictions from unperturbed model
77
78
        qq1, h1 = q1(np.ones_like(t))
79
        qq0, h0 = q1(np.zeros_like(t))
80
        rr = np.mean(qq1) / np.mean(qq0)
81
82
        ic = (1 / np.mean(qq1) * (h1 * (y - qall) + qq1 - np.mean(qq1)) -
83
              (1 / np.mean(qq0)) * (-1 * h0 * (y - qall) + qq0 - np.mean(qq0)))
84
        psi_tmle_std = 1.96 * np.sqrt(np.var(ic) / (t.shape[0]))
85
86
        return [rr, np.exp(np.log(rr) - psi_tmle_std), np.exp(np.log(rr) + psi_tmle_std)]
87
88
    def run_tmle_continuous(self):
89
        """
90
        This is for CV-TMLE on continuous outcomes yielding ATE/MD with 95% CI. Read Levi et al for methodological details.
91
        Influence curves coded from Gruber S, van der Laan, MJ. (2011).
92
93
        """
94
        print('running CV-TMLE for continuous outcomes...')
95
96
        q_t0, q_t1, g, t, y, truncatel = np.copy(self.q_t0), np.copy(self.q_t1), np.copy(self.g), np.copy(
97
            self.t), np.copy(self.y), np.copy(self.truncate_level)
98
        q_t0, q_t1, g, t, y = self.truncate_all_by_g(q_t0, q_t1, g, t, y, truncatel)
99
100
        h = t * (1.0 / g) - (1.0 - t) / (1.0 - g)
101
        full_q = (1.0 - t) * q_t0 + t * q_t1
102
        eps_hat = np.sum(h * (y - full_q)) / np.sum(np.square(h))
103
104
        def q1(t_cf):
105
            h_cf = t_cf * (1.0 / g) - (1.0 - t_cf) / (1.0 - g)
106
            full_q = ((1.0 - t_cf) * q_t0) + (t_cf * q_t1)
107
            return full_q + eps_hat * h_cf, h_cf
108
109
        qq1, h_cf1 = q1(np.ones_like(t))
110
        qq0, h_cf0 = q1(np.zeros_like(t))
111
        haw = h_cf0 + h_cf1
112
113
        rd = np.mean(qq1 - qq0)
114
        ic = (haw * (y - full_q)) + (qq1 - qq0) - rd
115
        psi_tmle_std = 1.96 * np.sqrt(np.var(ic) / (t.shape[0]))
116
117
        return [rd, rd - psi_tmle_std, rd + psi_tmle_std]
118
119
    def truncate_by_g(self, attribute, g, level=0.1):
120
        keep_these = np.logical_and(g >= level, g <= 1. - level)
121
        return attribute[keep_these]
122
123
    def truncate_all_by_g(self, q_t0, q_t1, g, t, y, truncate_level=0.05):
124
        """
125
        Helper function to clean up nuisance parameter estimates.
126
        """
127
        orig_g = np.copy(g)
128
        q_t0 = self.truncate_by_g(np.copy(q_t0), orig_g, truncate_level)
129
        q_t1 = self.truncate_by_g(np.copy(q_t1), orig_g, truncate_level)
130
        g = self.truncate_by_g(np.copy(g), orig_g, truncate_level)
131
        t = self.truncate_by_g(np.copy(t), orig_g, truncate_level)
132
        y = self.truncate_by_g(np.copy(y), orig_g, truncate_level)
133
        return q_t0, q_t1, g, t, y
134
135
    def cross_entropy(self, y, p):
136
        return -np.mean((y * np.log(p) + (1. - y) * np.log(1. - p)))
137
138
    def collateFromFolds(self, foldNPZ):
139
        """
140
        FYI: keys can be provided but default is below
141
        est_keys = {
142
        'treatment_label_key' : 'treatment_label',
143
        'outcome_key' : 'outcome',
144
        'treatment_pred_key' : 'treatment',
145
        'outcome_label_key' : 'outcome_label'}
146
147
        """
148
        if self.est_keys is None:
149
            self.est_keys = {
150
                'treatment_label_key': 'treatment_label',
151
                'outcome_key': 'outcome',
152
                'treatment_pred_key': 'treatment',
153
                'outcome_label_key': 'outcome_label'}
154
        t_all = []
155
        q1_all = []
156
        q0_all = []
157
        g_all = []
158
        y_all = []
159
        for fold in foldNPZ:
160
            ld = np.load(fold)
161
            t_all.append(ld[self.est_keys['treatment_label_key']])
162
            q0_all.append(ld[self.est_keys['outcome_key']][:, 0])
163
            q1_all.append(ld[self.est_keys['outcome_key']][:, 1])
164
            y_all.append(ld[self.est_keys['outcome_label_key']])
165
            g_all.append(ld[self.est_keys['treatment_pred_key']][:, 1])
166
        t_all = np.array(list(itertools.chain(*t_all))).flatten()
167
        g_all = np.array(list(itertools.chain(*g_all))).flatten()
168
        q0_all = np.array(list(itertools.chain(*q0_all))).flatten()
169
        q1_all = np.array(list(itertools.chain(*q1_all))).flatten()
170
        y_all = np.array(list(itertools.chain(*y_all))).flatten()
171
        return q0_all, q1_all, y_all, g_all, t_all