Switch to unified view

a b/random-forest/rf-varsellogrank.R
1
#+ knitr_setup, include = FALSE
2
3
# Whether to cache the intensive code sections. Set to FALSE to recalculate
4
# everything afresh.
5
cacheoption <- TRUE
6
# Disable lazy caching globally, because it fails for large objects, and all the
7
# objects we wish to cache are large...
8
opts_chunk$set(cache.lazy = FALSE)
9
10
#' # Variable selection in data-driven health records
11
#' 
12
#' Having extracted around 600 variables which occur most frequently in patient
13
#' records, let's try to narrow these down using a methodology based on varSelRf
14
#' combined with survival modelling. We'll find the predictability of variables
15
#' as defined by the p-value of a logrank test on survival curves of different
16
#' categories within that variable, and then iteratively throw out unimportant
17
#' variables, cross-validating for optimum performance.
18
#' 
19
#' ## User variables
20
#' 
21
#+ user_variables
22
23
output.filename.base <- '../../output/rf-bigdata-varsellogrank-02'
24
25
nsplit <- 20
26
n.trees.cv <- 500
27
n.trees.final <- 500
28
split.rule <- 'logrank'
29
n.imputations <- 3
30
cv.n.folds <- 3
31
vars.drop.frac <- 0.2 # Fraction of variables to drop at each iteration
32
bootstraps <- 20
33
34
n.data <- NA # This is after any variables being excluded in prep
35
36
n.threads <- 40
37
38
#' ## Data set-up
39
#' 
40
#+ data_setup
41
42
data.filename.big <- '../../data/cohort-datadriven-02.csv'
43
44
surv.predict.old <- c('age', 'smokstatus', 'imd_score', 'gender')
45
untransformed.vars <- c('time_death', 'endpoint_death', 'exclude')
46
47
source('../lib/shared.R')
48
require(xtable)
49
50
# Define these after shared.R or they will be overwritten!
51
exclude.vars <-
52
  c(
53
    # Entity type 4 is smoking status, which we already have
54
    "clinical.values.4_data1", "clinical.values.4_data5",
55
    "clinical.values.4_data6",
56
    # Entity 13 data2 is the patient's weight centile, and not a single one is
57
    # entered, but they come out as 0 so the algorithm, looking for NAs, thinks
58
    # it's a useful column
59
    "clinical.values.13_data2",
60
    # Entities 148 and 149 are to do with death certification. I'm not sure how 
61
    # it made it into the dataset, but since all the datapoints in this are
62
    # looking back in time, they're all NA. This causes rfsrc to fail.
63
    "clinical.values.148_data1", "clinical.values.148_data2",
64
    "clinical.values.148_data3", "clinical.values.148_data4",
65
    "clinical.values.148_data5",
66
    "clinical.values.149_data1", "clinical.values.149_data2"
67
  )
68
69
COHORT <- fread(data.filename.big)
70
71
bigdata.prefixes <-
72
  c(
73
    'hes.icd.',
74
    'hes.opcs.',
75
    'tests.enttype.',
76
    'clinical.history.',
77
    'clinical.values.',
78
    'bnf.'
79
  )
80
81
bigdata.columns <-
82
  colnames(COHORT)[
83
    which(
84
      # Does is start with one of the data column names?
85
      startsWithAny(names(COHORT), bigdata.prefixes) &
86
        # And it's not one of the columns we want to exclude?
87
        !(colnames(COHORT) %in% exclude.vars)
88
    )
89
    ]
90
91
COHORT.bigdata <-
92
  COHORT[, c(
93
    untransformed.vars, surv.predict.old, bigdata.columns
94
  ),
95
  with = FALSE
96
  ]
97
98
# Get the missingness before we start removing missing values
99
missingness <- sort(sapply(COHORT.bigdata, percentMissing))
100
# Remove values for the 'untransformed.vars' above, which are the survival
101
# values plus exclude column
102
missingness <- missingness[!(names(missingness) %in% untransformed.vars)]
103
104
# Deal appropriately with missing data
105
# Most of the variables are number of days since the first record of that type
106
time.based.vars <-
107
  names(COHORT.bigdata)[
108
    startsWithAny(
109
      names(COHORT.bigdata),
110
      c('hes.icd.', 'hes.opcs.', 'clinical.history.')
111
    )
112
    ]
