Switch to unified view

a b/overview/variable-effects.R
1
cox.disc.filename <- '../../output/all-cv-survreg-boot-try5-surv-model.rds'
2
caliber.missing.coefficients.filename <-
3
  '../../output/caliber-replicate-with-missing-survreg-6-linear-age-coeffs-3.csv'
4
rf.filename <- '../../output/rfsrc-cv-nsplit-try3-var-effects.csv'
5
6
source('../lib/shared.R')
7
requirePlus('cowplot')
8
9
# Amount of padding at the right-hand side to make space for missing values
10
missing.padding <- 0.05
11
12
continuous.vars <-
13
  c(
14
    'age', 'total_chol_6mo', 'hdl_6mo', 'pulse_6mo', 'crea_6mo',
15
    'total_wbc_6mo', 'haemoglobin_6mo'
16
  )
17
18
# Load in the discretised Cox model for plotting
19
surv.model.fit.boot <- readRDS(cox.disc.filename)
20
21
# Pull coefficients from model with missing data
22
caliber.missing.coeffs <- read.csv(caliber.missing.coefficients.filename)
23
# Log them to get them on the same scale as discrete model
24
caliber.missing.coeffs$our_value <- -log(caliber.missing.coeffs$our_value)
25
caliber.missing.coeffs$our_lower <- -log(caliber.missing.coeffs$our_lower)
26
caliber.missing.coeffs$our_upper <- -log(caliber.missing.coeffs$our_upper)
27
28
# Load the data
29
COHORT.use <- data.frame(fread(data.filename))
30
31
# Open the calibration to find the best binning scheme
32
calibration.filename <- '../../output/survreg-crossvalidation-try5.csv'
33
cv.performance <- read.csv(calibration.filename)
34
35
# Find the best calibration...
36
# First, average performance across cross-validation folds
37
cv.performance.average <-
38
  aggregate(
39
    c.index.val ~ calibration,
40
    data = cv.performance,
41
    mean
42
  )
43
# Find the highest value
44
best.calibration <-
45
  cv.performance.average$calibration[
46
    which.max(cv.performance.average$c.index.val)
47
    ]
48
# And finally, find the first row of that calibration to get the n.bins values
49
best.calibration.row1 <-
50
  min(which(cv.performance$calibration == best.calibration))
51
52
# Get its parameters
53
n.bins <-
54
  t(
55
    cv.performance[best.calibration.row1, continuous.vars]
56
  )
57
58
# Prepare the data with those settings...
59
60
# Reset process settings with the base setings
61
process.settings <-
62
  list(
63
    var        = c('anonpatid', 'time_death', 'imd_score', 'exclude'),
64
    method     = c(NA, NA, NA, NA),
65
    settings   = list(NA, NA, NA, NA)
66
  )
67
for(j in 1:length(continuous.vars)) {
68
  process.settings$var <- c(process.settings$var, continuous.vars[j])
69
  process.settings$method <- c(process.settings$method, 'binByQuantile')
70
  process.settings$settings <-
71
    c(
72
      process.settings$settings,
73
      list(
74
        seq(
75
          # Quantiles are obviously between 0 and 1
76
          0, 1,
77
          # Choose a random number of bins (and for n bins, you need n + 1 breaks)
78
          length.out = n.bins[j]
79
        )
80
      )
81
    )
82
}
83
84
# prep the data given the variables provided
85
COHORT.optimised <-
86
  prepData(
87
    # Data for cross-validation excludes test set
88
    COHORT.use,
89
    cols.keep,
90
    process.settings,
91
    surv.time, surv.event,
92
    surv.event.yes,
93
    extra.fun = caliberExtraPrep
94
  )
