Diff of /partyMod/src/Predict.c [000000] .. [fbf06f]

Switch to side-by-side view

--- a
+++ b/partyMod/src/Predict.c
@@ -0,0 +1,560 @@
+
+/**
+    Node splitting and prediction
+    *\file Predict.c
+    *\author $Author$
+    *\date $Date$
+*/
+                
+#include "party.h"
+
+
+/**
+    Split a node according to a splitting rule \n
+    *\param node the current node with primary split specified
+    *\param learnsample learning sample
+    *\param control an object of class `TreeControl'
+    *\todo outplace the splitting since there are at least 3 functions
+           with nearly identical code
+*/
+                
+void C_splitnode(SEXP node, SEXP learnsample, SEXP control) {
+
+    SEXP weights, leftnode, rightnode, split;
+    SEXP responses, inputs, whichNA;
+    double cutpoint, *dx, *dweights, *leftweights, *rightweights;
+    double sleft = 0.0, sright = 0.0;
+    int *ix, *levelset, *iwhichNA;
+    int nobs, i, nna;
+                    
+    weights = S3get_nodeweights(node);
+    dweights = REAL(weights);
+    responses = GET_SLOT(learnsample, PL2_responsesSym);
+    inputs = GET_SLOT(learnsample, PL2_inputsSym);
+    nobs = get_nobs(learnsample);
+            
+    /* set up memory for the left daughter */
+    SET_VECTOR_ELT(node, S3_LEFT, leftnode = allocVector(VECSXP, NODE_LENGTH));
+    C_init_node(leftnode, nobs, 
+        get_ninputs(learnsample), get_maxsurrogate(get_splitctrl(control)),
+        ncol(get_predict_trafo(responses)));
+    leftweights = REAL(S3get_nodeweights(leftnode));
+
+    /* set up memory for the right daughter */
+    SET_VECTOR_ELT(node, S3_RIGHT, 
+                   rightnode = allocVector(VECSXP, NODE_LENGTH));
+    C_init_node(rightnode, nobs, 
+        get_ninputs(learnsample), get_maxsurrogate(get_splitctrl(control)),
+        ncol(get_predict_trafo(responses)));
+    rightweights = REAL(S3get_nodeweights(rightnode));
+
+    /* split according to the primary split */
+    split = S3get_primarysplit(node);
+    if (has_missings(inputs, S3get_variableID(split))) {
+        whichNA = get_missings(inputs, S3get_variableID(split));
+        iwhichNA = INTEGER(whichNA);
+        nna = LENGTH(whichNA);
+    } else {
+        nna = 0;
+        whichNA = R_NilValue;
+        iwhichNA = NULL;
+    }
+    
+    if (S3is_ordered(split)) {
+        cutpoint = REAL(S3get_splitpoint(split))[0];
+        dx = REAL(get_variable(inputs, S3get_variableID(split)));
+        for (i = 0; i < nobs; i++) {
+            if (nna > 0) {
+                if (i_in_set(i + 1, iwhichNA, nna)) continue;
+            }
+            if (dx[i] <= cutpoint) 
+                leftweights[i] = dweights[i]; 
+            else 
+                leftweights[i] = 0.0;
+            rightweights[i] = dweights[i] - leftweights[i];
+            sleft += leftweights[i];
+            sright += rightweights[i];
+        }
+    } else {
+        levelset = INTEGER(S3get_splitpoint(split));
+        ix = INTEGER(get_variable(inputs, S3get_variableID(split)));
+
+        for (i = 0; i < nobs; i++) {
+            if (nna > 0) {
+                if (i_in_set(i + 1, iwhichNA, nna)) continue;
+            }
+            if (levelset[ix[i] - 1])
+                leftweights[i] = dweights[i];
+            else 
+                leftweights[i] = 0.0;
+            rightweights[i] = dweights[i] - leftweights[i];
+            sleft += leftweights[i];
+            sright += rightweights[i];
+        }
+    }
+    
+    /* for the moment: NA's go with majority */
+    if (nna > 0) {
+        for (i = 0; i < nna; i++) {
+            if (sleft > sright) {
+                leftweights[iwhichNA[i] - 1] = dweights[iwhichNA[i] - 1];
+                rightweights[iwhichNA[i] - 1] = 0.0;
+            } else {
+                rightweights[iwhichNA[i] - 1] = dweights[iwhichNA[i] - 1];
+                leftweights[iwhichNA[i] - 1] = 0.0;
+            }
+        }
+    }
+}
+
+
+/**
+    Get the terminal node for obs. number `numobs' of `newinputs' \n
+    *\param subtree a tree
+    *\param newinputs an object of class `VariableFrame'
+    *\param mincriterion overwrites mincriterion used for tree growing
+    *\param numobs observation number
+    *\param varperm which variable shall be permuted?
+    *\todo handle surrogate splits
+*/
+
+SEXP C_get_node(SEXP subtree, SEXP newinputs, 
+                double mincriterion, int numobs, int varperm) {
+
+    SEXP split, whichNA, ssplit, surrsplit;
+    double cutpoint, x, swleft, swright;
+    int level, *levelset, i, ns;
+
+    if (S3get_nodeterminal(subtree) || 
+        REAL(S3get_maxcriterion(subtree))[0] < mincriterion) 
+        return(subtree);
+    
+    split = S3get_primarysplit(subtree);
+
+    /* Maybe store the proportions left / right in each node? */
+    swleft = S3get_sumweights(S3get_leftnode(subtree));
+    swright = S3get_sumweights(S3get_rightnode(subtree));
+
+    /* splits based on variable varperm are random */    
+    if (S3get_variableID(split) == varperm) {
+        if (unif_rand() < swleft / (swleft + swright)) {
+            return(C_get_node(S3get_leftnode(subtree),
+                       newinputs, mincriterion, numobs, varperm));
+        } else {
+            return(C_get_node(S3get_rightnode(subtree),
+                       newinputs, mincriterion, numobs, varperm));
+        }
+    }
+                   
+    /* missing values */
+    if (has_missings(newinputs, S3get_variableID(split))) {
+        whichNA = get_missings(newinputs, S3get_variableID(split));
+    
+        /* numobs 0 ... n - 1 but whichNA has 1:n */
+        if (C_i_in_set(numobs + 1, whichNA)) {
+        
+            surrsplit = S3get_surrogatesplits(subtree);
+            ns = 0;
+            i = numobs;      
+
+            /* try to find a surrogate split */
+            while(TRUE) {
+    
+                if (ns >= LENGTH(surrsplit)) break;
+            
+                ssplit = VECTOR_ELT(surrsplit, ns);
+                if (has_missings(newinputs, S3get_variableID(ssplit))) {
+                    if (INTEGER(get_missings(newinputs, 
+                                             S3get_variableID(ssplit)))[i]) {
+                        ns++;
+                        continue;
+                    }
+                }
+
+                cutpoint = REAL(S3get_splitpoint(ssplit))[0];
+                x = REAL(get_variable(newinputs, S3get_variableID(ssplit)))[i];
+                     
+                if (S3get_toleft(ssplit)) {
+                    if (x <= cutpoint) {
+                        return(C_get_node(S3get_leftnode(subtree),
+                                          newinputs, mincriterion, numobs, varperm));
+                    } else {
+                        return(C_get_node(S3get_rightnode(subtree),
+                               newinputs, mincriterion, numobs, varperm));
+                    }
+                } else {
+                    if (x <= cutpoint) {
+                        return(C_get_node(S3get_rightnode(subtree),
+                                          newinputs, mincriterion, numobs, varperm));
+                    } else {
+                        return(C_get_node(S3get_leftnode(subtree),
+                               newinputs, mincriterion, numobs, varperm));
+                    }
+                }
+                break;
+            }
+
+            /* if this was not successful, we go with the majority */
+            if (swleft > swright) {
+                return(C_get_node(S3get_leftnode(subtree), 
+                                  newinputs, mincriterion, numobs, varperm));
+            } else {
+                return(C_get_node(S3get_rightnode(subtree), 
+                                  newinputs, mincriterion, numobs, varperm));
+            }
+        }
+    }
+    
+    if (S3is_ordered(split)) {
+        cutpoint = REAL(S3get_splitpoint(split))[0];
+        x = REAL(get_variable(newinputs, 
+                     S3get_variableID(split)))[numobs];
+        if (x <= cutpoint) {
+            return(C_get_node(S3get_leftnode(subtree), 
+                              newinputs, mincriterion, numobs, varperm));
+        } else {
+            return(C_get_node(S3get_rightnode(subtree), 
+                              newinputs, mincriterion, numobs, varperm));
+        }
+    } else {
+        levelset = INTEGER(S3get_splitpoint(split));
+        level = INTEGER(get_variable(newinputs, 
+                            S3get_variableID(split)))[numobs];
+        /* level is in 1, ..., K */
+        if (levelset[level - 1]) {
+            return(C_get_node(S3get_leftnode(subtree), newinputs, 
+                              mincriterion, numobs, varperm));
+        } else {
+            return(C_get_node(S3get_rightnode(subtree), newinputs, 
+                              mincriterion, numobs, varperm));
+        }
+    }
+}
+
+
+/**
+    R-Interface to C_get_node \n
+    *\param subtree a tree
+    *\param newinputs an object of class `VariableFrame'
+    *\param mincriterion overwrites mincriterion used for tree growing
+    *\param numobs observation number
+*/
+
+SEXP R_get_node(SEXP subtree, SEXP newinputs, SEXP mincriterion, 
+                SEXP numobs, SEXP varperm) {
+    return(C_get_node(subtree, newinputs, REAL(mincriterion)[0],
+                      INTEGER(numobs)[0] - 1, INTEGER(varperm)[0]));
+}
+
+
+/**
+    Get the node with nodeID `nodenum' \n
+    *\param subtree a tree
+    *\param nodenum a nodeID
+*/
+
+SEXP C_get_nodebynum(SEXP subtree, int nodenum) {
+    
+    if (nodenum == S3get_nodeID(subtree)) return(subtree);
+
+    if (S3get_nodeterminal(subtree)) 
+        error("no node with number %d\n", nodenum);
+
+    if (nodenum < S3get_nodeID(S3get_rightnode(subtree))) {
+        return(C_get_nodebynum(S3get_leftnode(subtree), nodenum));
+    } else {
+        return(C_get_nodebynum(S3get_rightnode(subtree), nodenum));
+    }
+}
+
+
+/**
+    R-Interface to C_get_nodenum \n
+    *\param subtree a tree
+    *\param nodenum a nodeID
+*/
+
+SEXP R_get_nodebynum(SEXP subtree, SEXP nodenum) {
+    return(C_get_nodebynum(subtree, INTEGER(nodenum)[0]));
+}
+
+
+/**
+    Get the prediction of a new observation\n
+    *\param subtree a tree
+    *\param newinputs an object of class `VariableFrame'
+    *\param mincriterion overwrites mincriterion used for tree growing
+    *\param numobs observation number
+    *\param varperm which variable shall be permuted?
+*/
+
+SEXP C_get_prediction(SEXP subtree, SEXP newinputs, 
+                      double mincriterion, int numobs, int varperm) {
+    return(S3get_prediction(C_get_node(subtree, newinputs, 
+                            mincriterion, numobs, varperm)));
+}
+
+
+/**
+    Get the weights for a new observation \n
+    *\param subtree a tree
+    *\param newinputs an object of class `VariableFrame'
+    *\param mincriterion overwrites mincriterion used for tree growing
+    *\param numobs observation number
+*/
+
+SEXP C_get_nodeweights(SEXP subtree, SEXP newinputs, 
+                       double mincriterion, int numobs) {
+    return(S3get_nodeweights(C_get_node(subtree, newinputs, 
+                             mincriterion, numobs, -1)));
+}
+
+
+/**
+    Get the nodeID for a new observation \n
+    *\param subtree a tree
+    *\param newinputs an object of class `VariableFrame'
+    *\param mincriterion overwrites mincriterion used for tree growing
+    *\param numobs observation number
+    *\param varperm which variable shall be permuted?
+*/
+
+int C_get_nodeID(SEXP subtree, SEXP newinputs,
+                  double mincriterion, int numobs, int varperm) {
+     return(S3get_nodeID(C_get_node(subtree, newinputs, 
+            mincriterion, numobs, varperm)));
+}
+
+
+/**
+    R-Interface to C_get_nodeID \n
+    *\param tree a tree
+    *\param newinputs an object of class `VariableFrame'
+    *\param mincriterion overwrites mincriterion used for tree growing
+*/
+
+SEXP R_get_nodeID(SEXP tree, SEXP newinputs, SEXP mincriterion, SEXP varperm) {
+
+    SEXP ans;
+    int nobs, i, *dans;
+            
+    nobs = get_nobs(newinputs);
+    PROTECT(ans = allocVector(INTSXP, nobs));
+    dans = INTEGER(ans);
+    for (i = 0; i < nobs; i++)
+         dans[i] = C_get_nodeID(tree, newinputs, REAL(mincriterion)[0], i, INTEGER(varperm)[0]);
+    UNPROTECT(1);
+    return(ans);
+}
+
+
+/**
+    Get all predictions for `newinputs' \n
+    *\param tree a tree
+    *\param newinputs an object of class `VariableFrame'
+    *\param mincriterion overwrites mincriterion used for tree growing
+    *\param varperm which variable shall be permuted?
+    *\param ans return value
+*/
+
+void C_predict(SEXP tree, SEXP newinputs, double mincriterion, 
+               int varperm, SEXP ans) {
+    
+    int nobs, i;
+    
+    nobs = get_nobs(newinputs);    
+    if (LENGTH(ans) != nobs) 
+        error("ans is not of length %d\n", nobs);
+        
+    for (i = 0; i < nobs; i++)
+        SET_VECTOR_ELT(ans, i, C_get_prediction(tree, newinputs, 
+                       mincriterion, i, varperm));
+}
+
+
+/**
+    R-Interface to C_predict \n
+    *\param tree a tree
+    *\param newinputs an object of class `VariableFrame'
+    *\param mincriterion overwrites mincriterion used for tree growing
+    *\param varperm which variable shall be permuted? -1 for no permutation
+*/
+
+SEXP R_predict(SEXP tree, SEXP newinputs, SEXP mincriterion,
+               SEXP varperm) {
+
+    SEXP ans;
+    int nobs;
+
+    nobs = get_nobs(newinputs);
+    PROTECT(ans = allocVector(VECSXP, nobs));
+    GetRNGstate();
+    C_predict(tree, newinputs, REAL(mincriterion)[0], 
+              INTEGER(varperm)[0], ans);
+    PutRNGstate();
+    UNPROTECT(1);
+    return(ans);
+}
+
+/**
+    Get the predictions from `where' nodes\n
+    *\param tree a tree
+    *\param where vector of nodeID's
+    *\param ans return value
+*/
+
+void C_getpredictions(SEXP tree, SEXP where, SEXP ans) {
+
+    int nobs, i, *iwhere;
+    
+    nobs = LENGTH(where);
+    iwhere = INTEGER(where);
+    if (LENGTH(ans) != nobs)
+        error("ans is not of length %d\n", nobs);
+        
+    for (i = 0; i < nobs; i++)
+        SET_VECTOR_ELT(ans, i, S3get_prediction(
+            C_get_nodebynum(tree, iwhere[i])));
+}
+
+
+/**
+    R-Interface to C_getpredictions\n
+    *\param tree a tree
+    *\param where vector of nodeID's
+*/
+            
+SEXP R_getpredictions(SEXP tree, SEXP where) {
+
+    SEXP ans;
+    int nobs;
+            
+    nobs = LENGTH(where);
+    PROTECT(ans = allocVector(VECSXP, nobs));
+    C_getpredictions(tree, where, ans);
+    UNPROTECT(1);
+    return(ans);
+}                        
+
+/**
+    Predictions weights from RandomForest objects
+    *\param forest a list of trees
+    *\param where list (length b) of integer vectors (length n) containing terminal node numbers    
+    *\param weights list (length b) of bootstrap case weights
+    *\param newinputs an object of class `VariableFrame'
+    *\param mincriterion overwrites mincriterion used for tree growing
+    *\param oobpred a logical indicating out-of-bag predictions
+*/
+
+SEXP R_predictRF_weights(SEXP forest, SEXP where, SEXP weights, 
+                         SEXP newinputs, SEXP mincriterion, SEXP oobpred, SEXP expand) {
+
+    SEXP ans, tree, bw, expand_exp;
+    int ntrees, nobs, i, b, j, iwhere, oob = 0, count = 0, ntrain;
+    int errorOccurred;
+    
+    if (LOGICAL(oobpred)[0]) oob = 1;
+    
+    nobs = get_nobs(newinputs);
+    ntrees = LENGTH(forest);
+
+    if (oob) {
+        if (LENGTH(VECTOR_ELT(weights, 0)) != nobs)
+            error("number of observations don't match");
+    }    
+    
+    tree = VECTOR_ELT(forest, 0);
+    ntrain = LENGTH(VECTOR_ELT(weights, 0));
+    
+    PROTECT(ans = allocVector(VECSXP, nobs));
+    
+    for (i = 0; i < nobs; i++) {
+        count = 0;
+        SET_VECTOR_ELT(ans, i, bw = allocVector(REALSXP, ntrain));
+        for (j = 0; j < ntrain; j++)
+            REAL(bw)[j] = 0.0;
+        for (b = 0; b < ntrees; b++) {
+            tree = VECTOR_ELT(forest, b);
+            PROTECT(expand_exp = lang2(expand, tree));
+            tree = R_tryEval(expand_exp, R_GlobalEnv, &errorOccurred);
+            UNPROTECT(1);
+            if(errorOccurred) { 
+               Rprintf("error calling expand\n");
+               break;
+            } 
+            PROTECT(tree);
+            
+            if (oob && 
+                REAL(VECTOR_ELT(weights, b))[i] > 0.0) {
+                UNPROTECT(1);
+                continue;
+            }
+
+            iwhere = C_get_nodeID(tree, newinputs, REAL(mincriterion)[0], i, -1);
+            
+            for (j = 0; j < ntrain; j++) {
+                if (iwhere == INTEGER(VECTOR_ELT(where, b))[j])
+                    REAL(bw)[j] += REAL(VECTOR_ELT(weights, b))[j];
+            }
+            count++;
+            UNPROTECT(1);
+        }
+        if(errorOccurred)
+            break;
+        
+        if (count == 0) 
+            error("cannot compute out-of-bag predictions for observation number %d", i + 1);
+    }
+    UNPROTECT(1);
+    
+    if(errorOccurred) {
+        return NULL;
+    }
+    
+    return(ans);
+}
+
+
+/**
+    Proximity matrix for random forests
+    *\param where list (length b) of integer vectors (length n) containing terminal node numbers
+*/
+
+SEXP R_proximity(SEXP where) {
+
+    SEXP ans, bw, bin;
+    int ntrees, nobs, i, b, j, iwhere;
+    
+    ntrees = LENGTH(where);
+    nobs = LENGTH(VECTOR_ELT(where, 0));
+    
+    PROTECT(ans = allocVector(VECSXP, nobs));
+    PROTECT(bin = allocVector(INTSXP, nobs));
+     
+    for (i = 0; i < nobs; i++) {
+        SET_VECTOR_ELT(ans, i, bw = allocVector(REALSXP, nobs));
+        for (j = 0; j < nobs; j++) {
+            REAL(bw)[j] = 0.0;
+            INTEGER(bin)[j] = 0;
+        }
+        for (b = 0; b < ntrees; b++) {
+            /* don't look at out-of-bag observations */
+            if (INTEGER(VECTOR_ELT(where, b))[i] == 0)
+                continue;
+            iwhere = INTEGER(VECTOR_ELT(where, b))[i];
+            for (j = 0; j < nobs; j++) {
+                if (iwhere == INTEGER(VECTOR_ELT(where, b))[j])
+                    /* only count the number of trees; no weights */
+                    REAL(bw)[j]++;
+                if (INTEGER(VECTOR_ELT(where, b))[j] > 0)
+                    /* count the number of bootstrap samples
+                    containing both i and j */
+                    INTEGER(bin)[j]++;
+            }
+        }
+        for (j = 0; j < nobs; j++)
+            REAL(bw)[j] = REAL(bw)[j] / INTEGER(bin)[j];
+    }
+    UNPROTECT(2);
+    return(ans);
+}