a b/partyMod/src/TreeGrow.c
1
2
/**
3
    The tree growing recursion
4
    *\file TreeGrow.c
5
    *\author $Author$
6
    *\date $Date$
7
*/
8
9
#include "party.h"
10
11
12
int *copy_with_ignoring(int *variables_to_ignore, int ninputs, int index) {
13
   int i;
14
   int *new_mask;
15
   
16
   if(variables_to_ignore == NULL) {
17
//     printf("copy_with_ignoring(NULL, %d, %d)\n", ninputs, index);
18
     return NULL;
19
   }
20
   
21
/*   printf("copy_with_ignoring([");
22
   for(i = 0;i<ninputs;i++) {
23
     if(i>0) { 
24
       printf(", ");
25
     }
26
     printf("%d", variables_to_ignore[i]);
27
   }
28
   printf("], %d, %d)\n", ninputs, index);
29
 */
30
   
31
   new_mask = Calloc(ninputs, int);
32
   for(i = 0;i<ninputs;i++) {
33
     new_mask[i] = i == index ? 1 : variables_to_ignore[i];
34
   }
35
   return new_mask;
36
}
37
38
/**
39
    The main tree growing function, handles the recursion. \n
40
    *\param node  a list representing the current node
41
    *\param learnsample an object of class `LearningSample'
42
    *\param fitmem an object of class `TreeFitMemory'
43
    *\param controls an object of class `TreeControl'
44
    *\param where a pointer to an integer vector of n-elements
45
    *\param nodenum a pointer to a integer vector of length 1
46
    *\param depth an integer giving the depth of the current node
47
*/
48
49
void C_TreeGrow(SEXP node, SEXP learnsample, SEXP fitmem, 
50
                SEXP controls, int *where, int *nodenum, int depth, int *variables_to_ignore) {
51
52
    SEXP weights;
53
    int nobs, i, stop;
54
    double *dweights;
55
    SEXP split;
56
    int ninputs;
57
    int *child_variables_to_ignore;
58
    
59
    ninputs = get_ninputs(learnsample);
60
    weights = S3get_nodeweights(node);
61
    
62
    /* stop if either stumps have been requested or 
63
       the maximum depth is exceeded */
64
    stop = (nodenum[0] == 2 || nodenum[0] == 3) && 
65
           get_stump(get_tgctrl(controls));
66
    stop = stop || !check_depth(get_tgctrl(controls), depth);
67
    
68
    if (stop)
69
        C_Node(node, learnsample, weights, fitmem, controls, 1, depth, variables_to_ignore);
70
    else
71
        C_Node(node, learnsample, weights, fitmem, controls, 0, depth, variables_to_ignore);
72
    
73
    S3set_nodeID(node, nodenum[0]);    
74
    
75
    if (!S3get_nodeterminal(node)) {
76
77
        C_splitnode(node, learnsample, controls);
78
79
        /* determine surrogate splits and split missing values */
80
        if (get_maxsurrogate(get_splitctrl(controls)) > 0) {
81
            C_surrogates(node, learnsample, weights, controls, fitmem);
82
            C_splitsurrogate(node, learnsample);
83
        }
84
85
        split = S3get_primarysplit(node);
86
        child_variables_to_ignore = copy_with_ignoring(variables_to_ignore, ninputs, S3get_variableID(split)-1);
87
        nodenum[0] += 1;
88
        C_TreeGrow(S3get_leftnode(node), learnsample, fitmem, 
89
                   controls, where, nodenum, depth + 1, child_variables_to_ignore);
90
91
        nodenum[0] += 1;                                      
92
        C_TreeGrow(S3get_rightnode(node), learnsample, fitmem, 
93
                   controls, where, nodenum, depth + 1, child_variables_to_ignore);
94
                   
95
        Free(child_variables_to_ignore);
96
                   
97
    } else {
98
        dweights = REAL(weights);
99
        nobs = get_nobs(learnsample);
100
        for (i = 0; i < nobs; i++)
101
            if (dweights[i] > 0) where[i] = nodenum[0];
102
    } 
103
}
104
105
106
/**
107
    R-interface to C_TreeGrow\n
108
    *\param learnsample an object of class `LearningSample'
109
    *\param weights a vector of case weights
110
    *\param fitmem an object of class `TreeFitMemory'
111
    *\param controls an object of class `TreeControl'
112
    *\param where a vector of node indices for each observation
113
*/
114
115
SEXP R_TreeGrow(SEXP learnsample, SEXP weights, SEXP fitmem, SEXP controls, SEXP where) {
116
            
117
     SEXP ans, nweights;
118
     double *dnweights, *dweights;
119
     int nobs, i, nodenum = 1;
120
121
     GetRNGstate();
122
     
123
     nobs = get_nobs(learnsample);
124
     PROTECT(ans = allocVector(VECSXP, NODE_LENGTH));
125
     C_init_node(ans, nobs, get_ninputs(learnsample), get_maxsurrogate(get_splitctrl(controls)),
126
                 ncol(get_predict_trafo(GET_SLOT(learnsample, PL2_responsesSym))));
127
128
     nweights = S3get_nodeweights(ans);
129
     dnweights = REAL(nweights);
130
     dweights = REAL(weights);
131
     for (i = 0; i < nobs; i++) dnweights[i] = dweights[i];
132
     
133
     C_TreeGrow(ans, learnsample, fitmem, controls, INTEGER(where), &nodenum, 1, NULL);
134
     
135
     PutRNGstate();
136
     
137
     UNPROTECT(1);
138
     return(ans);
139
}