a b/partyMod/R/varimp.R
1
# for the current variable of interest, xname,
2
# create the list of variables to condition on:
3
4
create_cond_list <- function(cond, threshold, xname, input) {
5
6
   stopifnot(is.logical(cond))
7
   if (!cond) return(NULL)
8
   if (threshold > 0 & threshold < 1) {
9
           ctrl <- ctree_control(teststat = "quad", testtype = "Univariate", stump = TRUE)
10
           xnames <- names(input)
11
           xnames <- xnames[xnames != xname]
12
           ct <- ctree(as.formula(paste(xname, "~", paste(xnames, collapse = "+"), collapse = "")),
13
                       data = input, controls = ctrl)
14
           crit <- ct@tree$criterion[[2]]
15
           crit[which(is.na(crit))] <- 0
16
           return(xnames[crit > threshold])
17
       }
18
   stop()
19
}
20
21
22
23
## mincriterion = 0 so that complete tree is evaluated; 
24
## regulate size of considered tree here via, e.g., mincriterion = 0.95
25
## or when building the forest in the first place via cforest_control(mincriterion = 0.95)
26
27
varimp <- function (object, mincriterion = 0, conditional = FALSE, 
28
                    threshold = 0.2, nperm = 1, OOB = TRUE, pre1.0_0 = conditional)
29
{
30
31
    response <- object@responses
32
    if (length(response@variables) == 1 && 
33
        inherits(response@variables[[1]], "Surv"))
34
        return(varimpsurv(object, mincriterion, conditional, threshold, nperm, OOB, pre1.0_0))
35
    input <- object@data@get("input")
36
    xnames <- colnames(input)
37
    inp <- initVariableFrame(input, trafo = NULL)
38
    y <- object@responses@variables[[1]]
39
    if(length(response@variables) != 1)
40
        stop("cannot compute variable importance measure for multivariate response")
41
42
    if (conditional || pre1.0_0) {
43
        if(!all(complete.cases(inp@variables)))
44
            stop("cannot compute variable importance measure with missing values")
45
    }
46
    CLASS <- all(response@is_nominal)
47
    ORDERED <- all(response@is_ordinal)
48
    if (CLASS) {
49
        error <- function(x, oob) mean((levels(y)[sapply(x, which.max)] != 
50
            y)[oob])
51
    }
52
    else {
53
        if (ORDERED) {
54
            error <- function(x, oob) mean((sapply(x, which.max) != 
55
                y)[oob])
56
        }
57
        else {
58
            error <- function(x, oob) mean((unlist(x) - y)[oob]^2)
59
        }
60
    }
61
62
    w <- object@initweights
63
    if (max(abs(w - 1)) > sqrt(.Machine$double.eps))
64
        warning(sQuote("varimp"), " with non-unity weights might give misleading results")
65
66
    ## list for several permutations
67
    perror <- matrix(0, nrow = nperm*length(object@ensemble), ncol = length(xnames))
68
    ## this matrix is initialized with values 0 so that a tree that does not 
69
    ## contain the current variable adds importance 0 to its average importance
70
    colnames(perror) <- xnames
71
        for (b in 1:length(object@ensemble)){
72
            tree <- object@expand(object@ensemble[[b]])
73
74
            ## if OOB == TRUE use only oob observations, otherwise use all observations in learning sample
75
            if(OOB){oob <- object@weights[[b]] == 0} else{ oob <- rep(TRUE, length(y))}
76
            p <- .Call("R_predict", tree, inp, mincriterion, -1L, PACKAGE = "atlantisPartyMod")
77
            eoob <- error(p, oob)
78
79
            ## for all variables (j = 1 ... number of variables) 
80
            for(j in unique(varIDs(tree))){
81
              for (per in 1:nperm){
82
83
                if (conditional || pre1.0_0) {
84
                    tmp <- inp
85
                    ccl <- create_cond_list(conditional, threshold, xnames[j], input)
86
                    if (is.null(ccl)) {
87
                        perm <- sample(which(oob))
88
                    } else {
89
                        perm <- conditional_perm(ccl, xnames, input, tree, oob)
90
                    }
91
                    tmp@variables[[j]][which(oob)] <- tmp@variables[[j]][perm]
92
                    p <- .Call("R_predict", tree, tmp, mincriterion, -1L,
93
                       PACKAGE = "atlantisPartyMod")
94
                } else {
95
                    p <- .Call("R_predict", tree, inp, mincriterion, as.integer(j),
96
                               PACKAGE = "atlantisPartyMod")
97
                }
98
                ## run through all rows of perror
99
                perror[(per+(b-1)*nperm), j] <- (error(p, oob) - eoob)
100
101
              } ## end of for (per in 1:nperm)
102
            } ## end of for(j in unique(varIDs(tree)))
103
        } ## end of for (b in 1:length(object@ensemble))
104
105
    perror <- as.data.frame(perror)
106
    #return(MeanDecreaseAccuracy = perror) ## return the whole matrix (= nperm*ntree values per variable)
107
    return(MeanDecreaseAccuracy = colMeans(perror)) ## return only averages over permutations and trees
108
}
109
110
111
varimpsurv <- function (object, mincriterion = 0, conditional = FALSE, 
112
                        threshold = 0.2, nperm = 1, OOB = TRUE, pre1.0_0 = conditional)
