--- a +++ b/partyMod/src/Predict.c @@ -0,0 +1,560 @@ + +/** + Node splitting and prediction + *\file Predict.c + *\author $Author$ + *\date $Date$ +*/ + +#include "party.h" + + +/** + Split a node according to a splitting rule \n + *\param node the current node with primary split specified + *\param learnsample learning sample + *\param control an object of class `TreeControl' + *\todo outplace the splitting since there are at least 3 functions + with nearly identical code +*/ + +void C_splitnode(SEXP node, SEXP learnsample, SEXP control) { + + SEXP weights, leftnode, rightnode, split; + SEXP responses, inputs, whichNA; + double cutpoint, *dx, *dweights, *leftweights, *rightweights; + double sleft = 0.0, sright = 0.0; + int *ix, *levelset, *iwhichNA; + int nobs, i, nna; + + weights = S3get_nodeweights(node); + dweights = REAL(weights); + responses = GET_SLOT(learnsample, PL2_responsesSym); + inputs = GET_SLOT(learnsample, PL2_inputsSym); + nobs = get_nobs(learnsample); + + /* set up memory for the left daughter */ + SET_VECTOR_ELT(node, S3_LEFT, leftnode = allocVector(VECSXP, NODE_LENGTH)); + C_init_node(leftnode, nobs, + get_ninputs(learnsample), get_maxsurrogate(get_splitctrl(control)), + ncol(get_predict_trafo(responses))); + leftweights = REAL(S3get_nodeweights(leftnode)); + + /* set up memory for the right daughter */ + SET_VECTOR_ELT(node, S3_RIGHT, + rightnode = allocVector(VECSXP, NODE_LENGTH)); + C_init_node(rightnode, nobs, + get_ninputs(learnsample), get_maxsurrogate(get_splitctrl(control)), + ncol(get_predict_trafo(responses))); + rightweights = REAL(S3get_nodeweights(rightnode)); + + /* split according to the primary split */ + split = S3get_primarysplit(node); + if (has_missings(inputs, S3get_variableID(split))) { + whichNA = get_missings(inputs, S3get_variableID(split)); + iwhichNA = INTEGER(whichNA); + nna = LENGTH(whichNA); + } else { + nna = 0; + whichNA = R_NilValue; + iwhichNA = NULL; + } + + if (S3is_ordered(split)) { + cutpoint = REAL(S3get_splitpoint(split))[0]; + dx = REAL(get_variable(inputs, S3get_variableID(split))); + for (i = 0; i < nobs; i++) { + if (nna > 0) { + if (i_in_set(i + 1, iwhichNA, nna)) continue; + } + if (dx[i] <= cutpoint) + leftweights[i] = dweights[i]; + else + leftweights[i] = 0.0; + rightweights[i] = dweights[i] - leftweights[i]; + sleft += leftweights[i]; + sright += rightweights[i]; + } + } else { + levelset = INTEGER(S3get_splitpoint(split)); + ix = INTEGER(get_variable(inputs, S3get_variableID(split))); + + for (i = 0; i < nobs; i++) { + if (nna > 0) { + if (i_in_set(i + 1, iwhichNA, nna)) continue; + } + if (levelset[ix[i] - 1]) + leftweights[i] = dweights[i]; + else + leftweights[i] = 0.0; + rightweights[i] = dweights[i] - leftweights[i]; + sleft += leftweights[i]; + sright += rightweights[i]; + } + } + + /* for the moment: NA's go with majority */ + if (nna > 0) { + for (i = 0; i < nna; i++) { + if (sleft > sright) { + leftweights[iwhichNA[i] - 1] = dweights[iwhichNA[i] - 1]; + rightweights[iwhichNA[i] - 1] = 0.0; + } else { + rightweights[iwhichNA[i] - 1] = dweights[iwhichNA[i] - 1]; + leftweights[iwhichNA[i] - 1] = 0.0; + } + } + } +} + + +/** + Get the terminal node for obs. number `numobs' of `newinputs' \n + *\param subtree a tree + *\param newinputs an object of class `VariableFrame' + *\param mincriterion overwrites mincriterion used for tree growing + *\param numobs observation number + *\param varperm which variable shall be permuted? + *\todo handle surrogate splits +*/ + +SEXP C_get_node(SEXP subtree, SEXP newinputs, + double mincriterion, int numobs, int varperm) { + + SEXP split, whichNA, ssplit, surrsplit; + double cutpoint, x, swleft, swright; + int level, *levelset, i, ns; + + if (S3get_nodeterminal(subtree) || + REAL(S3get_maxcriterion(subtree))[0] < mincriterion) + return(subtree); + + split = S3get_primarysplit(subtree); + + /* Maybe store the proportions left / right in each node? */ + swleft = S3get_sumweights(S3get_leftnode(subtree)); + swright = S3get_sumweights(S3get_rightnode(subtree)); + + /* splits based on variable varperm are random */ + if (S3get_variableID(split) == varperm) { + if (unif_rand() < swleft / (swleft + swright)) { + return(C_get_node(S3get_leftnode(subtree), + newinputs, mincriterion, numobs, varperm)); + } else { + return(C_get_node(S3get_rightnode(subtree), + newinputs, mincriterion, numobs, varperm)); + } + } + + /* missing values */ + if (has_missings(newinputs, S3get_variableID(split))) { + whichNA = get_missings(newinputs, S3get_variableID(split)); + + /* numobs 0 ... n - 1 but whichNA has 1:n */ + if (C_i_in_set(numobs + 1, whichNA)) { + + surrsplit = S3get_surrogatesplits(subtree); + ns = 0; + i = numobs; + + /* try to find a surrogate split */ + while(TRUE) { + + if (ns >= LENGTH(surrsplit)) break; + + ssplit = VECTOR_ELT(surrsplit, ns); + if (has_missings(newinputs, S3get_variableID(ssplit))) { + if (INTEGER(get_missings(newinputs, + S3get_variableID(ssplit)))[i]) { + ns++; + continue; + } + } + + cutpoint = REAL(S3get_splitpoint(ssplit))[0]; + x = REAL(get_variable(newinputs, S3get_variableID(ssplit)))[i]; + + if (S3get_toleft(ssplit)) { + if (x <= cutpoint) { + return(C_get_node(S3get_leftnode(subtree), + newinputs, mincriterion, numobs, varperm)); + } else { + return(C_get_node(S3get_rightnode(subtree), + newinputs, mincriterion, numobs, varperm)); + } + } else { + if (x <= cutpoint) { + return(C_get_node(S3get_rightnode(subtree), + newinputs, mincriterion, numobs, varperm)); + } else { + return(C_get_node(S3get_leftnode(subtree), + newinputs, mincriterion, numobs, varperm)); + } + } + break; + } + + /* if this was not successful, we go with the majority */ + if (swleft > swright) { + return(C_get_node(S3get_leftnode(subtree), + newinputs, mincriterion, numobs, varperm)); + } else { + return(C_get_node(S3get_rightnode(subtree), + newinputs, mincriterion, numobs, varperm)); + } + } + } + + if (S3is_ordered(split)) { + cutpoint = REAL(S3get_splitpoint(split))[0]; + x = REAL(get_variable(newinputs, + S3get_variableID(split)))[numobs]; + if (x <= cutpoint) { + return(C_get_node(S3get_leftnode(subtree), + newinputs, mincriterion, numobs, varperm)); + } else { + return(C_get_node(S3get_rightnode(subtree), + newinputs, mincriterion, numobs, varperm)); + } + } else { + levelset = INTEGER(S3get_splitpoint(split)); + level = INTEGER(get_variable(newinputs, + S3get_variableID(split)))[numobs]; + /* level is in 1, ..., K */ + if (levelset[level - 1]) { + return(C_get_node(S3get_leftnode(subtree), newinputs, + mincriterion, numobs, varperm)); + } else { + return(C_get_node(S3get_rightnode(subtree), newinputs, + mincriterion, numobs, varperm)); + } + } +} + + +/** + R-Interface to C_get_node \n + *\param subtree a tree + *\param newinputs an object of class `VariableFrame' + *\param mincriterion overwrites mincriterion used for tree growing + *\param numobs observation number +*/ + +SEXP R_get_node(SEXP subtree, SEXP newinputs, SEXP mincriterion, + SEXP numobs, SEXP varperm) { + return(C_get_node(subtree, newinputs, REAL(mincriterion)[0], + INTEGER(numobs)[0] - 1, INTEGER(varperm)[0])); +} + + +/** + Get the node with nodeID `nodenum' \n + *\param subtree a tree + *\param nodenum a nodeID +*/ + +SEXP C_get_nodebynum(SEXP subtree, int nodenum) { + + if (nodenum == S3get_nodeID(subtree)) return(subtree); + + if (S3get_nodeterminal(subtree)) + error("no node with number %d\n", nodenum); + + if (nodenum < S3get_nodeID(S3get_rightnode(subtree))) { + return(C_get_nodebynum(S3get_leftnode(subtree), nodenum)); + } else { + return(C_get_nodebynum(S3get_rightnode(subtree), nodenum)); + } +} + + +/** + R-Interface to C_get_nodenum \n + *\param subtree a tree + *\param nodenum a nodeID +*/ + +SEXP R_get_nodebynum(SEXP subtree, SEXP nodenum) { + return(C_get_nodebynum(subtree, INTEGER(nodenum)[0])); +} + + +/** + Get the prediction of a new observation\n + *\param subtree a tree + *\param newinputs an object of class `VariableFrame' + *\param mincriterion overwrites mincriterion used for tree growing + *\param numobs observation number + *\param varperm which variable shall be permuted? +*/ + +SEXP C_get_prediction(SEXP subtree, SEXP newinputs, + double mincriterion, int numobs, int varperm) { + return(S3get_prediction(C_get_node(subtree, newinputs, + mincriterion, numobs, varperm))); +} + + +/** + Get the weights for a new observation \n + *\param subtree a tree + *\param newinputs an object of class `VariableFrame' + *\param mincriterion overwrites mincriterion used for tree growing + *\param numobs observation number +*/ + +SEXP C_get_nodeweights(SEXP subtree, SEXP newinputs, + double mincriterion, int numobs) { + return(S3get_nodeweights(C_get_node(subtree, newinputs, + mincriterion, numobs, -1))); +} + + +/** + Get the nodeID for a new observation \n + *\param subtree a tree + *\param newinputs an object of class `VariableFrame' + *\param mincriterion overwrites mincriterion used for tree growing + *\param numobs observation number + *\param varperm which variable shall be permuted? +*/ + +int C_get_nodeID(SEXP subtree, SEXP newinputs, + double mincriterion, int numobs, int varperm) { + return(S3get_nodeID(C_get_node(subtree, newinputs, + mincriterion, numobs, varperm))); +} + + +/** + R-Interface to C_get_nodeID \n + *\param tree a tree + *\param newinputs an object of class `VariableFrame' + *\param mincriterion overwrites mincriterion used for tree growing +*/ + +SEXP R_get_nodeID(SEXP tree, SEXP newinputs, SEXP mincriterion, SEXP varperm) { + + SEXP ans; + int nobs, i, *dans; + + nobs = get_nobs(newinputs); + PROTECT(ans = allocVector(INTSXP, nobs)); + dans = INTEGER(ans); + for (i = 0; i < nobs; i++) + dans[i] = C_get_nodeID(tree, newinputs, REAL(mincriterion)[0], i, INTEGER(varperm)[0]); + UNPROTECT(1); + return(ans); +} + + +/** + Get all predictions for `newinputs' \n + *\param tree a tree + *\param newinputs an object of class `VariableFrame' + *\param mincriterion overwrites mincriterion used for tree growing + *\param varperm which variable shall be permuted? + *\param ans return value +*/ + +void C_predict(SEXP tree, SEXP newinputs, double mincriterion, + int varperm, SEXP ans) { + + int nobs, i; + + nobs = get_nobs(newinputs); + if (LENGTH(ans) != nobs) + error("ans is not of length %d\n", nobs); + + for (i = 0; i < nobs; i++) + SET_VECTOR_ELT(ans, i, C_get_prediction(tree, newinputs, + mincriterion, i, varperm)); +} + + +/** + R-Interface to C_predict \n + *\param tree a tree + *\param newinputs an object of class `VariableFrame' + *\param mincriterion overwrites mincriterion used for tree growing + *\param varperm which variable shall be permuted? -1 for no permutation +*/ + +SEXP R_predict(SEXP tree, SEXP newinputs, SEXP mincriterion, + SEXP varperm) { + + SEXP ans; + int nobs; + + nobs = get_nobs(newinputs); + PROTECT(ans = allocVector(VECSXP, nobs)); + GetRNGstate(); + C_predict(tree, newinputs, REAL(mincriterion)[0], + INTEGER(varperm)[0], ans); + PutRNGstate(); + UNPROTECT(1); + return(ans); +} + +/** + Get the predictions from `where' nodes\n + *\param tree a tree + *\param where vector of nodeID's + *\param ans return value +*/ + +void C_getpredictions(SEXP tree, SEXP where, SEXP ans) { + + int nobs, i, *iwhere; + + nobs = LENGTH(where); + iwhere = INTEGER(where); + if (LENGTH(ans) != nobs) + error("ans is not of length %d\n", nobs); + + for (i = 0; i < nobs; i++) + SET_VECTOR_ELT(ans, i, S3get_prediction( + C_get_nodebynum(tree, iwhere[i]))); +} + + +/** + R-Interface to C_getpredictions\n + *\param tree a tree + *\param where vector of nodeID's +*/ + +SEXP R_getpredictions(SEXP tree, SEXP where) { + + SEXP ans; + int nobs; + + nobs = LENGTH(where); + PROTECT(ans = allocVector(VECSXP, nobs)); + C_getpredictions(tree, where, ans); + UNPROTECT(1); + return(ans); +} + +/** + Predictions weights from RandomForest objects + *\param forest a list of trees + *\param where list (length b) of integer vectors (length n) containing terminal node numbers + *\param weights list (length b) of bootstrap case weights + *\param newinputs an object of class `VariableFrame' + *\param mincriterion overwrites mincriterion used for tree growing + *\param oobpred a logical indicating out-of-bag predictions +*/ + +SEXP R_predictRF_weights(SEXP forest, SEXP where, SEXP weights, + SEXP newinputs, SEXP mincriterion, SEXP oobpred, SEXP expand) { + + SEXP ans, tree, bw, expand_exp; + int ntrees, nobs, i, b, j, iwhere, oob = 0, count = 0, ntrain; + int errorOccurred; + + if (LOGICAL(oobpred)[0]) oob = 1; + + nobs = get_nobs(newinputs); + ntrees = LENGTH(forest); + + if (oob) { + if (LENGTH(VECTOR_ELT(weights, 0)) != nobs) + error("number of observations don't match"); + } + + tree = VECTOR_ELT(forest, 0); + ntrain = LENGTH(VECTOR_ELT(weights, 0)); + + PROTECT(ans = allocVector(VECSXP, nobs)); + + for (i = 0; i < nobs; i++) { + count = 0; + SET_VECTOR_ELT(ans, i, bw = allocVector(REALSXP, ntrain)); + for (j = 0; j < ntrain; j++) + REAL(bw)[j] = 0.0; + for (b = 0; b < ntrees; b++) { + tree = VECTOR_ELT(forest, b); + PROTECT(expand_exp = lang2(expand, tree)); + tree = R_tryEval(expand_exp, R_GlobalEnv, &errorOccurred); + UNPROTECT(1); + if(errorOccurred) { + Rprintf("error calling expand\n"); + break; + } + PROTECT(tree); + + if (oob && + REAL(VECTOR_ELT(weights, b))[i] > 0.0) { + UNPROTECT(1); + continue; + } + + iwhere = C_get_nodeID(tree, newinputs, REAL(mincriterion)[0], i, -1); + + for (j = 0; j < ntrain; j++) { + if (iwhere == INTEGER(VECTOR_ELT(where, b))[j]) + REAL(bw)[j] += REAL(VECTOR_ELT(weights, b))[j]; + } + count++; + UNPROTECT(1); + } + if(errorOccurred) + break; + + if (count == 0) + error("cannot compute out-of-bag predictions for observation number %d", i + 1); + } + UNPROTECT(1); + + if(errorOccurred) { + return NULL; + } + + return(ans); +} + + +/** + Proximity matrix for random forests + *\param where list (length b) of integer vectors (length n) containing terminal node numbers +*/ + +SEXP R_proximity(SEXP where) { + + SEXP ans, bw, bin; + int ntrees, nobs, i, b, j, iwhere; + + ntrees = LENGTH(where); + nobs = LENGTH(VECTOR_ELT(where, 0)); + + PROTECT(ans = allocVector(VECSXP, nobs)); + PROTECT(bin = allocVector(INTSXP, nobs)); + + for (i = 0; i < nobs; i++) { + SET_VECTOR_ELT(ans, i, bw = allocVector(REALSXP, nobs)); + for (j = 0; j < nobs; j++) { + REAL(bw)[j] = 0.0; + INTEGER(bin)[j] = 0; + } + for (b = 0; b < ntrees; b++) { + /* don't look at out-of-bag observations */ + if (INTEGER(VECTOR_ELT(where, b))[i] == 0) + continue; + iwhere = INTEGER(VECTOR_ELT(where, b))[i]; + for (j = 0; j < nobs; j++) { + if (iwhere == INTEGER(VECTOR_ELT(where, b))[j]) + /* only count the number of trees; no weights */ + REAL(bw)[j]++; + if (INTEGER(VECTOR_ELT(where, b))[j] > 0) + /* count the number of bootstrap samples + containing both i and j */ + INTEGER(bin)[j]++; + } + } + for (j = 0; j < nobs; j++) + REAL(bw)[j] = REAL(bw)[j] / INTEGER(bin)[j]; + } + UNPROTECT(2); + return(ans); +}