Switch to unified view

a b/cox-ph/cox-discrete-varsellogrank.R
1
#+ knitr_setup, include = FALSE
2
3
# Whether to cache the intensive code sections. Set to FALSE to recalculate
4
# everything afresh.
5
cacheoption <- TRUE
6
# Disable lazy caching globally, because it fails for large objects, and all the
7
# objects we wish to cache are large...
8
opts_chunk$set(cache.lazy = FALSE)
9
10
#' # Variable selection in data-driven health records with discretised
11
#' # Cox models
12
#' 
13
#' Having extracted around 600 variables which occur most frequently in patient
14
#' records, let's try to narrow these down using a methodology based on varSelRf
15
#' combined with survival modelling. We'll find the predictability of variables
16
#' as defined by the p-value of a logrank test on survival curves of different
17
#' categories within that variable, and then iteratively throw out unimportant
18
#' variables, cross-validating for optimum performance.
19
#' 
20
#' ## User variables
21
#' 
22
#+ user_variables
23
24
output.filename.base <- '../../output/cox-bigdata-varsellogrank-01'
25
26
cv.n.folds <- 3
27
vars.drop.frac <- 0.2 # Fraction of variables to drop at each iteration
28
bootstraps <- 100
29
30
n.data <- NA # This is after any variables being excluded in prep
31
32
n.threads <- 20
33
34
#' ## Data set-up
35
#' 
36
#+ data_setup
37
38
data.filename.big <- '../../data/cohort-datadriven-02.csv'
39
40
surv.predict.old <- c('age', 'smokstatus', 'imd_score', 'gender')
41
untransformed.vars <- c('time_death', 'endpoint_death', 'exclude')
42
43
source('../lib/shared.R')
44
require(xtable)
45
46
# Define these after shared.R or they will be overwritten!
47
exclude.vars <-
48
  c(
49
    # Entity type 4 is smoking status, which we already have
50
    "clinical.values.4_data1", "clinical.values.4_data5",
51
    "clinical.values.4_data6",
52
    # Entity 13 data2 is the patient's weight centile, and not a single one is
53
    # entered, but they come out as 0 so the algorithm, looking for NAs, thinks
54
    # it's a useful column
55
    "clinical.values.13_data2",
56
    # Entities 148 and 149 are to do with death certification. I'm not sure how 
57
    # it made it into the dataset, but since all the datapoints in this are
58
    # looking back in time, they're all NA. This causes rfsrc to fail.
59
    "clinical.values.148_data1", "clinical.values.148_data2",
60
    "clinical.values.148_data3", "clinical.values.148_data4",
61
    "clinical.values.148_data5",
62
    "clinical.values.149_data1", "clinical.values.149_data2"
63
  )
64
65
COHORT <- fread(data.filename.big)
66
67
bigdata.prefixes <-
68
  c(
69
    'hes.icd.',
70
    'hes.opcs.',
71
    'tests.enttype.',
72
    'clinical.history.',
73
    'clinical.values.',
74
    'bnf.'
75
  )
76
77
bigdata.columns <-
78
  colnames(COHORT)[
79
    which(
80
      # Does is start with one of the data column names?
81
      startsWithAny(names(COHORT), bigdata.prefixes) &
82
        # And it's not one of the columns we want to exclude?
83
        !(colnames(COHORT) %in% exclude.vars)
84
    )
85
    ]
86
87
COHORT.bigdata <-
88
  COHORT[, c(
89
    untransformed.vars, surv.predict.old, bigdata.columns
90
  ),
91
  with = FALSE
92
  ]
93
94
# Get the missingness before we start removing missing values
95
missingness <- sort(sapply(COHORT.bigdata, percentMissing))
96
# Remove values for the 'untransformed.vars' above, which are the survival
97
# values plus exclude column
98
missingness <- missingness[!(names(missingness) %in% untransformed.vars)]
99
100
# Deal appropriately with missing data
101
# Most of the variables are number of days since the first record of that type
102
time.based.vars <-
103
  names(COHORT.bigdata)[
104
    startsWithAny(
105
      names(COHORT.bigdata),
106
      c('hes.icd.', 'hes.opcs.', 'clinical.history.')
107
    )
108
    ]
