Switch to unified view

a b/random-forest/rf-varselrf-eqv.R
1
#+ knitr_setup, include = FALSE
2
3
# Whether to cache the intensive code sections. Set to FALSE to recalculate
4
# everything afresh.
5
cacheoption <- TRUE
6
# Disable lazy caching globally, because it fails for large objects, and all the
7
# objects we wish to cache are large...
8
opts_chunk$set(cache.lazy = FALSE)
9
10
#' # Variable selection in data-driven health records
11
#' 
12
#' Having extracted around 600 variables which occur most frequently in patient
13
#' records, let's try to narrow these down using the methodology of varSelRf
14
#' with a couple of custom additions (potentially multiple variable importance
15
#' calculations at the outset, and cross-validation to choose the best number of
16
#' variables at the end.)
17
#' 
18
#' ## User variables
19
#' 
20
#+ user_variables
21
22
output.filename.base <- '../../output/rf-bigdata-try12-varselrf2'
23
24
nsplit <- 20
25
n.trees.initial <- 500
26
n.forests.initial <- 5
27
n.trees.cv <- 500
28
n.trees.final <- 2000
29
split.rule <- 'logrank'
30
n.imputations <- 3
31
cv.n.folds <- 3
32
vars.drop.frac <- 0.2 # Fraction of variables to drop at each iteration
33
bootstraps <- 200
34
35
n.data <- NA # This is after any variables being excluded in prep
36
37
n.threads <- 20
38
39
#' ## Data set-up
40
#' 
41
#+ data_setup
42
43
data.filename.big <- '../../data/cohort-datadriven-02.csv'
44
45
surv.predict.old <- c('age', 'smokstatus', 'imd_score', 'gender')
46
untransformed.vars <- c('time_death', 'endpoint_death', 'exclude')
47
48
source('../lib/shared.R')
49
require(xtable)
50
51
# Define these after shared.R or they will be overwritten!
52
exclude.vars <-
53
  c(
54
    # Entity type 4 is smoking status, which we already have
55
    "clinical.values.4_data1", "clinical.values.4_data5",
56
    "clinical.values.4_data6",
57
    # Entity 13 data2 is the patient's weight centile, and not a single one is
58
    # entered, but they come out as 0 so the algorithm, looking for NAs, thinks
59
    # it's a useful column
60
    "clinical.values.13_data2",
61
    # Entities 148 and 149 are to do with death certification. I'm not sure how 
62
    # it made it into the dataset, but since all the datapoints in this are
63
    # looking back in time, they're all NA. This causes rfsrc to fail.
64
    "clinical.values.148_data1", "clinical.values.148_data2",
65
    "clinical.values.148_data3", "clinical.values.148_data4",
66
    "clinical.values.148_data5",
67
    "clinical.values.149_data1", "clinical.values.149_data2"
68
  )
69
70
COHORT <- fread(data.filename.big)
71
72
l2df <- function(x) {
73
  # Simple function to turn a list into a data frame so we can compare variable
74
  # importances across iterations of the initial forests
75
  rownames <- unique(unlist(lapply(x, names)))
76
  df <-
77
    data.frame(
78
      matrix(
79
        unlist(lapply(x, FUN = function(x){x[rownames]})),
80
        ncol = length(x), byrow = FALSE
81
      )
82
    )
83
  rownames(df) <- rownames
84
  df
85
}
86
87
bigdata.prefixes <-
88
  c(
89
    'hes.icd.',
90
    'hes.opcs.',
91
    'tests.enttype.',
92
    'clinical.history.',
93
    'clinical.values.',
94
    'bnf.'
95
  )
96
97
bigdata.columns <-
98
  colnames(COHORT)[
99
    which(
100
      # Does is start with one of the data column names?
101
      startsWithAny(names(COHORT), bigdata.prefixes) &
102
        # And it's not one of the columns we want to exclude?
103
        !(colnames(COHORT) %in% exclude.vars)
104
    )
105
  ]
106
107
COHORT.bigdata <-
108
  COHORT[, c(
109
    untransformed.vars, surv.predict.old, bigdata.columns
110
  ),
111
  with = FALSE
112
  ]