113
{
114
115
    cat("\n")
116
    cat("Variable importance for survival forests; this feature is _experimental_\n\n")
117
    response <- object@responses
118
    input <- object@data@get("input")
119
    xnames <- colnames(input)
120
    inp <- initVariableFrame(input, trafo = NULL)
121
    y <- object@responses@variables[[1]]
122
    weights <- object@initweights
123
    stopifnot(inherits(y, "Surv"))
124
125
    if (conditional || pre1.0_0) {
126
        if(!all(complete.cases(inp@variables)))
127
            stop("cannot compute variable importance measure with missing values")
128
    }
129
    stopifnot(require("ipred", quietly = TRUE))
130
    error <- function(x, oob) sbrier(y[oob,,drop = FALSE], x[oob])
131
132
    pred <- function(tree, newinp, j = -1L) {
133
134
        where <- R_get_nodeID(tree, inp, mincriterion)
135
        wh <- .Call("R_get_nodeID", tree, newinp, mincriterion, as.integer(j), PACKAGE = "atlantisPartyMod")
136
        swh <- sort(unique(wh))
137
        RET <- vector(mode = "list", length = length(wh))
138
        for (i in 1:length(swh)) {
139
            w <- weights * (where == swh[i])
140
            RET[wh == swh[i]] <- list(mysurvfit(y, weights = w))
141
        }
142
        return(RET)
143
    }
144
145
    w <- object@initweights
146
    if (max(abs(w - 1)) > sqrt(.Machine$double.eps))
147
        warning(sQuote("varimp"), " with non-unity weights might give misleading results")
148
149
    ## list for several permutations
150
    perror <- matrix(0, nrow = nperm*length(object@ensemble), ncol = length(xnames))
151
    ## this matrix is initialized with values 0 so that a tree that does not 
152
    ## contain the current variable adds importance 0 to its average importance
153
    colnames(perror) <- xnames
154
        for (b in 1:length(object@ensemble)){
155
            tree <- object@ensemble[[b]]
156
157
158
            ## if OOB == TRUE use only oob observations, otherwise use all observations in learning sample
159
            if(OOB){oob <- object@weights[[b]] == 0} else{ oob <- rep(TRUE, length(y))}
160
            p <- pred(tree, inp)
161
            eoob <- error(p, oob)
162
163
            ## for all variables (j = 1 ... number of variables) 
164
            for(j in unique(varIDs(tree))){
165
              for (per in 1:nperm){
166
                 if (conditional || pre1.0_0) {
167
                    tmp <- inp
168
                    ccl <- create_cond_list(conditional, threshold, xnames[j], input)
169
                    if (is.null(ccl)) {
170
                        perm <- sample(which(oob))
171
                    } else {
172
                        perm <- conditional_perm(ccl, xnames, input, tree, oob)
173
                    }
174
                    tmp@variables[[j]][which(oob)] <- tmp@variables[[j]][perm]
175
                    p <- pred(tree, tmp, -1L)
176
                } else {
177
                    p <- pred(tree, inp, as.integer(j))
178
                }
179
180
                ## run through all rows of perror
181
                perror[(per+(b-1)*nperm), j] <- (error(p, oob) - eoob)
182
183
              } ## end of for (per in 1:nperm)
184
            } ## end of for(j in unique(varIDs(tree)))
185
        } ## end of for (b in 1:length(object@ensemble))
186
187
    perror <- as.data.frame(perror)
188
    #return(MeanDecreaseAccuracy = perror) ## return the whole matrix (= nperm*ntree values per variable)
189
    return(MeanDecreaseAccuracy = colMeans(perror)) ## return only averages over permutations and trees
190
}
191
192
193
194
195
# cutpoints_list() returns:
196
# - vector of cutpoints (length=number of cutpoints) 
197
#   if variable is continuous
198
# - vector of indicators (length=number of categories x number of cutpoints)
199
#   if variable is categorical (nominal or ordered)
200
cutpoints_list <- function(tree, variableID) {
201
202
    cutp <- function(node) {
203
       if (node[[4]]) return(NULL)
204
       cp <- NULL
205
       if (node[[5]][[1]] == variableID)
206
           cp <- node[[5]][[3]]
207
       nl <- cutp(node[[8]])
208
       nr <- cutp(node[[9]])
209
       return(c(cp, nl, nr))
210
    }
211
    return(cutp(tree))
212
}
213
214
215
conditional_perm <- function(cond, xnames, input, tree, oob){
216
217
    ## get cutpoints of all conditioning variables of the current variable of interest 
218
    ## and generate design matrix for permutation from factors in help
219
    blocks <- vector(mode = "list", length = length(cond))
220
                    
221
    for (i in 1:length(cond)) {
222
223
        ## varID is variable index or column number of input (predictor matrix) 
224
        ## not variable name!
225
        varID <- which(xnames == cond[i])
226
227
228
        ## if conditioning variable is not used for splitting in current tree
229
        ## proceed with next conditioning variable
230
        cl <- cutpoints_list(tree, varID)
231
        if (is.null(cl)) next
232
233
        ## proceed cutpoints for different types of variables
234
        x <- input[, varID]
235
        xclass <- class(x)[1]
236
        if (xclass == "integer") xclass <- "numeric"
237
238
        block <- switch(xclass, "numeric" = cut(x, breaks = c(-Inf, sort(unique(cl)), Inf)),
239
                        "ordered" = cut(as.numeric(x), breaks =  c(-Inf, sort(unique(cl)), Inf)),
240
                        "factor" = {
241
                            CL <- matrix(as.logical(cl), nrow = nlevels(x))                            
242
                            rs <- rowSums(CL)
243
                            dlev <- (1:nrow(CL))[rs %in% rs[duplicated(rs)]]
244
                            fuse <- c()
245
                            for (ii in dlev) {
246
                                for (j in dlev[dlev > ii]) {
247
                                    if (all(CL[ii,] == CL[j,])) fuse <- rbind(fuse, c(ii, j))
248
                                }
249
                            }
250
                            xlev <- 1:nlevels(x)
251
                            newl <- nlevels(x) + 1
252
                            block <- as.integer(x)
253
                            for (l in xlev) {
254
                                if (NROW(fuse) == 0) break
255
                                if (any(fuse[, 1] == l)) {
256
                                    f <- c(l, fuse[fuse[, 1] == l, 2])
257
                                    fuse <- fuse[!fuse[,1] %in% f, , drop = FALSE]
258
                                    block[block %in% f] <- newl
259
                                    newl <- newl + 1
260
                                 }
261
                            }
262
                            as.factor(block)
263
                         })
264
         blocks[[i]] <- block
265
    }
266
267
    ## remove non-splitting variables
268
    names(blocks) <- cond
269
    blocks <- blocks[!sapply(blocks, is.null)]
270
271
    ## if none of the conditioning variables are used in the tree
272
    if (!length(blocks)>0){
273
        perm <- sample(which(oob))
274
        return(perm)
275
    } else {
276
        blocks <- as.data.frame(blocks)
277
        ## from factors blocks create design matrix
278
        f <- paste("~ - 1 + ", paste(colnames(blocks), collapse = ":", sep = ""))
279
        des <- model.matrix(as.formula(f), data = blocks)
280
281
        ## one conditional permutation
282
        perm <- 1:nrow(input)
283
        for (l in 1:ncol(des)) {
284
           index <- which(des[,l] > 0 & oob)
285
           if (length(index) > 1)
286
               perm[index] <- sample(index)
287
           }
288
        return(perm[oob])
289
    }
290
}
291
292
varimpAUC <- function(object, mincriterion = 0, conditional = FALSE, 
293
                      threshold = 0.2, nperm = 1, OOB = TRUE, pre1.0_0 = conditional)