109
# We're dealing with this as a logical, so we want non-NA values to be TRUE,
110
# is there is something in the history
111
for (j in time.based.vars) {
112
  set(COHORT.bigdata, j = j, value = !is.na(COHORT.bigdata[[j]]))
113
}
114
115
# Again, taking this as a logical, set any non-NA value to TRUE.
116
prescriptions.vars <- names(COHORT.bigdata)[startsWith(names(COHORT.bigdata), 'bnf.')]
117
for (j in prescriptions.vars) {
118
  set(COHORT.bigdata, j = j, value = !is.na(COHORT.bigdata[[j]]))
119
}
120
121
# This leaves tests and clinical.values, which are test results and should be
122
# imputed.
123
124
# Manually fix clinical values items...
125
#
126
# "clinical.values.1_data1"  "clinical.values.1_data2"
127
# These are just blood pressure values...fine to impute
128
#
129
# "clinical.values.13_data1" "clinical.values.13_data3"
130
# These are weight and BMI...also fine to impute
131
#
132
# Entity 5 is alcohol consumption status, 1 = Yes, 2 = No, 3 = Ex, so should be
133
# a factor, and NA can be a factor level
134
COHORT.bigdata$clinical.values.5_data1 <-
135
  factorNAfix(factor(COHORT.bigdata$clinical.values.5_data1), NAval = 'missing')
136
137
# Both gender and smokstatus are factors...fix that
138
COHORT.bigdata$gender <- factor(COHORT.bigdata$gender)
139
COHORT.bigdata$smokstatus <-
140
  factorNAfix(factor(COHORT.bigdata$smokstatus), NAval = 'missing')
141
142
# Exclude invalid patients
143
COHORT.bigdata <- COHORT.bigdata[!COHORT.bigdata$exclude]
144
COHORT.bigdata$exclude <- NULL
145
146
# Remove negative survival times
147
COHORT.bigdata <- subset(COHORT.bigdata, time_death > 0)
148
149
# Define test set
150
test.set <- testSetIndices(COHORT.bigdata, random.seed = 78361)
151
152
# If n.data was specified, trim the data table down to size
153
if(!is.na(n.data)) {
154
  COHORT.bigdata <- sample.df(COHORT.bigdata, n.data)
155
}
156
157
# Create an appropraite survival column
158
COHORT.bigdata <- 
159
  prepSurvCol(
160
    data.frame(COHORT.bigdata), 'time_death', 'endpoint_death', 'Death'
161
  )
162
163
# Start by predicting survival with all the variables provided
164
surv.predict <- c(surv.predict.old, bigdata.columns)
165
166
# Set up a csv file to store calibration data, or retrieve previous data
167
calibration.filename <- paste0(output.filename.base, '-varselcalibration.csv')
168
169
varLogrankTest <- function(df, var) {
170
  # If there's only one category, this is a single-valued variable so you can't
171
  # do a logrank test on different values of it...
172
  if(length(unique(NArm(df[, var]))) == 1) {
173
    return(NA)
174
  }
175
  
176
  # If it's a logical, make an extra column for consistency of later code
177
  if(class(df[, var]) == 'logical') {
178
    df$groups <- factor(ifelse(df[, var], 'A', 'B'))
179
  # If it's numeric, split it into four quartiles
180
  } else if(class(df[, var]) == 'numeric') {
181
    # First, discard all rows where the value is missing
182
    df <- df[!is.na(df[, var]), ]
183
    # Then, assign quartiles
184
    df$groups <- 
185
      factor(
186
        findInterval(
187
          df[, var],
188
          quantile(df[, var], probs=c(0, 0.25, .5, .75, 1))
189
        )
190
      )
191
    
192
  } else {
193
    # Otherwise, it's a factor, so leave it as-is
194
    df$groups <- df[, var]
195
  }
196
  
197
  # Perform a logrank test on the data
198
  lr.test <- 
199
    survdiff(
200
      as.formula(paste0('Surv(surv_time, surv_event) ~ groups')),
201
      df
202
    )
203
  # Return the p-value of the logrank test
204
  pchisq(lr.test$chisq, length(lr.test$n)-1, lower.tail = FALSE)
205
}
206
207
# Don't use the output variables in our list
208
vars.to.check <-
209
  names(COHORT.bigdata)[!(names(COHORT.bigdata) %in% c('surv_time', 'surv_event'))]