113
114
# Deal appropriately with missing data
115
# Most of the variables are number of days since the first record of that type
116
time.based.vars <-
117
  names(COHORT.bigdata)[
118
    startsWithAny(
119
      names(COHORT.bigdata),
120
      c('hes.icd.', 'hes.opcs.', 'clinical.history.')
121
    )
122
    ]
123
# We're dealing with this as a logical, so we want non-NA values to be TRUE,
124
# is there is something in the history
125
for (j in time.based.vars) {
126
  set(COHORT.bigdata, j = j, value = !is.na(COHORT.bigdata[[j]]))
127
}
128
129
# Again, taking this as a logical, set any non-NA value to TRUE.
130
prescriptions.vars <- names(COHORT.bigdata)[startsWith(names(COHORT.bigdata), 'bnf.')]
131
for (j in prescriptions.vars) {
132
  set(COHORT.bigdata, j = j, value = !is.na(COHORT.bigdata[[j]]))
133
}
134
135
# This leaves tests and clinical.values, which are test results and should be
136
# imputed.
137
138
# Manually fix clinical values items...
139
#
140
# "clinical.values.1_data1"  "clinical.values.1_data2"
141
# These are just blood pressure values...fine to impute
142
#
143
# "clinical.values.13_data1" "clinical.values.13_data3"
144
# These are weight and BMI...also fine to impute
145
#
146
# Entity 5 is alcohol consumption status, 1 = Yes, 2 = No, 3 = Ex, so should be
147
# a factor, and NA can be a factor level
148
COHORT.bigdata$clinical.values.5_data1 <-
149
  factorNAfix(factor(COHORT.bigdata$clinical.values.5_data1), NAval = 'missing')
150
151
# Both gender and smokstatus are factors...fix that
152
COHORT.bigdata$gender <- factor(COHORT.bigdata$gender)
153
COHORT.bigdata$smokstatus <-
154
  factorNAfix(factor(COHORT.bigdata$smokstatus), NAval = 'missing')
155
156
# Exclude invalid patients
157
COHORT.bigdata <- COHORT.bigdata[!COHORT.bigdata$exclude]
158
COHORT.bigdata$exclude <- NULL
159
160
COHORT.bigdata <-
161
  prepSurvCol(data.frame(COHORT.bigdata), 'time_death', 'endpoint_death', 'Death')