95
96
# Unpack variable and level names
97
cph.coeffs <- cphCoeffs(
98
  bootStats(surv.model.fit.boot, uncertainty = '95ci', transform = `-`),
99
  COHORT.optimised, surv.predict, model.type = 'boot.survreg'
100
)
101
102
# We'll need the CALIBER scaling functions for plotting
103
source('../cox-ph/caliber-scale.R')
104
105
# set up list to store the plots
106
cox.discrete.plots <- list()
107
# Add dummy columns for x-position of missing values
108
caliber.missing.coeffs$missing.x.pos.cont <- NA
109
cph.coeffs$missing.x.pos.disc <- NA
110
111
for(variable in unique(cph.coeffs$var)) {
112
  # If it's a continuous variable, get the real centres of the bins
113
  if(variable %in% process.settings$var) {
114
    process.i <- which(variable == process.settings$var)
115
    
116
    if(process.settings$method[[process.i]] == 'binByQuantile') {
117
      
118
      variable.quantiles <-
119
        getQuantiles(
120
          COHORT.use[, variable],
121
          process.settings$settings[[process.i]]
122
        )
123
      # For those rows which relate to this variable, and whose level isn't
124
      # missing, put in the appropriate quantile boundaries for plotting
125
      cph.coeffs$bin.min[cph.coeffs$var == variable & 
126
                           cph.coeffs$level != 'missing'] <-
127
        variable.quantiles[1:(length(variable.quantiles) - 1)]
128
      cph.coeffs$bin.max[cph.coeffs$var == variable & 
129
                           cph.coeffs$level != 'missing'] <-
130
        variable.quantiles[2:length(variable.quantiles)]
131
      # Make the final bin the 99th percentile
132
      cph.coeffs$bin.max[cph.coeffs$var == variable & 
133
                           cph.coeffs$level != 'missing'][
134
                             length(variable.quantiles) - 1] <-
135
        quantile(COHORT.use[, variable], 0.99, na.rm = TRUE)
136
      
137
      # Add a fake data point at the highest value to finish the graph
138
      cph.coeffs <-
139
        rbind(
140
          cph.coeffs,
141
          cph.coeffs[cph.coeffs$var == variable & 
142
                       cph.coeffs$level != 'missing', ][
143
                         length(variable.quantiles) - 1, ]
144
        )
145
      # Change it so that bin.min is bin.max from the old one
146
      cph.coeffs$bin.min[nrow(cph.coeffs)] <-
147
        cph.coeffs$bin.max[cph.coeffs$var == variable & 
148
                             cph.coeffs$level != 'missing'][
149
                               length(variable.quantiles) - 1]
150
      
151
      # Work out data range by taking the 1st and 99th percentiles
152
      # Use the max to provide a max value for the final bin
153
      # Also use for x-axis limits, unless there are missing values to
154
      # accommodate on the right-hand edge.
155
      x.data.range <-
156
        quantile(COHORT.use[, variable], c(0.01, 0.99), na.rm = TRUE)
157
      x.axis.limits <- x.data.range
158
      
159
      
160
      # Finally, we need to scale this such that the baseline value is equal
161
      # to the value for the equivalent place in the Cox model, to make the
162
      # risks comparable...
163
      
164
      # First, we need to find the average value of this variable in the lowest
165
      # bin (which is always the baseline here)
166
      baseline.bin <- variable.quantiles[1:2]
167
      baseline.bin.avg <- 
168
        mean(
169
          # Take only those values of the variable which are in the range
170
          COHORT.use[
171
            inRange(COHORT.use[, variable], baseline.bin, na.false = TRUE),
172
            variable
173
            ]
174
        )
175
      # Then, scale it with the caliber scaling
176
      baseline.bin.val <-
177
        caliberScaleUnits(baseline.bin.avg, variable) * 
178
        caliber.missing.coeffs$our_value[
179
          caliber.missing.coeffs$quantity == variable
180
          ]
181
      
182
      # And now, add all the discretised values to that value to make them
183
      # comparable...
184
      cph.coeffs[cph.coeffs$var == variable, c('val', 'lower', 'upper')] <-
185
        cph.coeffs[cph.coeffs$var == variable, c('val', 'lower', 'upper')] -
186
        baseline.bin.val
187
      
188
      # Now, plot this variable as a stepped line plot using those quantile
189
      # boundaries
190
      cox.discrete.plot <-
191
        ggplot(
192
          subset(cph.coeffs, var == variable),
193
          aes(x = bin.min, y = val)
194
        ) +   
195
        geom_step() +
196
        geom_step(aes(y = lower), colour = 'grey') +
197
        geom_step(aes(y = upper), colour = 'grey') +
198
        labs(x = variable, y = 'Bx')
199
      
200
      # If there's a missing value risk, add it
201
      if(any(cph.coeffs$var == variable & cph.coeffs$level == 'missing')) {
202
        # Expand the x-axis to squeeze the missing values in
203
        x.axis.limits[2] <- 
204
          x.axis.limits[2] + diff(x.data.range) * missing.padding
205
        # Put this missing value a third of the way into the missing area
206
        cph.coeffs$missing.x.pos.disc[
207
          cph.coeffs$var == variable &
208
          cph.coeffs$level == 'missing'] <-
209
          x.axis.limits[2] + diff(x.data.range) * missing.padding / 3
210
211
        # Add the point to the graph (we'll set axis limits later)
212
        cox.discrete.plot <-
213
          cox.discrete.plot +
214
          geom_pointrange(
215
            data = cph.coeffs[cph.coeffs$var == variable & 
216
                                cph.coeffs$level == 'missing', ],
217
            aes(
218
              x = missing.x.pos.disc,
219
              y = val, ymin = lower,
220
              ymax = upper
221
            ),
222
            colour = 'red'
223
          )
224
      }
225
      
226
      # Now, let's add the line from the continuous Cox model. We only need two
227
      # points because the lines are straight!
228
      continuous.cox <-
229
        data.frame(
230
          var.x.values = x.data.range
231
        )
232
      # Scale the x-values
233
      continuous.cox$var.x.scaled <-
234
        caliberScaleUnits(continuous.cox$var.x.values, variable)
235
      # Use the risks to calculate risk per x for central estimate and errors
236
      continuous.cox$y <-
237
        -caliber.missing.coeffs$our_value[
238
          caliber.missing.coeffs$quantity == variable
239
          ] * continuous.cox$var.x.scaled
240
      continuous.cox$upper <-
241
        -caliber.missing.coeffs$our_upper[
242
          caliber.missing.coeffs$quantity == variable
243
          ] * continuous.cox$var.x.scaled
244
      continuous.cox$lower <-
245
        -caliber.missing.coeffs$our_lower[
246
          caliber.missing.coeffs$quantity == variable
247
          ] * continuous.cox$var.x.scaled
248
      
249
      cox.discrete.plot <-
250
        cox.discrete.plot +
251
        geom_line(
252
          data = continuous.cox,
253
          aes(x = var.x.values, y = y),
254
          colour = 'blue'
255
        ) +
256
        geom_line(
257
          data = continuous.cox,
258
          aes(x = var.x.values, y = upper),
259
          colour = 'lightblue'
260
        ) +
261
        geom_line(
262
          data = continuous.cox,
263
          aes(x = var.x.values, y = lower),
264
          colour = 'lightblue'
265
        )
266
      
267
      # If there is one, add missing value risk from the continuous model
268
      if(any(caliber.missing.coeffs$quantity == paste0(variable, '_missing') &
269
             caliber.missing.coeffs$unit == 'missing')) {
270
        
271
        # Put this missing value 2/3rds of the way into the missing area
272
        caliber.missing.coeffs$missing.x.pos.cont[
273
          caliber.missing.coeffs$quantity == paste0(variable, '_missing') &
274
            caliber.missing.coeffs$unit == 'missing'] <-
275
                x.axis.limits[2] + 2 * diff(x.data.range) * missing.padding / 3
276
        
277
        cox.discrete.plot <-
278
          cox.discrete.plot +
279
          geom_pointrange(
280
            data = caliber.missing.coeffs[
281
              caliber.missing.coeffs$quantity == paste0(variable, '_missing') &
282
                caliber.missing.coeffs$unit == 'missing',
283
              ],
284
            aes(
285
              x = missing.x.pos.cont,
286
              y = our_value, ymin = our_lower, ymax = our_upper
287
            ),
288
            colour = 'blue'
289
          )
290
      }
291
      
292
      # Finally, set the x-axis limits; will just be the data range, or data
293
      # range plus a bit if there are missing values to squeeze in
294
      cox.discrete.plot <-
295
        cox.discrete.plot +
296
        coord_cartesian(xlim = x.axis.limits) +
297
        theme(axis.title.y = element_blank()) +
298
        theme(plot.margin = unit(c(0.2, 0.1, 0.2, 0.1), "cm"))
299
      
300
      cox.discrete.plots[[variable]] <- cox.discrete.plot
301
    }
302
  }
303
}
304
305
# Load the random forest variable effects file
306
risk.by.variables <- read.csv(rf.filename)
307
rf.vareff.plots <- list()
308
309
for(variable in unique(risk.by.variables$var)) {
310
  # Get the mean of the normalised risk for every value of the variable
311
  risk.aggregated <-
312
    aggregate(
313
      as.formula(paste0('risk.normalised ~ val')),
314
      subset(risk.by.variables, var == variable), median
315
    )
316
  
317
  # work out the limits on the axes by taking the 1st and 99th percentiles
318
  x.axis.limits <-
319
    quantile(COHORT.use[, variable], c(0.01, 0.99), na.rm = TRUE)
320
  y.axis.limits <-
321
    quantile(subset(risk.by.variables, var == variable)$risk.normalised, c(0.05, 0.95), na.rm = TRUE)
322
  
323
  # If there's a missing value risk in the graph above, expand the axes so they
324
  # match
325
  if(any(cph.coeffs$var == variable & cph.coeffs$level == 'missing')) {
326
    x.axis.limits[2] <- 
327
      x.axis.limits[2] + diff(x.data.range) * missing.padding
328
  }
329
  
330
  rf.vareff.plots[[variable]] <-
331
    ggplot(
332
      subset(risk.by.variables, var == variable), 
333
      aes(x = val, y = log(risk.normalised))
334
    ) +
335
    geom_line(alpha=0.003, aes(group = id)) +
336
    geom_line(data = risk.aggregated, colour = 'blue') +
337
    coord_cartesian(xlim = x.axis.limits, ylim = log(y.axis.limits)) +
338
    labs(x = variable) +
339
    theme(
340
      plot.margin = unit(c(0.2, 0.1, 0.2, 0.1), "cm"),
341
      axis.title.y = element_blank()
342
    )
343
}
344
345
346
plot_grid(
347
  cox.discrete.plots[['age']],
348
  cox.discrete.plots[['haemoglobin_6mo']],
349
  cox.discrete.plots[['total_wbc_6mo']],
350
  cox.discrete.plots[['crea_6mo']],
351
  rf.vareff.plots[['age']],
352
  rf.vareff.plots[['haemoglobin_6mo']],
353
  rf.vareff.plots[['total_wbc_6mo']],
354
  rf.vareff.plots[['crea_6mo']],
355
  labels = c('A', rep('', 3), 'B', rep('', 3)),
356
  align = "v", ncol = 4
357
)
358
359
ggsave(
360
  '../../output/variable-effects.pdf',
361
  width = 16,
362
  height = 10,
363
  units = 'cm',
364
  useDingbats = FALSE
365
)