Diff of /code.R [000000] .. [17d6aa]

Switch to unified view

a b/code.R
1
###################################################################
2
## Code for Workshop 4: Predictive Modeling on Data with Severe 
3
## Class Imbalance: Applications on Electronic Health Records.  
4
## The course was conducted for the  International Conference on 
5
## Health Policy Statistics (ICHPS) on Wed, Oct 7, from 
6
## 10:15 AM - 12:15 PM.
7
8
###################################################################
9
## Example Data
10
11
load("emr.RData")
12
13
###################################################################
14
## Training/Test Split
15
16
library(caret)
17
18
set.seed(1732)
19
in_train <- createDataPartition(emr$Class, p = .75, list = FALSE)
20
training <- emr[ in_train,]
21
testing  <- emr[-in_train,]
22
23
mean(training$Class == "event")
24
mean(testing$Class == "event")
25
26
table(training$Class)
27
table(testing$Class)
28
29
###################################################################
30
## Overfitting to the Majority Class
31
32
library(partykit)
33
library(rpart)
34
35
rpart_small <- rpart(Class ~ ., data = training,
36
                    control = rpart.control(cp = 0.0062))
37
38
plot(as.party(rpart_small))
39
40
###################################################################
41
## Subsampling for class imbalances
42
43
## Define the resampling method and how we calculate performance
44
45
ctrl <- trainControl(method = "repeatedcv",
46
                     repeats = 5,
47
                     classProbs = TRUE,
48
                     savePredictions = TRUE,
49
                     summaryFunction = twoClassSummary)
50
51
## Tune random forest models over this grid
52
mtry_grid <- data.frame(mtry = c(1:15, (4:9)*5))
53
54
###################################################################
55
## The basic random forest model with no adaptations
56
57
set.seed(1537)
58
rf_mod <- train(Class ~ ., 
59
                data = training,
60
                method = "rf",
61
                metric = "ROC",
62
                tuneGrid = mtry_grid,
63
                ntree = 1000,
64
                trControl = ctrl)
65
66
###################################################################
67
## This function is used to take the out of sample predictions and
68
## create an approximate ROC curve from them
69
70
roc_train <- function(object, best_only = TRUE, ...) {
71
  caret:::requireNamespaceQuietStop("pROC")
72
  caret:::requireNamespaceQuietStop("plyr")
73
  
74
  if(object$modelType != "Classification")
75
    stop("ROC curves are only availible for classification models")
76
  if(!any(names(object$modelInfo) == "levels"))
77
    stop(paste("The model's code is required to have a 'levels' module.",
78
               "See http://topepo.github.io/caret/custom_models.html#Components"))
79
  lvs <- object$modelInfo$levels(object$finalModel)
80
  if(length(lvs) != 2) 
81
    stop("ROC curves are only implemented here for two class problems")
82
  
83
  ## check for predictions
84
  if(is.null(object$pred)) 
85
    stop(paste("The out of sample predictions are required.",
86
               "See the `savePredictions` argument of `trainControl`"))
87
  
88
  if(best_only) {
89
    object$pred <- merge(object$pred, object$bestTune)
90
  }
91
  ## find tuning parameter names
92
  p_names <- as.character(object$modelInfo$parameters$parameter)
93
  p_combos <- object$pred[, p_names, drop = FALSE]
94
  
95
  ## average probabilities across resamples
96
  object$pred <- plyr::ddply(.data = object$pred, 
97
                             .variables = c("obs", "rowIndex", p_names),
98
                             .fun = function(dat, lvls = lvs) {
99
                               out <- mean(dat[, lvls[1]])
100
                               names(out) <- lvls[1]
101
                               out
102
                             })
103
  
104
  make_roc <- function(x, lvls = lvs, nms = NULL, ...) {
105
    out <- pROC::roc(response = x$obs,
106
                     predictor = x[, lvls[1]],
107
                     levels = rev(lvls))
108
    
109
    out$model_param <- x[1,nms,drop = FALSE]
110
    out
111
  }
112
  out <- plyr::dlply(.data = object$pred, 
113
                     .variables = p_names,
114
                     .fun = make_roc,
115
                     lvls = lvs,
116
                     nms = p_names)
117
  if(length(out) == 1)  out <- out[[1]]
118
  out
119
}
120
121
###################################################################
122
## Some plots of the data
123
124
ggplot(rf_mod)
125
126
plot(roc_train(rf_mod), 
127
     legacy.axes = TRUE,
128
     print.thres = .5,
129
     print.thres.pattern="   <- default %.1f threshold")