210
211
var.logrank.p <-
212
  sapply(
213
    X = vars.to.check, FUN = varLogrankTest,
214
    df = COHORT.bigdata[-test.set, ]
215
  )
216
217
# Sort them, in ascending order because small p-values indicate differing
218
# survival curves
219
var.logrank.p <- sort(var.logrank.p, na.last = TRUE)
220
221
# Create process settings
222
223
# Variables to leave alone, including those whose logrank p-value is NA because
224
# that means there is only one value in the column and so it can't be discretised
225
# properly anyway
226
vars.noprocess <- c('surv_time', 'surv_event', names(var.logrank.p)[is.na(var.logrank.p)])
227
process.settings <-
228
  list(
229
    var        = vars.noprocess,
230
    method     = rep(NA, length(vars.noprocess)),
231
    settings   = rep(list(NA), length(vars.noprocess))
232
  )
233
# Find continuous variables which will need discretising
234
continuous.vars <- names(COHORT.bigdata)[sapply(COHORT.bigdata, class) %in% c('integer', 'numeric')]
235
# Remove those variables already explicitly excluded, mainly for those whose
236
# logrank score was NA
237
continuous.vars <- continuous.vars[!(continuous.vars %in% process.settings$var)]
238
process.settings$var <- c(process.settings$var, continuous.vars)
239
process.settings$method <-
240
  c(process.settings$method,
241
    rep('binByQuantile', length(continuous.vars))
242
  )
243
process.settings$settings <-
244
  c(
245
    process.settings$settings,
246
    rep(
247
      list(
248
        seq(
249
          # Quantiles are obviously between 0 and 1
250
          0, 1,
251
          # All have the same number of bins
252
          length.out = 10
253
        )
254
      ),
255
      length(continuous.vars)
256
    )
257
  )
258
259
COHORT.prep <-
260
  prepData(
261
    # Data for cross-validation excludes test set
262
    COHORT.bigdata,
263
    names(COHORT.bigdata),
264
    process.settings,
265
    'surv_time', 'surv_event',
266
    TRUE
267
  )
