Diff of /partyMod/R/RandomForest.R [000000] .. [fbf06f]

Switch to side-by-side view

--- a
+++ b/partyMod/R/RandomForest.R
@@ -0,0 +1,345 @@
+
+# $Id$
+
+### the fitting procedure
+cforestfit <- function(object, controls, weights = NULL, fitmem = NULL, ...) {
+
+    if (!extends(class(object), "LearningSample"))
+        stop(sQuote("object"), " is not of class ", sQuote("LearningSample"))
+    if (!extends(class(controls), "ForestControl"))
+        stop(sQuote("controls"), " is not of class ", sQuote("ForestControl"))
+
+    if (is.null(fitmem)) 
+        fitmem <- ctree_memory(object, TRUE)
+    if (!extends(class(fitmem), "TreeFitMemory"))
+        stop(sQuote("fitmem"), " is not of class ", sQuote("TreeFitMemory"))
+
+    if (is.null(weights))
+        weights <- object@weights
+    storage.mode(weights) <- "double"
+    USER_WEIGHTS <- is.matrix(weights)
+    if (USER_WEIGHTS) {
+        controls@ntree <- ncol(weights)
+        weights <- as.data.frame(weights)
+        if (nrow(weights) != object@nobs)
+            stop(sQuote("weights"), " are not a double matrix of ", 
+                 object@nobs, " rows")
+        bweights <- weights
+        bwhere <- vector(mode = "list", length = controls@ntree)
+        ### grow the tree
+        ensemble <- .Call("R_Ensemble_weights", object, bwhere, bweights, fitmem, controls,
+                          PACKAGE = "atlantisPartyMod")
+    } else {
+        if (length(weights) != object@nobs || storage.mode(weights) != "double")
+            stop(sQuote("weights"), " are not a double vector of ", 
+                 object@nobs, " elements")
+        bweights <- vector(mode = "list", length = controls@ntree)
+        bwhere <- vector(mode = "list", length = controls@ntree)
+        ### grow the tree
+        ensemble <- .Call("R_Ensemble", object, weights, bwhere, bweights, fitmem, controls,
+                          PACKAGE = "atlantisPartyMod")
+    }
+
+    ### prepare the returned object
+    RET <- new("RandomForest")
+    RET@ensemble <- ensemble
+    RET@where <- bwhere
+    RET@weights <- bweights
+    RET@expand <- controls@expand
+    if (USER_WEIGHTS) {
+        RET@initweights <- as.double(rep(1.0, object@nobs)) ### <FIXME>
+    } else {
+        RET@initweights <- as.double(weights)
+    }
+    RET@responses <- object@responses
+    if (inherits(object, "LearningSampleFormula"))
+        RET@data <- object@menv
+
+    RET@update <- function(weights = NULL) {
+        cforestfit(object = object, controls = controls,
+                   weights = weights, fitmem = fitmem, ...)
+    }
+
+
+    ### (estimated) conditional distribution of the response given the
+    ### covariates
+    RET@cond_distr_response <- function(newdata = NULL, mincriterion = 0, ...) { 
+        
+        pw <- RET@prediction_weights(newdata = newdata, mincriterion =
+                                     mincriterion, ...)
+
+        response <- object@responses
+
+        ### survival: estimated Kaplan-Meier
+        if (any(response@is_censored)) {
+            resp <- response@variables[[1]]
+            RET <- lapply(pw, function(w) 
+                mysurvfit(resp, weights = w))
+            return(RET)
+        }
+
+        ### classification: estimated class probabilities
+        ### regression: the means, not really a distribution
+        RET <- lapply(pw, function(w) w %*% response@predict_trafo / sum(w))
+        return(RET)
+    }
+
+    ### predict in the response space, always!
+    RET@predict_response <- function(newdata = NULL, mincriterion = 0, 
+        type = c("response", "prob"), ...) { 
+
+        type <- match.arg(type)
+        cdresp <- RET@cond_distr_response(newdata = newdata, 
+                                          mincriterion = mincriterion, ...)
+        if (type == "prob" || object@responses@ninputs > 1)
+            return(cdresp)
+
+        response <- object@responses
+        ### classification: classes
+        if (all(response@is_nominal || response@is_ordinal)) {
+            lev <- levels(response@variables[[1]])
+            RET <- factor(lev[unlist(lapply(cdresp, which.max))],
+                          levels = levels(response@variables[[1]]))
+            return(RET)
+        }
+
+        ### survival: median survival time
+        if (any(response@is_censored)) {
+            RET <- sapply(cdresp, mst)
+            return(RET)
+        }
+
+        ### regression: mean (median would be possible)
+        RET <- matrix(unlist(cdresp),
+                      nrow = length(cdresp), byrow = TRUE)
+        ### <FIXME> what about multivariate responses?
+        colnames(RET) <- names(response@variables)
+        ### </FIXME>
+        return(RET)
+    }
+
+    ### get terminal node numbers
+    RET@get_where <- function(newdata = NULL, mincriterion = 0, ...) {
+
+        if (is.null(newdata) && mincriterion == 0) {
+            if (all(where > 0)) return(RET@where)
+        }
+
+        newinp <- newinputs(object, newdata)
+
+        lapply(ensemble, function(e) 
+            R_get_nodeID(RET@expand(e), newinp, mincriterion))
+    }
+
+    RET@prediction_weights <- function(newdata = NULL, 
+                                       mincriterion = 0, OOB = FALSE) {
+
+        newinp <- newinputs(object, newdata)
+
+        RET <- .Call("R_predictRF_weights", ensemble, bwhere, bweights, 
+                     newinp, mincriterion, OOB && is.null(newdata), RET@expand,
+                     PACKAGE = "atlantisPartyMod")
+        names(RET) <- rownames(newinp@variables)
+        RET
+    }
+    return(RET)
+}
+
+### the unfitted forest, an object of class `StatModel'
+### see package `modeltools'
+RandomForest <- new("StatModel",
+                    capabilities = new("StatModelCapabilities"),
+                    name = "random forest",
+                    dpp = ctreedpp,
+                    fit = cforestfit,
+                    predict = function(object, ...) 
+                        object@predict_response(...))
+
+cforest_control <- function(teststat = "max", 
+                            testtype = "Teststatistic",
+                            mincriterion = qnorm(0.9),
+                            savesplitstats = FALSE,
+                            ntree = 500, mtry = 5, replace = TRUE, 
+                            fraction = 0.632, 
+                            trace = FALSE, dropcriterion=FALSE, compress=identity, expand=identity, varOnce=F, ...) {
+
+    if (is.null(mtry)) mtry <- 0
+    tmp <- ctree_control(teststat = teststat, testtype = testtype,
+                         mincriterion = mincriterion, 
+                         savesplitstats = savesplitstats, 
+                         mtry = mtry, varOnce=varOnce, ...)
+    RET <- new("ForestControl")
+    RET@ntree <- as.integer(ntree)
+    RET@replace <- as.logical(replace)
+    RET@fraction <- as.double(fraction)
+    RET@trace <- as.logical(trace)
+    RET@dropcriterion <- as.logical(dropcriterion)
+    RET@compress <- as.function(compress)
+    RET@expand <- as.function(expand)
+    RET <- copyslots(tmp, RET)
+    if (!validObject(RET))
+        stop("RET is not a valid object of class", class(RET))
+    RET
+}
+
+cforest_classical <- function(...) cforest_control(teststat = "max",
+                            testtype = "Teststatistic",
+                            mincriterion = qnorm(0.9), 
+                            replace = TRUE, ...)
+
+cforest_unbiased <- function(...) cforest_control(teststat = "quad", 
+                            testtype = "Univ",
+                            mincriterion = 0,
+                            replace = FALSE, 
+                            fraction = 0.632, ...) 
+    
+### the top-level convenience function
+cforest <- function(formula, data = list(), subset = NULL, weights = NULL, 
+                    controls = cforest_unbiased(),
+                    xtrafo = ptrafo, ytrafo = ptrafo, scores = NULL) {
+
+    ### setup learning sample
+    ls <- dpp(RandomForest, formula, data, subset, xtrafo = xtrafo, 
+              ytrafo = ytrafo, scores = scores)
+
+    ### setup memory
+    fitmem <- ctree_memory(ls, TRUE)
+
+    ### fit and return a conditional tree
+    fit(RandomForest, ls, controls = controls, weights = weights, 
+        fitmem = fitmem)
+}
+
+###
+### variable importance for `cforest'
+###
+### see ?importance (in `randomForest'), too
+###
+###
+
+### extract ID of _all_ variables the tree uses for splitting
+varIDs <- function(node) {
+
+    v <- c()
+    foo <- function(node) {
+        if (node[[4]]) return(NULL)
+        v <<- c(v, node[[5]][[1]])
+        foo(node[[8]])
+        foo(node[[9]])
+    }
+    foo(node)
+    return(v)
+}
+
+### calculate proximity matrix: p[i,j] = number of times obs i and j are 
+### in the same terminal node
+proximity <- function(object, newdata = NULL) {
+
+    if (is.null(newdata)) {
+        wh <- object@where
+        rn <- rownames(object@data@get("response"))
+    } else {
+        wh <- object@get_where(newdata = newdata)
+        rn <- rownames(newdata)
+    }
+    ### extract prediction weights
+    prox <- .Call("R_proximity", wh, package = "atlantisPartyMod")
+    prox <- matrix(unlist(prox), ncol = length(prox))
+    rownames(prox) <- rn
+    colnames(prox) <- rn
+    prox
+}
+
+
+### FIXME: newdata may be missing, reuse weights
+### partialPlot.BinaryTree?
+partialPlot.party <-
+    function (x, newdata, x.var, which.class, weights, plot = TRUE, add = FALSE,
+              n.pt = min(length(unique(newdata[, xname])), 51), rug = TRUE,
+              xlab = deparse(substitute(x.var)), ylab = "",
+              main = paste("Partial Dependence on", deparse(substitute(x.var))),
+              ...) 
+{
+    classRF <- all(x@responses@is_nominal || x@responses@is_ordinal)
+
+    x.var <- substitute(x.var)
+    xname <- if (is.character(x.var)) x.var else {
+        if (is.name(x.var)) deparse(x.var) else {
+            eval(x.var)
+        }
+    }
+
+    if (!xname %in% names(newdata))
+        stop("variable", " ", xname, " ", "not known")
+    xv <- newdata[, xname]
+    n <- nrow(newdata)
+
+    if (missing(weights)) weights <- rep(1, n)
+
+    if (classRF) {
+        if (missing(which.class)) {
+            focus <- 1
+        }
+        else {
+            focus <- charmatch(which.class, levels(x@responses@variables[[1]]))
+            if (is.na(focus)) 
+                stop(which.class, " ", "is not one of the class labels.")
+        }
+    }
+    if (is.factor(xv)) { ### includes ordered
+        x.pt <- levels(xv)
+        y.pt <- numeric(length(x.pt))
+        for (i in seq(along = x.pt)) {
+            x.data <- newdata
+            x.data[, xname] <- factor(rep(x.pt[i], n), levels = x.pt, ordered = is.ordered(xv))
+            if (classRF) {
+                pr <- treeresponse(x, newdata = x.data)
+                pr <- matrix(unlist(pr), nrow = length(pr), byrow = TRUE)
+                y.pt[i] <- weighted.mean(log(ifelse(pr[, focus] > 0,
+                                                    pr[, focus], 1)) -
+                                         rowMeans(log(ifelse(pr > 0, pr, 1))),
+                                         weights, na.rm=TRUE)
+            } else y.pt[i] <- weighted.mean(predict(x, newdata = x.data), weights, na.rm = TRUE)
+
+        }
+        if (add) {
+            points(1:length(x.pt), y.pt, type="h", lwd=2, ...)
+        } else {
+            if (plot) barplot(y.pt, width=rep(1, length(y.pt)), col="blue",
+                              xlab = xlab, ylab = ylab, main=main,
+                              names.arg=x.pt, ...)
+        }
+    } else {
+        x.pt <- seq(min(xv), max(xv), length = n.pt)
+        y.pt <- numeric(length(x.pt))
+        for (i in seq(along = x.pt)) {
+            x.data <- newdata
+            x.data[, xname] <- rep(x.pt[i], n)
+            class(x.data[, xname]) <- class(newdata[, xname])
+            storage.mode(x.data[, xname]) <- storage.mode(newdata[, xname])
+            if (classRF) {
+                pr <- treeresponse(x, newdata = x.data)
+                pr <- matrix(unlist(pr), nrow = length(pr), byrow = TRUE)
+                y.pt[i] <- weighted.mean(log(ifelse(pr[, focus] == 0, 1, pr[, focus]))
+                                         - rowMeans(log(ifelse(pr == 0, 1, pr))),
+                                         weights, na.rm=TRUE)
+            } else {
+                y.pt[i] <- weighted.mean(predict(x, newdata = x.data), weights, na.rm=TRUE)
+            }
+        }
+        if (add) {
+            lines(x.pt, y.pt, ...)
+        } else {
+            if (plot) plot(x.pt, y.pt, type = "l", xlab=xlab, ylab=ylab,
+                           main = main, ...)
+        }
+        if (rug && plot) {
+            if (n.pt > 10) {
+                rug(quantile(xv, seq(0.1, 0.9, by = 0.1)), side = 1)
+            } else {
+                rug(unique(xv, side = 1))
+            }
+        }
+    }
+    invisible(list(x = x.pt, y = y.pt))
+}