130
131
plot(roc_train(rf_mod), 
132
     legacy.axes = TRUE,
133
     print.thres.pattern = "Cutoff: %.2f (Sp = %.2f, Sn = %.2f)",
134
     print.thres = "best")
135
136
###################################################################
137
## Internal down-sampling
138
139
set.seed(1537)
140
rf_down_int <- train(Class ~ ., 
141
                     data = training,
142
                     method = "rf",
143
                     metric = "ROC",
144
                     strata = training$Class,
145
                     sampsize = rep(sum(training$Class == "event"), 2),
146
                     ntree = 1000,
147
                     tuneGrid = mtry_grid,
148
                     trControl = ctrl)
149
150
ggplot(rf_mod$results, aes(x = mtry, y = ROC)) + geom_point() + geom_line() + 
151
  geom_point(data = rf_down_int$results, aes(x = mtry, y = ROC), col = mod_cols[2]) + 
152
  geom_line(data = rf_down_int$results, aes(x = mtry, y = ROC), col = mod_cols[2]) + 
153
  theme_bw() + 
154
  xlab("#Randomly Selected Predictors") + 
155
  ylab("ROC (Repeated Cross-Validation)")
156
157
###################################################################
158
## External down-sampling
159
160
ctrl$sampling <- "down"
161
set.seed(1537)
162
rf_down <- train(Class ~ ., 
163
                 data = training,
164
                 method = "rf",
165
                 metric = "ROC",
166
                 tuneGrid = mtry_grid,
167
                 ntree = 1000,
168
                 trControl = ctrl)
169
170
geom_point(data = rf_down$results, aes(x = mtry, y = ROC), col = mod_cols[1]) + 
171
  geom_line(data = rf_down$results, aes(x = mtry, y = ROC), col = mod_cols[1]) + 
172
  theme_bw() + 
173
  xlab("#Randomly Selected Predictors") + 
174
  ylab("ROC (Repeated Cross-Validation)")
175
176
###################################################################
177
## Up-sampling
178
179
ctrl$sampling <- "up"
180
set.seed(1537)
181
rf_up <- train(Class ~ ., 
182
               data = training,
183
               method = "rf",
184
               tuneGrid = mtry_grid,
185
               ntree = 1000,
186
               metric = "ROC",
187
               trControl = ctrl)
188
189
ggplot(rf_mod$results, aes(x = mtry, y = ROC)) + geom_point() + geom_line() + 
190
  geom_point(data = rf_up$results, aes(x = mtry, y = ROC), col = mod_cols[3]) + 
191
  geom_line(data = rf_up$results, aes(x = mtry, y = ROC), col = mod_cols[3]) + 
192
  theme_bw() + 
193
  xlab("#Randomly Selected Predictors") + 
194
  ylab("ROC (Repeated Cross-Validation)")
195
196
###################################################################
197
## Up-sampling done **wrong**
198
199
ctrl2 <- trainControl(method = "repeatedcv",
200
                      repeats = 5,
201
                      classProbs = TRUE,
202
                      savePredictions = TRUE,
203
                      summaryFunction = twoClassSummary)
204
upped <- upSample(x = training[, -1], y = training$Class)
205
206
set.seed(1537)
207
rf_wrong <- train(Class ~ ., 
208
                  data = upped,
209
                  method = "rf",
210
                  tuneGrid = mtry_grid,
211
                  ntree = 1000,
212
                  metric = "ROC",
213
                  trControl = ctrl2)