268
269
# Kludge...remove surv_time.1 and rename surv_event.1
270
COHORT.prep$surv_time.1 <- NULL
271
names(COHORT.prep)[names(COHORT.prep) == 'surv_event.1'] <- 'surv_event'
272
273
#' ## Run variable selection
274
#' 
275
#' If there's not already a calibration file, we run our variable selection
276
#' algorithm:
277
#'   1. Perform logrank tests on survival curves of subsets of the data to find
278
#'      those variables which seemingly have the largest effect on survival.
279
#'   2. Cross-validate as number of most important variables kept is reduced.
280
#' 
281
#' (If there is already a calibration file, just load the previous work.)
282
#' 
283
#+ cox_var_sel_calibration
284
285
# If we've not already done a calibration, then do one
286
if(!file.exists(calibration.filename)) {
287
  # Create an empty data frame to aggregate stats per fold
288
  cv.performance <- data.frame()
289
  
290
  # Cross-validate over number of variables to try
291
  cv.vars <-
292
    getVarNums(
293
      length(var.logrank.p),
294
      # no point going lower than the point at which all the p-values are 0,
295
      # because the order is alphabetical and therefore meaningless below this!
296
      min =  sum(var.logrank.p == 0, na.rm = TRUE)
297
    )
298
  
299
  COHORT.cv <- COHORT.prep[-test.set, ]
300
  
301
  # Run crossvalidations. No need to parallelise because rfsrc is parallelised
302
  for(i in 1:length(cv.vars)) {
303
    # Get the subset of most important variables to use
304
    surv.predict.partial <- names(var.logrank.p)[1:cv.vars[i]]
305
    
306
    # Get folds for cross-validation
307
    cv.folds <- cvFolds(nrow(COHORT.cv), cv.n.folds)
308
    
309
    cv.fold.performance <- data.frame()
310
    
311
    for(j in 1:cv.n.folds) {
312
      time.start <- handyTimer()
313
      # Fit model to the training set
314
      surv.model.fit <-
315
        survivalFit(
316
          surv.predict.partial,
317
          COHORT.cv[-cv.folds[[j]],],
318
          model.type = 'survreg',
319
          n.threads = n.threads
320
        )
321
      time.learn <- handyTimer(time.start)
322
      
323
      time.start <- handyTimer()
324
      # Get C-index on validation set
325
      c.index.val <-
326
        cIndex(
327
          surv.model.fit, COHORT.cv[cv.folds[[j]],]
328
        )
329
      time.c.index <- handyTimer(time.start)
330
      
331
      time.start <- handyTimer()
332
      # Get calibration score validation set
333
      calibration.score <-
334
        calibrationScore(
335
          calibrationTable(
336
            surv.model.fit, COHORT.cv[cv.folds[[j]],]
337
          )
338
        )
339
      time.calibration <- handyTimer(time.start)
340
      
341
      # Append the stats we've obtained from this fold
342
      cv.fold.performance <-
343
        rbind(
344
          cv.fold.performance,
345
          data.frame(
346
            calibration = i,
347
            cv.fold = j,
348
            n.vars = cv.vars[i],
349
            c.index.val,
350
            calibration.score,
351
            time.learn,
352
            time.c.index,
353
            time.calibration
354
          )
355
        )
356
      
357
    } # End cross-validation loop (j)
358
    
359
    
360
    # rbind the performance by fold
361
    cv.performance <-
362
      rbind(
363
        cv.performance,
364
        cv.fold.performance
365
      )
366
    
367
    # Save output at the end of each loop
368
    write.csv(cv.performance, calibration.filename)
369
    
370
  } # End calibration loop (i)
371
} else {
372
  cv.performance <- read.csv(calibration.filename)
373
}
374
375
#' ## Find the best model from the calibrations
376
#' 
377
#' ### Plot model performance
378
#' 
379
#+ model_performance
380
381
# Find the best calibration...
382
# First, average performance across cross-validation folds
383
cv.performance.average <-
384
  aggregate(
385
    c.index.val ~ n.vars,
386
    data = cv.performance,
387
    mean
388
  )
389
390
cv.calibration.average <-
391
  aggregate(
392
    area ~ n.vars,
393
    data = cv.performance,
394
    mean
395
  )
396
397
ggplot(cv.performance.average, aes(x = n.vars, y = c.index.val)) +
398
  geom_line() +
399
  geom_point(data = cv.performance) +
400
  ggtitle(label = 'C-index by n.vars')
401
402
ggplot(cv.calibration.average, aes(x = n.vars, y = area)) +
403
  geom_line() +
404
  geom_point(data = cv.performance) +
405
  ggtitle(label = 'Calibration performance by n.vars')
406
407
# Find the highest value
408
n.vars <-
409
  cv.performance.average$n.vars[
410
    which.max(cv.performance.average$c.index.val)
411
    ]
412
413
# Fit a full model with the variables provided
414
surv.predict.partial <- names(var.logrank.p)[1:n.vars]
415
416
#' ## Best model
417
#' 
418
#' The best model contained `r n.vars` variables. Let's see what those were...
419
#' 
420
#+ variables_used
421
422
vars.df <-
423
  data.frame(
424
    vars = surv.predict.partial
425
  )
