Switch to unified view

a b/lib/rfsrc-cv-mtry-nsplit-logical.R
1
bootstraps <- 3
2
split.rule <- 'logrank'
3
n.threads <- 20
4
5
# Cross-validation variables
6
ns.splits <- c(0, 5, 10, 15, 20, 30)
7
ms.try <- c(50, 100, 200, 300, 400)
8
n.trees.cv  <- 500
9
n.imputations <- 3
10
cv.n.folds <- 3
11
n.trees.final <- 2000
12
n.data <- NA # This is of full dataset...further rows may be excluded in prep
13
14
calibration.filename <- paste0(output.filename.base, '-calibration.csv')
15
16
# If we've not already done a calibration, then do one
17
if(!file.exists(calibration.filename)) {
18
  # Create an empty data frame to aggregate stats per fold
19
  cv.performance <- data.frame()
20
  
21
  # Items to cross-validate over
22
  cv.vars <- expand.grid(ns.splits, ms.try)
23
  names(cv.vars) <- c('n.splits', 'm.try')
24
  
25
  COHORT.cv <- COHORT.bigdata[-test.set, ]
26
  
27
  # Run crossvalidations. No need to parallelise because rfsrc is parallelised
28
  for(i in 1:nrow(cv.vars)) {
29
    cat(
30
      'Calibration', i, '...\n'
31
    )
32
    
33
    # Get folds for cross-validation
34
    cv.folds <- cvFolds(nrow(COHORT.cv), cv.n.folds)
35
    
36
    cv.fold.performance <- data.frame()
37
    
38
    for(j in 1:cv.n.folds) {
39
      time.start <- handyTimer()
40
      # Fit model to the training set
41
      surv.model.fit <-
42
        survivalFit(
43
          surv.predict,
44
          COHORT.cv[-cv.folds[[j]],],
45
          model.type = 'rfsrc',
46
          n.trees = n.trees.cv,
47
          split.rule = split.rule,
48
          n.threads = n.threads,
49
          nsplit = cv.vars$n.splits[i],
50
          nimpute = n.imputations,
51
          na.action = 'na.impute',
52
          mtry = cv.vars$m.try[i]
53
        )
54
      time.learn <- handyTimer(time.start)
55
      
56
      time.start <- handyTimer()
57
      # Get C-index on validation set
58
      c.index.val <-
59
        cIndex(
60
          surv.model.fit, COHORT.cv[cv.folds[[j]],],
61
          na.action = 'na.impute'
62
        )
63
      time.c.index <- handyTimer(time.start)
64
      
65
      time.start <- handyTimer()
66
      # Get C-index on validation set
67
      calibration.score <-
68
        calibrationScore(
69
          calibrationTable(
70
            surv.model.fit, COHORT.cv[cv.folds[[j]],], na.action = 'na.impute'
71
          )
72
        )
73
      time.calibration <- handyTimer(time.start)
74
      
75
      # Append the stats we've obtained from this fold
76
      cv.fold.performance <-
77
        rbind(
78
          cv.fold.performance,
79
          data.frame(
80
            calibration = i,
81
            cv.fold = j,
82
            n.splits = cv.vars$n.splits[i],
83
            m.try = cv.vars$m.try[i],
84
            c.index.val,
85
            time.learn,
86
            time.c.index,
87
            time.calibration
88
          )
89
        )
90
      
91
    } # End cross-validation loop (j)
92
    
93
    
94
    # rbind the performance by fold
95
    cv.performance <-
96
      rbind(
97
        cv.performance,
98
        cv.fold.performance
99
      )
100
    
101
    # Save output at the end of each loop
102
    write.csv(cv.performance, calibration.filename)
103
    
104
  } # End calibration loop (i)
105
  
106
107
108
} else { # If we did previously calibrate, load it
109
  cv.performance <- read.csv(calibration.filename)
110
}
111
112
# Find the best calibration...
113
# First, average performance across cross-validation folds
114
cv.performance.average <-
115
  aggregate(
116
    c.index.val ~ calibration,
117
    data = cv.performance,
118
    mean
119
  )
120
# Find the highest value
121
best.calibration <-
122
  cv.performance.average$calibration[
123
    which.max(cv.performance.average$c.index.val)
124
  ]
125
# And finally, find the first row of that calibration to get the n.bins values
126
best.calibration.row1 <-
127
  min(which(cv.performance$calibration == best.calibration))
128
129
#' ## Fit the final model
130
#' 
131
#' This may take some time, so we'll cache it if possible...
132
133
#+ fit_final_model
134
135
surv.model.fit <-
136
  survivalFit(
137
    surv.predict,
138
    COHORT.bigdata[-test.set,],
139
    model.type = 'rfsrc',
140
    n.trees = n.trees.final,
141
    split.rule = split.rule,
142
    n.threads = n.threads,
143
    nimpute = n.imputations,
144
    nsplit = cv.performance[best.calibration.row1, 'n.splits'],
145
    mtry = cv.performance[best.calibration.row1, 'm.try'],
146
    na.action = 'na.impute',
147
    importance = 'permute'
148
  )
149
150
# Save the fit object
151
saveRDS(
152
  surv.model.fit,
153
  paste0(output.filename.base, '-surv-model.rds')
154
)
155
156
surv.model.fit.boot <-
157
  survivalBootstrap(
158
    surv.predict,
159
    COHORT.bigdata[-test.set,], # Training set
160
    COHORT.bigdata[test.set,],  # Test set
161
    model.type = 'rfsrc',
162
    n.trees = n.trees.final,
163
    split.rule = split.rule,
164
    n.threads = n.threads,
165
    nimpute = n.imputations,
166
    nsplit = cv.performance[best.calibration.row1, 'n.splits'],
167
    mtry = cv.performance[best.calibration.row1, 'm.try'],
168
    na.action = 'na.impute',
169
    bootstraps = bootstraps
170
  )
171
172
# Save the fit object
173
saveRDS(
174
  surv.model.fit.boot,
175
  paste0(output.filename.base, '-surv-model-bootstraps.rds')
176
)
177
178
# Get C-indices for training and test sets
179
surv.model.fit.coeffs <- bootStats(surv.model.fit.boot, uncertainty = '95ci')
180
181
# Save them to the all-models comparison table
182
varsToTable(
183
  data.frame(
184
    model = 'rfbigdata',
185
    imputation = FALSE,
186
    discretised = FALSE,
187
    c.index = surv.model.fit.coeffs['c.test', 'val'],
188
    c.index.lower = surv.model.fit.coeffs['c.test', 'lower'],
189
    c.index.upper = surv.model.fit.coeffs['c.test', 'upper'],
190
    calibration.score = surv.model.fit.coeffs['calibration.score', 'val'],
191
    calibration.score.lower =
192
      surv.model.fit.coeffs['calibration.score', 'lower'],
193
    calibration.score.upper =
194
      surv.model.fit.coeffs['calibration.score', 'upper']
195
  ),
196
  performance.file,
197
  index.cols = c('model', 'imputation', 'discretised')
198
)