162
163
# If n.data was specified, trim the data table down to size
164
if(!is.na(n.data)) {
165
  COHORT.bigdata <- sample.df(COHORT.bigdata, n.data)
166
}
167
168
# Define test set
169
test.set <- testSetIndices(COHORT.bigdata, random.seed = 78361)
170
171
# Start by predicting survival with all the variables provided
172
surv.predict <- c(surv.predict.old, bigdata.columns)
173
174
# Set up a csv file to store calibration data, or retrieve previous data
175
calibration.filename <- paste0(output.filename.base, '-varselcalibration.csv')
176
177
#' ## Run random forest calibration
178
#' 
179
#' If there's not already a calibration file, we run the rfVarSel methodology:
180
#'   1. Fit a big forest to the whole dataset to obtain variable importances.
181
#'   2. Cross-validate as number of most important variables kept is reduced.
182
#' 
183
#' (If there is already a calibration file, just load the previous work.)
184
#' 
185
#+ rf_var_sel_calibration
186
187
# If we've not already done a calibration, then do one
188
if(!file.exists(calibration.filename)) {
189
  # Start by fitting several forests to all the variables
190
  vimps.initial <- list()
191
  time.fit <- c()
192
  time.vimp <- c()
193
  
194
  # If we haven't already made the initial forests...
195
  if(!file.exists(
196
    paste0(output.filename.base,'-varimp-', n.forests.initial,'.rds'))
197
    ) {
198
    for(i in 1:n.forests.initial) {
199
      time.start <- handyTimer()
200
      surv.model.fit.full <-
201
        survivalFit(
202
          surv.predict,
203
          COHORT.bigdata[-test.set,],
204
          model.type = 'rfsrc',
205
          n.trees = n.trees.initial,
206
          split.rule = split.rule,
207
          n.threads = n.threads,
208
          nimpute = 3,
209
          nsplit = nsplit,
210
          na.action = 'na.impute'
211
        )
212
      time.fit <- c(time.fit, handyTimer(time.start))
213
      
214
      # Save the model
215
      saveRDS(
216
        surv.model.fit.full,
217
        paste0(output.filename.base,'-initialmodel-', i,'.rds')
218
      )
219
      
220
      # Calculate variable importance
221
      time.start <- handyTimer()
222
      var.imp <- vimp(surv.model.fit.full, importance = 'permute.ensemble')
223
      time.vimp <- c(time.vimp, handyTimer(time.start))
224
      # Save it
225
      saveRDS(var.imp, paste0(output.filename.base, '-varimp-', i,'.rds'))
226
      
227
      # Make a vector of variable importances and append to the list
228
      vimps.initial[[i]] <- sort(var.imp$importance, decreasing = TRUE)
229
    }
230
    
231
    cat('Total initial vimp time = ', sum(time.fit) + sum(time.vimp))
232
    cat('Average fit time = ', mean(time.fit))
233
    cat('Average vimp time = ', mean(time.vimp))
234
    
235
  } else {
236
    # If we already made the initial forests, just load them
237
    for(i in 1:n.forests.initial) {
238
      var.imp <- readRDS(paste0(output.filename.base, '-varimp-', i,'.rds'))
239
      vimps.initial[[i]] <- sort(var.imp$importance, decreasing = TRUE)
240
    }
241
  }
242
  
243
  # Convert the vimps.initial list to a dataframe, rowwise
244
  vimps.initial <- l2df(vimps.initial)
245
  
246
  # Take averages across rows to give variable importances to use
247
  var.importances <- sort(apply(vimps.initial, 1, mean), decreasing = TRUE)
248
  
249
  # Save the result
250
  saveRDS(var.importances, paste0(output.filename.base, '-varimp.rds'))
251
252
  # Create an empty data frame to aggregate stats per fold
253
  cv.performance <- data.frame()
254
  # Cross-validate over number of variables to try
255
  cv.vars <- getVarNums(length(var.importances))
256
  COHORT.cv <- COHORT.bigdata[-test.set, ]
257
  # Run crossvalidations. No need to parallelise because rfsrc is parallelised
258
  for(i in 1:length(cv.vars)) {
259
    # Get the subset of most important variables to use
260
    surv.predict.partial <- names(var.importances)[1:cv.vars[i]]
261
    
262
    # Get folds for cross-validation
263
    cv.folds <- cvFolds(nrow(COHORT.cv), cv.n.folds)
264
    
265
    cv.fold.performance <- data.frame()
266
    
267
    for(j in 1:cv.n.folds) {
268
      time.start <- handyTimer()
269
      # Fit model to the training set
270
      surv.model.fit <-
271
        survivalFit(
272
          surv.predict.partial,
273
          COHORT.cv[-cv.folds[[j]],],
274
          model.type = 'rfsrc',
275
          n.trees = n.trees.cv,
276
          split.rule = split.rule,
277
          n.threads = n.threads,
278
          nsplit = nsplit,
279
          nimpute = n.imputations,
280
          na.action = 'na.impute'
281
        )
282
      time.learn <- handyTimer(time.start)
283
      
284
      time.start <- handyTimer()
285
      # Get C-index on validation set
286
      c.index.val <-
287
        cIndex(
288
          surv.model.fit, COHORT.cv[cv.folds[[j]],],
289
          na.action = 'na.impute'
290
        )
291
      time.c.index <- handyTimer(time.start)
292
      
293
      time.start <- handyTimer()
294
      # Get calibration score validation set
295
      calibration.score <-
296
        calibrationScore(
297
          calibrationTable(
298
            surv.model.fit, COHORT.cv[cv.folds[[j]],], na.action = 'na.impute'
299
          )
300
        )
301
      time.calibration <- handyTimer(time.start)
302
      
303
      # Append the stats we've obtained from this fold
304
      cv.fold.performance <-
305
        rbind(
306
          cv.fold.performance,
307
          data.frame(
308
            calibration = i,
309
            cv.fold = j,
310
            n.vars = cv.vars[i],
311
            c.index.val,
312
            calibration.score,
313
            time.learn,
314
            time.c.index,
315
            time.calibration
316
          )
317
        )
318
      
319
    } # End cross-validation loop (j)
320
    
321
    
322
    # rbind the performance by fold
323
    cv.performance <-
324
      rbind(
325
        cv.performance,
326
        cv.fold.performance
327
      )
328
    
329
    # Save output at the end of each loop
330
    write.csv(cv.performance, calibration.filename)
331
    
332
  } # End calibration loop (i)
333
} else {
334
  cv.performance <- read.csv(calibration.filename)
335
  var.importances <- readRDS(paste0(output.filename.base, '-varimp.rds'))
336
}
337
338
#' ## Find the best model from the calibrations
339
#' 
340
#' ### Plot model performance
341
#' 
342
#+ model_performance
343
344
# Find the best calibration...
345
# First, average performance across cross-validation folds
346
cv.performance.average <-
347
  aggregate(
348
    c.index.val ~ n.vars,
349
    data = cv.performance,
350
    mean
351
  )