426
427
vars.df$descriptions <- lookUpDescriptions(surv.predict.partial)
428
429
vars.df$missingness <- missingness[surv.predict.partial]
430
431
#+ variables_table, results='asis'
432
433
print(
434
  xtable(vars.df),
435
  type = 'html',
436
  include.rownames = FALSE
437
)
438
439
#' ## Perform the final fit
440
#' 
441
#' Having found the best number of variables by cross-validation, let's perform
442
#' the final fit with the full training set.
443
#' 
444
#+ final_fit
445
446
time.start <- handyTimer()
447
surv.model.fit.final <-
448
  survivalFit(
449
    surv.predict.partial,
450
    COHORT.prep[-test.set,],
451
    model.type = 'survreg'
452
  )
453
time.fit.final <- handyTimer(time.start)
454
455
saveRDS(surv.model.fit.final, paste0(output.filename.base, '-finalmodel.rds'))
456
457
#' Final model of fitted in `r round(time.fit.final)` seconds! 
458
#' 
459
#' Also bootstrap this final fitting stage. A fully proper bootstrap would
460
#' iterate over the whole model-building process including variable selection,
461
#' but that would be prohibitive in terms of computational time.
462
#' 
463
#+ bootstrap_final
464
465
time.start <- handyTimer()
466
surv.model.params.boot <-
467
  survivalFitBoot(
468
    surv.predict.partial,
469
    COHORT.prep[-test.set,], # Training set
470
    COHORT.prep[test.set,],  # Test set
471
    model.type = 'survreg',
472
    bootstraps = bootstraps,
473
    n.threads = n.threads,
474
    filename = paste0(output.filename.base, '-boot-all.csv')
475
  )
476
time.boot.final <- handyTimer(time.start)
477
478
#' `r bootstraps` bootstrap fits completed in `r time.boot.final` seconds!
479
480
# Get coefficients and variable importances from bootstrap fits
481
surv.model.fit.coeffs <- bootStatsDf(surv.model.params.boot)
482
483
# Save performance results
484
varsToTable(
485
  data.frame(
486
    model = 'cox-logrank',
487
    imputation = FALSE,
488
    discretised = TRUE,
489
    c.index = surv.model.fit.coeffs['c.index', 'val'],
490
    c.index.lower = surv.model.fit.coeffs['c.index', 'lower'],
491
    c.index.upper = surv.model.fit.coeffs['c.index', 'upper'],
492
    calibration.score = surv.model.fit.coeffs['calibration.score', 'val'],
493
    calibration.score.lower =
494
      surv.model.fit.coeffs['calibration.score', 'lower'],
495
    calibration.score.upper =
496
      surv.model.fit.coeffs['calibration.score', 'upper']
497
  ),
498
  performance.file,
499
  index.cols = c('model', 'imputation', 'discretised')
500
)
501
502
#' ## Performance
503
#' 
504
#' ### C-index
505
#' 
506
#' C-index is **`r round(surv.model.fit.coeffs['c.index', 'val'], 3)`
507
#' (`r round(surv.model.fit.coeffs['c.index', 'lower'], 3)` - 
508
#' `r round(surv.model.fit.coeffs['c.index', 'upper'], 3)`)**
509
#' on the held-out test set.
510
#' 
511
#'
512
#' ### Calibration
513
#' 
514
#' The bootstrapped calibration score is
515
#' **`r round(surv.model.fit.coeffs['calibration.score', 'val'], 3)`
516
#' (`r round(surv.model.fit.coeffs['calibration.score', 'lower'], 3)` - 
517
#' `r round(surv.model.fit.coeffs['calibration.score', 'upper'], 3)`)**.
518
#' 
519
#' Let's draw a representative curve from the unbootstrapped fit... (It would be
520
#' better to draw all the curves from the bootstrap fit to get an idea of
521
#' variability, but I've not implemented this yet.)
522
#' 
523
#+ calibration_plot
524
525
calibration.table <-
526
  calibrationTable(surv.model.fit.final, COHORT.prep[test.set,])
527
528
calibration.score <- calibrationScore(calibration.table)
529
530
calibrationPlot(calibration.table, show.censored = TRUE)
531
532
# Save the calibration table for plotting later
533
write.csv(
534
  calibration.table,
535
  paste0(output.filename.base, '-calibration-table.csv')
536
)