113
# We're dealing with this as a logical, so we want non-NA values to be TRUE,
114
# is there is something in the history
115
for (j in time.based.vars) {
116
  set(COHORT.bigdata, j = j, value = !is.na(COHORT.bigdata[[j]]))
117
}
118
119
# Again, taking this as a logical, set any non-NA value to TRUE.
120
prescriptions.vars <- names(COHORT.bigdata)[startsWith(names(COHORT.bigdata), 'bnf.')]
121
for (j in prescriptions.vars) {
122
  set(COHORT.bigdata, j = j, value = !is.na(COHORT.bigdata[[j]]))
123
}
124
125
# This leaves tests and clinical.values, which are test results and should be
126
# imputed.
127
128
# Manually fix clinical values items...
129
#
130
# "clinical.values.1_data1"  "clinical.values.1_data2"
131
# These are just blood pressure values...fine to impute
132
#
133
# "clinical.values.13_data1" "clinical.values.13_data3"
134
# These are weight and BMI...also fine to impute
135
#
136
# Entity 5 is alcohol consumption status, 1 = Yes, 2 = No, 3 = Ex, so should be
137
# a factor, and NA can be a factor level
138
COHORT.bigdata$clinical.values.5_data1 <-
139
  factorNAfix(factor(COHORT.bigdata$clinical.values.5_data1), NAval = 'missing')
140
141
# Both gender and smokstatus are factors...fix that
142
COHORT.bigdata$gender <- factor(COHORT.bigdata$gender)
143
COHORT.bigdata$smokstatus <-
144
  factorNAfix(factor(COHORT.bigdata$smokstatus), NAval = 'missing')
145
146
# Exclude invalid patients
147
COHORT.bigdata <- COHORT.bigdata[!COHORT.bigdata$exclude]
148
COHORT.bigdata$exclude <- NULL
149
150
COHORT.bigdata <-
151
  prepSurvCol(data.frame(COHORT.bigdata), 'time_death', 'endpoint_death', 'Death')
152
153
# If n.data was specified, trim the data table down to size
154
if(!is.na(n.data)) {
155
  COHORT.bigdata <- sample.df(COHORT.bigdata, n.data)
156
}
157
158
# Define test set
159
test.set <- testSetIndices(COHORT.bigdata, random.seed = 78361)
160
161
# Start by predicting survival with all the variables provided
162
surv.predict <- c(surv.predict.old, bigdata.columns)
163
164
# Set up a csv file to store calibration data, or retrieve previous data
165
calibration.filename <- paste0(output.filename.base, '-varselcalibration.csv')
166
167
varLogrankTest <- function(df, var) {
168
  # If there's only one category, this is a single-valued variable so you can't
169
  # do a logrank test on different values of it...
170
  if(length(unique(NArm(df[, var]))) == 1) {
171
    return(NA)
172
  }
173
  
174
  # If it's a logical, make an extra column for consistency of later code
175
  if(class(df[, var]) == 'logical') {
176
    df$groups <- factor(ifelse(df[, var], 'A', 'B'))
177
  # If it's numeric, split it into four quartiles
178
  } else if(class(df[, var]) == 'numeric') {
179
    # First, discard all rows where the value is missing
180
    df <- df[!is.na(df[, var]), ]
181
    # Then, assign quartiles
182
    df$groups <- 
183
      factor(
184
        findInterval(
185
          df[, var],
186
          quantile(df[, var], probs=c(0, 0.25, .5, .75, 1))
187
        )
188
      )
189
    
190
  } else {
191
    # Otherwise, it's a factor, so leave it as-is
192
    df$groups <- df[, var]
193
  }
194
  
195
  # Perform a logrank test on the data
196
  lr.test <- 
197
    survdiff(
198
      as.formula(paste0('Surv(surv_time, surv_event) ~ groups')),
199
      df
200
    )
201
  # Return the p-value of the logrank test
202
  pchisq(lr.test$chisq, length(lr.test$n)-1, lower.tail = FALSE)
203
}
204
205
# Don't use the output variables in our list
206
vars.to.check <-
207
  names(COHORT.bigdata)[!(names(COHORT.bigdata) %in% c('surv_time', 'surv_event'))]
208
209
var.logrank.p <-
210
  sapply(
211
    X = vars.to.check, FUN = varLogrankTest,
212
    df = COHORT.bigdata[-test.set, ]
213
  )