294
{
295
296
    response <- object@responses
297
    input <- object@data@get("input")
298
    xnames <- colnames(input)
299
    inp <- initVariableFrame(input, trafo = NULL)
300
    y <- object@responses@variables[[1]]
301
    if(length(response@variables) != 1)
302
        stop("cannot compute variable importance measure for multivariate response")
303
304
    if (conditional || pre1.0_0) {
305
        if(!all(complete.cases(inp@variables)))
306
            stop("cannot compute variable importance measure with missing values")
307
    }
308
    CLASS <- all(response@is_nominal)
309
    ORDERED <- all(response@is_ordinal)
310
    if (CLASS) {      
311
          if (nlevels(y)>2) {
312
            warning("AUC=TRUE works only for binary y\n error rate is used instead of AUC")
313
            error <- function(x, oob) mean((levels(y)[sapply(x, which.max)] != y)[oob])
314
          }   
315
          else {
316
             error <- function(x, oob) {
317
               xoob <- sapply(x, function(x) x[1])[oob]
318
               yoob <- y[oob]
319
               which1 <- which(yoob==levels(y)[1])
320
               noob1 <- length(which1)
321
               noob <- length(yoob)
322
               if (noob1==0|noob1==noob) { return(NA) }       # AUC cannot be computed if all OOB-observations are from one class
323
               return(1-sum(kronecker(xoob[which1] , xoob[-which1],">"))/(noob1*(length(yoob)-noob1)))       # calculate AUC
324
            }
325
       }
326
       ###  stop
327
    }
328
    else {
329
        if (ORDERED) {
330
            error <- function(x, oob) mean((sapply(x, which.max) != 
331
                y)[oob])
332
        }
333
        else {
334
            error <- function(x, oob) mean((unlist(x) - y)[oob]^2)
335
        }
336
    }
337
338
    w <- object@initweights
339
    if (max(abs(w - 1)) > sqrt(.Machine$double.eps))
340
        warning(sQuote("varimp"), " with non-unity weights might give misleading results")
341
342
    perror <- matrix(0, nrow = nperm*length(object@ensemble), ncol = length(xnames))
343
    colnames(perror) <- xnames
344
        for (b in 1:length(object@ensemble)){
345
            tree <- object@ensemble[[b]]
346
347
            if(OOB){oob <- object@weights[[b]] == 0} else{ oob <- rep(TRUE, length(xnames))}
348
            p <- .Call("R_predict", tree, inp, mincriterion, -1L, PACKAGE = "atlantisPartyMod")
349
            eoob <- error(p, oob)
350
351
            for(j in unique(varIDs(tree))){
352
              for (per in 1:nperm){
353
354
                if (conditional || pre1.0_0) {
355
                    tmp <- inp
356
                    ccl <- create_cond_list(conditional, threshold, xnames[j], input)
357
                    if (is.null(ccl)) {
358
                        perm <- sample(which(oob))
359
                    } else {
360
                        perm <- conditional_perm(ccl, xnames, input, tree, oob)
361
                    }
362
                    tmp@variables[[j]][which(oob)] <- tmp@variables[[j]][perm]
363
                    p <- .Call("R_predict", tree, tmp, mincriterion, -1L,
364
                       PACKAGE = "atlantisPartyMod")
365
                } else {
366
                    p <- .Call("R_predict", tree, inp, mincriterion, as.integer(j),
367
                               PACKAGE = "atlantisPartyMod")
368
                }
369
                perror[(per+(b-1)*nperm), j] <- (error(p, oob) - eoob)
370
371
              } 
372
            } 
373
        } 
374
375
    perror <- as.data.frame(perror)
376
    return(MeanDecreaseAccuracy = colMeans(perror, na.rm = TRUE)) ## na.rm = TRUE because with AUC-perm. VIM NA values occur whenever a tree's OOB-observations are all from the same class
377
}
378