Diff of /partyMod/src/TreeGrow.c [000000] .. [fbf06f]

Switch to side-by-side view

--- a
+++ b/partyMod/src/TreeGrow.c
@@ -0,0 +1,139 @@
+
+/**
+    The tree growing recursion
+    *\file TreeGrow.c
+    *\author $Author$
+    *\date $Date$
+*/
+
+#include "party.h"
+
+
+int *copy_with_ignoring(int *variables_to_ignore, int ninputs, int index) {
+   int i;
+   int *new_mask;
+   
+   if(variables_to_ignore == NULL) {
+//     printf("copy_with_ignoring(NULL, %d, %d)\n", ninputs, index);
+     return NULL;
+   }
+   
+/*   printf("copy_with_ignoring([");
+   for(i = 0;i<ninputs;i++) {
+     if(i>0) { 
+       printf(", ");
+     }
+     printf("%d", variables_to_ignore[i]);
+   }
+   printf("], %d, %d)\n", ninputs, index);
+ */
+   
+   new_mask = Calloc(ninputs, int);
+   for(i = 0;i<ninputs;i++) {
+     new_mask[i] = i == index ? 1 : variables_to_ignore[i];
+   }
+   return new_mask;
+}
+
+/**
+    The main tree growing function, handles the recursion. \n
+    *\param node  a list representing the current node
+    *\param learnsample an object of class `LearningSample'
+    *\param fitmem an object of class `TreeFitMemory'
+    *\param controls an object of class `TreeControl'
+    *\param where a pointer to an integer vector of n-elements
+    *\param nodenum a pointer to a integer vector of length 1
+    *\param depth an integer giving the depth of the current node
+*/
+
+void C_TreeGrow(SEXP node, SEXP learnsample, SEXP fitmem, 
+                SEXP controls, int *where, int *nodenum, int depth, int *variables_to_ignore) {
+
+    SEXP weights;
+    int nobs, i, stop;
+    double *dweights;
+    SEXP split;
+    int ninputs;
+    int *child_variables_to_ignore;
+    
+    ninputs = get_ninputs(learnsample);
+    weights = S3get_nodeweights(node);
+    
+    /* stop if either stumps have been requested or 
+       the maximum depth is exceeded */
+    stop = (nodenum[0] == 2 || nodenum[0] == 3) && 
+           get_stump(get_tgctrl(controls));
+    stop = stop || !check_depth(get_tgctrl(controls), depth);
+    
+    if (stop)
+        C_Node(node, learnsample, weights, fitmem, controls, 1, depth, variables_to_ignore);
+    else
+        C_Node(node, learnsample, weights, fitmem, controls, 0, depth, variables_to_ignore);
+    
+    S3set_nodeID(node, nodenum[0]);    
+    
+    if (!S3get_nodeterminal(node)) {
+
+        C_splitnode(node, learnsample, controls);
+
+        /* determine surrogate splits and split missing values */
+        if (get_maxsurrogate(get_splitctrl(controls)) > 0) {
+            C_surrogates(node, learnsample, weights, controls, fitmem);
+            C_splitsurrogate(node, learnsample);
+        }
+
+        split = S3get_primarysplit(node);
+        child_variables_to_ignore = copy_with_ignoring(variables_to_ignore, ninputs, S3get_variableID(split)-1);
+        nodenum[0] += 1;
+        C_TreeGrow(S3get_leftnode(node), learnsample, fitmem, 
+                   controls, where, nodenum, depth + 1, child_variables_to_ignore);
+
+        nodenum[0] += 1;                                      
+        C_TreeGrow(S3get_rightnode(node), learnsample, fitmem, 
+                   controls, where, nodenum, depth + 1, child_variables_to_ignore);
+                   
+        Free(child_variables_to_ignore);
+                   
+    } else {
+        dweights = REAL(weights);
+        nobs = get_nobs(learnsample);
+        for (i = 0; i < nobs; i++)
+            if (dweights[i] > 0) where[i] = nodenum[0];
+    } 
+}
+
+
+/**
+    R-interface to C_TreeGrow\n
+    *\param learnsample an object of class `LearningSample'
+    *\param weights a vector of case weights
+    *\param fitmem an object of class `TreeFitMemory'
+    *\param controls an object of class `TreeControl'
+    *\param where a vector of node indices for each observation
+*/
+
+SEXP R_TreeGrow(SEXP learnsample, SEXP weights, SEXP fitmem, SEXP controls, SEXP where) {
+            
+     SEXP ans, nweights;
+     double *dnweights, *dweights;
+     int nobs, i, nodenum = 1;
+
+     GetRNGstate();
+     
+     nobs = get_nobs(learnsample);
+     PROTECT(ans = allocVector(VECSXP, NODE_LENGTH));
+     C_init_node(ans, nobs, get_ninputs(learnsample), get_maxsurrogate(get_splitctrl(controls)),
+                 ncol(get_predict_trafo(GET_SLOT(learnsample, PL2_responsesSym))));
+
+     nweights = S3get_nodeweights(ans);
+     dnweights = REAL(nweights);
+     dweights = REAL(weights);
+     for (i = 0; i < nobs; i++) dnweights[i] = dweights[i];
+     
+     C_TreeGrow(ans, learnsample, fitmem, controls, INTEGER(where), &nodenum, 1, NULL);
+     
+     PutRNGstate();
+     
+     UNPROTECT(1);
+     return(ans);
+}