352
353
cv.calibration.average <-
354
  aggregate(
355
    area ~ n.vars,
356
    data = cv.performance,
357
    mean
358
  )
359
360
ggplot(cv.performance.average, aes(x = n.vars, y = c.index.val)) +
361
  geom_line() +
362
  geom_point(data = cv.performance) +
363
  ggtitle(label = 'C-index by n.vars')
364
365
ggplot(cv.calibration.average, aes(x = n.vars, y = area)) +
366
  geom_line() +
367
  geom_point(data = cv.performance) +
368
  ggtitle(label = 'Calibration performance by n.vars')
369
370
# Find the highest value
371
n.vars <-
372
  cv.performance.average$n.vars[
373
    which.max(cv.performance.average$c.index.val)
374
    ]
375
376
# Fit a full model with the variables provided
377
surv.predict.partial <- names(var.importances)[1:n.vars]
378
379
#' ## Best model
380
#' 
381
#' The best model contained `r n.vars` variables. Let's see what those were...
382
#' 
383
#+ variables_used
384
385
vars.df <-
386
  data.frame(
387
    vars = surv.predict.partial
388
  )
389
390
vars.df$descriptions <- lookUpDescriptions(surv.predict.partial)
391
392
vars.df$vimp <- var.importances[1:n.vars]/max(var.importances)
393
394
missingness <- sapply(COHORT, percentMissing)
395
396
vars.df$missingness <- missingness[names(var.importances)[1:n.vars]]
397
398
#+ variables_table, results='asis'
399
400
print(
401
  xtable(vars.df),
402
  type = 'html',
403
  include.rownames = FALSE
404
)
405
406
#' ## Perform the final fit
407
#' 
408
#' Having found the best number of variables by cross-validation, let's perform
409
#' the final fit with the full training set and `r n.trees.final` trees.
410
#' 
411
#+ final_fit
412
413
time.start <- handyTimer()
414
surv.model.fit.final <-
415
  survivalFit(
416
    surv.predict.partial,
417
    COHORT.bigdata[-test.set,],
418
    model.type = 'rfsrc',
419
    n.trees = n.trees.final,
420
    split.rule = split.rule,
421
    n.threads = n.threads,
422
    nimpute = 3,
423
    nsplit = nsplit,
424
    na.action = 'na.impute'
425
  )
426
time.fit.final <- handyTimer(time.start)
427
428
saveRDS(surv.model.fit.final, paste0(output.filename.base, '-finalmodel.rds'))
429
430
#' Final model of `r n.trees.final` trees fitted in `r round(time.fit.final)`
431
#' seconds!
432
#' 
433
#' Also bootstrap this final fitting stage. A fully proper bootstrap would
434
#' iterate over the whole model-building process including variable selection,
435
#' but that would be prohibitive in terms of computational time.
436
#' 
437
#+ bootstrap_final
438
439
surv.model.fit.boot <-
440
  survivalBootstrap(
441
    surv.predict.partial,
442
    COHORT.bigdata[-test.set,], # Training set
443
    COHORT.bigdata[test.set,],  # Test set
444
    model.type = 'rfsrc',
445
    n.trees = n.trees.final,
446
    split.rule = split.rule,
447
    n.threads = n.threads,
448
    nimpute = 3,
449
    nsplit = nsplit,
450
    na.action = 'na.impute',
451
    bootstraps = bootstraps
452
  )