214
215
# Sort them, in ascending order because small p-values indicate differing
216
# survival curves
217
var.logrank.p <- sort(var.logrank.p)
218
219
#' ## Run random forest calibration
220
#' 
221
#' If there's not already a calibration file, we run the rfVarSel methodology:
222
#'   1. Fit a big forest to the whole dataset to obtain variable importances.
223
#'   2. Cross-validate as number of most important variables kept is reduced.
224
#' 
225
#' (If there is already a calibration file, just load the previous work.)
226
#' 
227
#+ rf_var_sel_calibration
228
229
# If we've not already done a calibration, then do one
230
if(!file.exists(calibration.filename)) {
231
  # Create an empty data frame to aggregate stats per fold
232
  cv.performance <- data.frame()
233
  
234
  # Cross-validate over number of variables to try
235
  cv.vars <-
236
    getVarNums(
237
      length(var.logrank.p),
238
      # no point going lower than the point at which all the p-values are 0,
239
      # because the order is alphabetical and therefore meaningless below this!
240
      min =  sum(var.logrank.p == 0)
241
    )
242
  
243
  COHORT.cv <- COHORT.bigdata[-test.set, ]
244
  
245
  # Run crossvalidations. No need to parallelise because rfsrc is parallelised
246
  for(i in 1:length(cv.vars)) {
247
    # Get the subset of most important variables to use
248
    surv.predict.partial <- names(var.logrank.p)[1:cv.vars[i]]
249
    
250
    # Get folds for cross-validation
251
    cv.folds <- cvFolds(nrow(COHORT.cv), cv.n.folds)
252
    
253
    cv.fold.performance <- data.frame()
254
    
255
    for(j in 1:cv.n.folds) {
256
      time.start <- handyTimer()
257
      # Fit model to the training set
258
      surv.model.fit <-
259
        survivalFit(
260
          surv.predict.partial,
261
          COHORT.cv[-cv.folds[[j]],],
262
          model.type = 'rfsrc',
263
          n.trees = n.trees.cv,
264
          split.rule = split.rule,
265
          n.threads = n.threads,
266
          nsplit = nsplit,
267
          nimpute = n.imputations,
268
          na.action = 'na.impute'
269
        )
270
      time.learn <- handyTimer(time.start)
271
      
272
      time.start <- handyTimer()
273
      # Get C-index on validation set
274
      c.index.val <-
275
        cIndex(
276
          surv.model.fit, COHORT.cv[cv.folds[[j]],],
277
          na.action = 'na.impute'
278
        )
279
      time.c.index <- handyTimer(time.start)
280
      
281
      time.start <- handyTimer()
282
      # Get calibration score validation set
283
      calibration.score <-
284
        calibrationScore(
285
          calibrationTable(
286
            surv.model.fit, COHORT.cv[cv.folds[[j]],], na.action = 'na.impute'
287
          )
288
        )
289
      time.calibration <- handyTimer(time.start)
290
      
291
      # Append the stats we've obtained from this fold
292
      cv.fold.performance <-
293
        rbind(
294
          cv.fold.performance,
295
          data.frame(
296
            calibration = i,
297
            cv.fold = j,
298
            n.vars = cv.vars[i],
299
            c.index.val,
300
            calibration.score,
301
            time.learn,
302
            time.c.index,
303
            time.calibration
304
          )
305
        )
306
      
307
    } # End cross-validation loop (j)
308
    
309
    
310
    # rbind the performance by fold
311
    cv.performance <-
312
      rbind(
313
        cv.performance,
314
        cv.fold.performance
315
      )
316
    
317
    # Save output at the end of each loop
318
    write.csv(cv.performance, calibration.filename)
319
    
320
  } # End calibration loop (i)
321
} else {
322
  cv.performance <- read.csv(calibration.filename)
323
}
324
325
#' ## Find the best model from the calibrations
326
#' 
327
#' ### Plot model performance
328
#' 
329
#+ model_performance
330
331
# Find the best calibration...
332
# First, average performance across cross-validation folds
333
cv.performance.average <-
334
  aggregate(
335
    c.index.val ~ n.vars,
336
    data = cv.performance,
337
    mean
338
  )
339
340
cv.calibration.average <-
341
  aggregate(
342
    area ~ n.vars,
343
    data = cv.performance,
344
    mean
345
  )
346
347
ggplot(cv.performance.average, aes(x = n.vars, y = c.index.val)) +
348
  geom_line() +
349
  geom_point(data = cv.performance) +
350
  ggtitle(label = 'C-index by n.vars')
351
352
ggplot(cv.calibration.average, aes(x = n.vars, y = area)) +
353
  geom_line() +
354
  geom_point(data = cv.performance) +
355
  ggtitle(label = 'Calibration performance by n.vars')
356
357
# Find the highest value
358
n.vars <-
359
  cv.performance.average$n.vars[
360
    which.max(cv.performance.average$c.index.val)
361
    ]
362
363
# Fit a full model with the variables provided
364
surv.predict.partial <- names(var.logrank.p)[1:n.vars]
365
366
#' ## Best model
367
#' 
368
#' The best model contained `r n.vars` variables. Let's see what those were...
369
#' 
370
#+ variables_used
371
372
vars.df <-
373
  data.frame(
374
    vars = surv.predict.partial
375
  )
