--- a +++ b/partyMod/src/TreeGrow.c @@ -0,0 +1,139 @@ + +/** + The tree growing recursion + *\file TreeGrow.c + *\author $Author$ + *\date $Date$ +*/ + +#include "party.h" + + +int *copy_with_ignoring(int *variables_to_ignore, int ninputs, int index) { + int i; + int *new_mask; + + if(variables_to_ignore == NULL) { +// printf("copy_with_ignoring(NULL, %d, %d)\n", ninputs, index); + return NULL; + } + +/* printf("copy_with_ignoring(["); + for(i = 0;i<ninputs;i++) { + if(i>0) { + printf(", "); + } + printf("%d", variables_to_ignore[i]); + } + printf("], %d, %d)\n", ninputs, index); + */ + + new_mask = Calloc(ninputs, int); + for(i = 0;i<ninputs;i++) { + new_mask[i] = i == index ? 1 : variables_to_ignore[i]; + } + return new_mask; +} + +/** + The main tree growing function, handles the recursion. \n + *\param node a list representing the current node + *\param learnsample an object of class `LearningSample' + *\param fitmem an object of class `TreeFitMemory' + *\param controls an object of class `TreeControl' + *\param where a pointer to an integer vector of n-elements + *\param nodenum a pointer to a integer vector of length 1 + *\param depth an integer giving the depth of the current node +*/ + +void C_TreeGrow(SEXP node, SEXP learnsample, SEXP fitmem, + SEXP controls, int *where, int *nodenum, int depth, int *variables_to_ignore) { + + SEXP weights; + int nobs, i, stop; + double *dweights; + SEXP split; + int ninputs; + int *child_variables_to_ignore; + + ninputs = get_ninputs(learnsample); + weights = S3get_nodeweights(node); + + /* stop if either stumps have been requested or + the maximum depth is exceeded */ + stop = (nodenum[0] == 2 || nodenum[0] == 3) && + get_stump(get_tgctrl(controls)); + stop = stop || !check_depth(get_tgctrl(controls), depth); + + if (stop) + C_Node(node, learnsample, weights, fitmem, controls, 1, depth, variables_to_ignore); + else + C_Node(node, learnsample, weights, fitmem, controls, 0, depth, variables_to_ignore); + + S3set_nodeID(node, nodenum[0]); + + if (!S3get_nodeterminal(node)) { + + C_splitnode(node, learnsample, controls); + + /* determine surrogate splits and split missing values */ + if (get_maxsurrogate(get_splitctrl(controls)) > 0) { + C_surrogates(node, learnsample, weights, controls, fitmem); + C_splitsurrogate(node, learnsample); + } + + split = S3get_primarysplit(node); + child_variables_to_ignore = copy_with_ignoring(variables_to_ignore, ninputs, S3get_variableID(split)-1); + nodenum[0] += 1; + C_TreeGrow(S3get_leftnode(node), learnsample, fitmem, + controls, where, nodenum, depth + 1, child_variables_to_ignore); + + nodenum[0] += 1; + C_TreeGrow(S3get_rightnode(node), learnsample, fitmem, + controls, where, nodenum, depth + 1, child_variables_to_ignore); + + Free(child_variables_to_ignore); + + } else { + dweights = REAL(weights); + nobs = get_nobs(learnsample); + for (i = 0; i < nobs; i++) + if (dweights[i] > 0) where[i] = nodenum[0]; + } +} + + +/** + R-interface to C_TreeGrow\n + *\param learnsample an object of class `LearningSample' + *\param weights a vector of case weights + *\param fitmem an object of class `TreeFitMemory' + *\param controls an object of class `TreeControl' + *\param where a vector of node indices for each observation +*/ + +SEXP R_TreeGrow(SEXP learnsample, SEXP weights, SEXP fitmem, SEXP controls, SEXP where) { + + SEXP ans, nweights; + double *dnweights, *dweights; + int nobs, i, nodenum = 1; + + GetRNGstate(); + + nobs = get_nobs(learnsample); + PROTECT(ans = allocVector(VECSXP, NODE_LENGTH)); + C_init_node(ans, nobs, get_ninputs(learnsample), get_maxsurrogate(get_splitctrl(controls)), + ncol(get_predict_trafo(GET_SLOT(learnsample, PL2_responsesSym)))); + + nweights = S3get_nodeweights(ans); + dnweights = REAL(nweights); + dweights = REAL(weights); + for (i = 0; i < nobs; i++) dnweights[i] = dweights[i]; + + C_TreeGrow(ans, learnsample, fitmem, controls, INTEGER(where), &nodenum, 1, NULL); + + PutRNGstate(); + + UNPROTECT(1); + return(ans); +}