214
215
ggplot(rf_mod$results, aes(x = mtry, y = ROC)) + geom_point() + geom_line() + 
216
  geom_point(data = rf_wrong$results, aes(x = mtry, y = ROC), col = mod_cols[3]) + 
217
  geom_line(data = rf_wrong$results, aes(x = mtry, y = ROC), col = mod_cols[3]) + 
218
  theme_bw() + 
219
  xlab("#Randomly Selected Predictors") + 
220
  ylab("ROC (Repeated Cross-Validation)")
221
222
###################################################################
223
## SMOTE 
224
225
ctrl$sampling <- "smote"
226
set.seed(1537)
227
rf_smote <- train(Class ~ ., 
228
                  data = training,
229
                  method = "rf",
230
                  tuneGrid = mtry_grid,
231
                  ntree = 1000,
232
                  metric = "ROC",
233
                  trControl = ctrl)
234
235
ggplot(rf_mod$results, aes(x = mtry, y = ROC)) + geom_point() + geom_line() + 
236
  geom_point(data = rf_smote$results, aes(x = mtry, y = ROC), col = mod_cols[4]) + 
237
  geom_line(data = rf_smote$results, aes(x = mtry, y = ROC), col = mod_cols[4]) + 
238
  theme_bw() + 
239
  xlab("#Randomly Selected Predictors") + 
240
  ylab("ROC (Repeated Cross-Validation)")
241
242
###################################################################
243
## Make code to measure performance for cost-sensitive learning
244
245
246
fourStats <- function (data, lev = levels(data$obs), model = NULL) {
247
  accKapp <- postResample(data[, "pred"], data[, "obs"])
248
  out <- c(accKapp,
249
           sensitivity(data[, "pred"], data[, "obs"], lev[1]),
250
           specificity(data[, "pred"], data[, "obs"], lev[2]))
251
  names(out)[3:4] <- c("Sens", "Spec")
252
  out
253
}
254
255
ctrl_cost <- trainControl(method = "repeatedcv",
256
                          repeats = 5,
257
                          classProbs = FALSE,
258
                          savePredictions = TRUE,
259
                          summaryFunction = fourStats)
260
261
###################################################################
262
## Setup a custom tuning grid by first fitting a rpart model and
263
## getting the unique Cp values
264
265
rpart_init <- rpart(Class ~ ., data = training, cp = 0)$cptable
266
267
cost_grid <- expand.grid(cp = rpart_init[, "CP"],
268
                         Cost = 1:5)
269
set.seed(1537)
270
rpart_costs <- train(Class ~ ., data = training, 
271
                     method = "rpartCost",
272
                     tuneGrid = cost_grid,
273
                     metric = "Kappa",
274
                     trControl = ctrl_cost)
275
276
ggplot(rpart_costs) + 
277
  scale_x_log10(breaks = 10^pretty(log10(rpart_costs$results$cp), n = 5)) + 
278
  theme(legend.position = "top")
279
280
###################################################################
281
## C5.0 with costs
282
283
cost_grid <- expand.grid(trials = 1:3,
284
                         winnow = FALSE,
285
                         model = "tree",
286
                         cost = seq(1, 10, by = .25))
287
set.seed(1537)
288
c5_costs <- train(Class ~ ., data = training, 
289
                  method = "C5.0Cost",
290
                  tuneGrid = cost_grid,
291
                  metric = "Kappa",
292
                  trControl = ctrl_cost)
293
294
c5_costs_res <- subset(c5_costs$results, trials <= 3)
295
c5_costs_res$trials <- factor(c5_costs_res$trials)
296
297
298
ggplot(c5_costs_res, aes(x = cost, y = Kappa, group = trials)) +
299
  geom_point(aes(color = trials)) + 
300
  geom_line(aes(color = trials)) + 
301
  ylab("Kappa (Repeated Cross-Validation)")+ 
302
  theme(legend.position = "top")
303