376
377
vars.df$descriptions <- lookUpDescriptions(surv.predict.partial)
378
379
vars.df$missingness <- missingness[surv.predict.partial]
380
381
#+ variables_table, results='asis'
382
383
print(
384
  xtable(vars.df),
385
  type = 'html',
386
  include.rownames = FALSE
387
)
388
389
#' ## Perform the final fit
390
#' 
391
#' Having found the best number of variables by cross-validation, let's perform
392
#' the final fit with the full training set and `r n.trees.final` trees.
393
#' 
394
#+ final_fit
395
396
time.start <- handyTimer()
397
surv.model.fit.final <-
398
  survivalFit(
399
    surv.predict.partial,
400
    COHORT.bigdata[-test.set,],
401
    model.type = 'rfsrc',
402
    n.trees = n.trees.final,
403
    split.rule = split.rule,
404
    n.threads = n.threads,
405
    nimpute = 3,
406
    nsplit = nsplit,
407
    na.action = 'na.impute'
408
  )
409
time.fit.final <- handyTimer(time.start)
410
411
saveRDS(surv.model.fit.final, paste0(output.filename.base, '-finalmodel.rds'))
412
413
#' Final model of `r n.trees.final` trees fitted in `r round(time.fit.final)`
414
#' seconds! 
415
#' 
416
#' Also bootstrap this final fitting stage. A fully proper bootstrap would
417
#' iterate over the whole model-building process including variable selection,
418
#' but that would be prohibitive in terms of computational time.
419
#' 
420
#+ bootstrap_final
421
422
time.start <- handyTimer()
423
surv.model.params.boot <-
424
  survivalFitBoot(
425
    surv.predict.partial,
426
    COHORT.bigdata[-test.set,], # Training set
427
    COHORT.bigdata[test.set,],  # Test set
428
    model.type = 'rfsrc',
429
    n.threads = n.threads,
430
    bootstraps = bootstraps,
431
    filename = paste0(output.filename.base, '-boot-all.csv'),
432
    na.action = 'na.impute'
433
  )
434
time.fit.boot <- handyTimer(time.start)
435
436
#' `r bootstraps` bootstraps completed in `r round(time.fit.boot)` seconds!
437
438
# Get coefficients and variable importances from bootstrap fits
439
surv.model.fit.coeffs <- bootStatsDf(surv.model.params.boot)
440
441
# Save performance results
442
varsToTable(
443
  data.frame(
444
    model = 'rf-varsellr',
445
    imputation = FALSE,
446
    discretised = FALSE,
447
    c.index = surv.model.fit.coeffs['c.index', 'val'],
448
    c.index.lower = surv.model.fit.coeffs['c.index', 'lower'],
449
    c.index.upper = surv.model.fit.coeffs['c.index', 'upper'],
450
    calibration.score = surv.model.fit.coeffs['calibration.score', 'val'],
451
    calibration.score.lower =
452
      surv.model.fit.coeffs['calibration.score', 'lower'],
453
    calibration.score.upper =
454
      surv.model.fit.coeffs['calibration.score', 'upper']
455
  ),
456
  performance.file,
457
  index.cols = c('model', 'imputation', 'discretised')
458
)
459
460
write.csv(
461
  surv.model.fit.coeffs,
462
  paste0(output.filename.base, '-boot-summary.csv')
463
)
464
465
#' ## Performance
466
#' 
467
#' ### C-index
468
#' 
469
#' C-indices are **`r round(surv.model.fit.coeffs['c.train', 'val'], 3)`
470
#' (`r round(surv.model.fit.coeffs['c.train', 'lower'], 3)` - 
471
#' `r round(surv.model.fit.coeffs['c.train', 'upper'], 3)`)**
472
#' on the training set and
473
#' **`r round(surv.model.fit.coeffs['c.test', 'val'], 3)`
474
#' (`r round(surv.model.fit.coeffs['c.test', 'lower'], 3)` - 
475
#' `r round(surv.model.fit.coeffs['c.test', 'upper'], 3)`)** on the test set.
476
#' 
477
#'
478
#' ### Calibration
479
#' 
480
#' The bootstrapped calibration score is
481
#' **`r round(surv.model.fit.coeffs['calibration.score', 'val'], 3)`
482
#' (`r round(surv.model.fit.coeffs['calibration.score', 'lower'], 3)` - 
483
#' `r round(surv.model.fit.coeffs['calibration.score', 'upper'], 3)`)**.
484
#' 
485
#' Let's draw a representative curve from the unbootstrapped fit... (It would be
486
#' better to draw all the curves from the bootstrap fit to get an idea of
487
#' variability, but I've not implemented this yet.)
488
#' 
489
#+ calibration_plot
490
491
calibration.table <-
492
  calibrationTable(
493
    # Standard calibration options
494
    surv.model.fit.final, COHORT.bigdata[test.set,],
495
    # Always need to specify NA imputation for rfsrc
496
    na.action = 'na.impute'
497
  )
498
499
calibration.score <- calibrationScore(calibration.table)
500
501
calibrationPlot(calibration.table)
502
503
#' The area between the calibration curve and the diagonal is 
504
#' **`r round(calibration.score, 3)`**