453
454
COHORT.bigdata$surv_time_round <- round_any(COHORT.bigdata$surv_time, tod.round)
455
# Create a survival formula with the provided variable names...
456
surv.formula <-
457
  formula(
458
    paste0(
459
      # Survival object made in-formula
460
      'Surv(surv_time_round, surv_event) ~ ',
461
      # Predictor variables then make up the other side
462
      paste(surv.predict.partial, collapse = '+')
463
    )
464
  )
465
466
# rfsrc, if you installed it correctly, controls threading by changing an
467
# environment variable
468
options(rf.cores = n.threads)
469
  
470
surv.model.fit.boot <-  
471
  boot(
472
    formula = surv.formula,
473
    data = COHORT.bigdata[-test.set, ],
474
    statistic = bootstrapFitRfsrc,
475
    R = bootstraps,
476
    parallel = 'no',
477
    ncpus = 1, # disable parallelism because rfsrc can be run in parallel
478
    n.trees = n.trees,
479
    test.data =  COHORT.bigdata[test.set, ],
480
    # Boot requires named variables, so can't use ... here. This slight
481
    # kludge means that this will fail unless these three variables are
482
    # explicitly specified in the call.
483
    nimpute = nimpute,
484
    nsplit = nsplit,
485
    na.action = 'na.impute'
486
  )
487
488
489
# Get coefficients and variable importances from bootstrap fits
490
surv.model.fit.coeffs <- bootStats(surv.model.fit.boot, uncertainty = '95ci')
491
492
#' ## Performance
493
#' 
494
#' ### C-index
495
#' 
496
#' C-indices are **`r round(surv.model.fit.coeffs['c.train', 'val'], 3)`
497
#' (`r round(surv.model.fit.coeffs['c.train', 'lower'], 3)` - 
498
#' `r round(surv.model.fit.coeffs['c.train', 'upper'], 3)`)**
499
#' on the training set and
500
#' **`r round(surv.model.fit.coeffs['c.test', 'val'], 3)`
501
#' (`r round(surv.model.fit.coeffs['c.test', 'lower'], 3)` - 
502
#' `r round(surv.model.fit.coeffs['c.test', 'upper'], 3)`)** on the test set.
503
#' 
504
#'
505
#' ### Calibration
506
#' 
507
#' The bootstrapped calibration score is
508
#' **`r round(surv.model.fit.coeffs['calibration.score', 'val'], 3)`
509
#' (`r round(surv.model.fit.coeffs['calibration.score', 'lower'], 3)` - 
510
#' `r round(surv.model.fit.coeffs['calibration.score', 'upper'], 3)`)**.
511
#' 
512
#' Let's draw a representative curve from the unbootstrapped fit... (It would be
513
#' better to draw all the curves from the bootstrap fit to get an idea of
514
#' variability, but I've not implemented this yet.)
515
#' 
516
#+ calibration_plot
517
518
calibration.table <-
519
  calibrationTable(
520
    # Standard calibration options
521
    surv.model.fit.final, COHORT.bigdata[test.set,],
522
    # Always need to specify NA imputation for rfsrc
523
    na.action = 'na.impute'
524
  )
525
526
calibration.score <- calibrationScore(calibration.table)
527
528
calibrationPlot(calibration.table)
529
530
#' The area between the calibration curve and the diagonal is 
531
#' **`r round(calibration.score[['area']], 3)`** +/-
532
#' **`r round(calibration.score[['se']], 3)`**.
533
#'
534
#+ save_results
535
536
# Save performance results
537
varsToTable(
538
  data.frame(
539
    model = 'rf-varselrf2',
540
    imputation = FALSE,
541
    discretised = FALSE,
542
    c.index = surv.model.fit.coeffs['c.train', 'val'],
543
    c.index.lower = surv.model.fit.coeffs['c.train', 'lower'],
544
    c.index.upper = surv.model.fit.coeffs['c.train', 'upper'],
545
    calibration.score = surv.model.fit.coeffs['calibration.score', 'val'],
546
    calibration.score.lower =
547
      surv.model.fit.coeffs['calibration.score', 'lower'],
548
    calibration.score.upper =
549
      surv.model.fit.coeffs['calibration.score', 'upper']
550
  ),
551
  performance.file,
552
  index.cols = c('model', 'imputation', 'discretised')
553
)