|
a |
|
b/partyMod/src/TreeGrow.c |
|
|
1 |
|
|
|
2 |
/** |
|
|
3 |
The tree growing recursion |
|
|
4 |
*\file TreeGrow.c |
|
|
5 |
*\author $Author$ |
|
|
6 |
*\date $Date$ |
|
|
7 |
*/ |
|
|
8 |
|
|
|
9 |
#include "party.h" |
|
|
10 |
|
|
|
11 |
|
|
|
12 |
int *copy_with_ignoring(int *variables_to_ignore, int ninputs, int index) { |
|
|
13 |
int i; |
|
|
14 |
int *new_mask; |
|
|
15 |
|
|
|
16 |
if(variables_to_ignore == NULL) { |
|
|
17 |
// printf("copy_with_ignoring(NULL, %d, %d)\n", ninputs, index); |
|
|
18 |
return NULL; |
|
|
19 |
} |
|
|
20 |
|
|
|
21 |
/* printf("copy_with_ignoring(["); |
|
|
22 |
for(i = 0;i<ninputs;i++) { |
|
|
23 |
if(i>0) { |
|
|
24 |
printf(", "); |
|
|
25 |
} |
|
|
26 |
printf("%d", variables_to_ignore[i]); |
|
|
27 |
} |
|
|
28 |
printf("], %d, %d)\n", ninputs, index); |
|
|
29 |
*/ |
|
|
30 |
|
|
|
31 |
new_mask = Calloc(ninputs, int); |
|
|
32 |
for(i = 0;i<ninputs;i++) { |
|
|
33 |
new_mask[i] = i == index ? 1 : variables_to_ignore[i]; |
|
|
34 |
} |
|
|
35 |
return new_mask; |
|
|
36 |
} |
|
|
37 |
|
|
|
38 |
/** |
|
|
39 |
The main tree growing function, handles the recursion. \n |
|
|
40 |
*\param node a list representing the current node |
|
|
41 |
*\param learnsample an object of class `LearningSample' |
|
|
42 |
*\param fitmem an object of class `TreeFitMemory' |
|
|
43 |
*\param controls an object of class `TreeControl' |
|
|
44 |
*\param where a pointer to an integer vector of n-elements |
|
|
45 |
*\param nodenum a pointer to a integer vector of length 1 |
|
|
46 |
*\param depth an integer giving the depth of the current node |
|
|
47 |
*/ |
|
|
48 |
|
|
|
49 |
void C_TreeGrow(SEXP node, SEXP learnsample, SEXP fitmem, |
|
|
50 |
SEXP controls, int *where, int *nodenum, int depth, int *variables_to_ignore) { |
|
|
51 |
|
|
|
52 |
SEXP weights; |
|
|
53 |
int nobs, i, stop; |
|
|
54 |
double *dweights; |
|
|
55 |
SEXP split; |
|
|
56 |
int ninputs; |
|
|
57 |
int *child_variables_to_ignore; |
|
|
58 |
|
|
|
59 |
ninputs = get_ninputs(learnsample); |
|
|
60 |
weights = S3get_nodeweights(node); |
|
|
61 |
|
|
|
62 |
/* stop if either stumps have been requested or |
|
|
63 |
the maximum depth is exceeded */ |
|
|
64 |
stop = (nodenum[0] == 2 || nodenum[0] == 3) && |
|
|
65 |
get_stump(get_tgctrl(controls)); |
|
|
66 |
stop = stop || !check_depth(get_tgctrl(controls), depth); |
|
|
67 |
|
|
|
68 |
if (stop) |
|
|
69 |
C_Node(node, learnsample, weights, fitmem, controls, 1, depth, variables_to_ignore); |
|
|
70 |
else |
|
|
71 |
C_Node(node, learnsample, weights, fitmem, controls, 0, depth, variables_to_ignore); |
|
|
72 |
|
|
|
73 |
S3set_nodeID(node, nodenum[0]); |
|
|
74 |
|
|
|
75 |
if (!S3get_nodeterminal(node)) { |
|
|
76 |
|
|
|
77 |
C_splitnode(node, learnsample, controls); |
|
|
78 |
|
|
|
79 |
/* determine surrogate splits and split missing values */ |
|
|
80 |
if (get_maxsurrogate(get_splitctrl(controls)) > 0) { |
|
|
81 |
C_surrogates(node, learnsample, weights, controls, fitmem); |
|
|
82 |
C_splitsurrogate(node, learnsample); |
|
|
83 |
} |
|
|
84 |
|
|
|
85 |
split = S3get_primarysplit(node); |
|
|
86 |
child_variables_to_ignore = copy_with_ignoring(variables_to_ignore, ninputs, S3get_variableID(split)-1); |
|
|
87 |
nodenum[0] += 1; |
|
|
88 |
C_TreeGrow(S3get_leftnode(node), learnsample, fitmem, |
|
|
89 |
controls, where, nodenum, depth + 1, child_variables_to_ignore); |
|
|
90 |
|
|
|
91 |
nodenum[0] += 1; |
|
|
92 |
C_TreeGrow(S3get_rightnode(node), learnsample, fitmem, |
|
|
93 |
controls, where, nodenum, depth + 1, child_variables_to_ignore); |
|
|
94 |
|
|
|
95 |
Free(child_variables_to_ignore); |
|
|
96 |
|
|
|
97 |
} else { |
|
|
98 |
dweights = REAL(weights); |
|
|
99 |
nobs = get_nobs(learnsample); |
|
|
100 |
for (i = 0; i < nobs; i++) |
|
|
101 |
if (dweights[i] > 0) where[i] = nodenum[0]; |
|
|
102 |
} |
|
|
103 |
} |
|
|
104 |
|
|
|
105 |
|
|
|
106 |
/** |
|
|
107 |
R-interface to C_TreeGrow\n |
|
|
108 |
*\param learnsample an object of class `LearningSample' |
|
|
109 |
*\param weights a vector of case weights |
|
|
110 |
*\param fitmem an object of class `TreeFitMemory' |
|
|
111 |
*\param controls an object of class `TreeControl' |
|
|
112 |
*\param where a vector of node indices for each observation |
|
|
113 |
*/ |
|
|
114 |
|
|
|
115 |
SEXP R_TreeGrow(SEXP learnsample, SEXP weights, SEXP fitmem, SEXP controls, SEXP where) { |
|
|
116 |
|
|
|
117 |
SEXP ans, nweights; |
|
|
118 |
double *dnweights, *dweights; |
|
|
119 |
int nobs, i, nodenum = 1; |
|
|
120 |
|
|
|
121 |
GetRNGstate(); |
|
|
122 |
|
|
|
123 |
nobs = get_nobs(learnsample); |
|
|
124 |
PROTECT(ans = allocVector(VECSXP, NODE_LENGTH)); |
|
|
125 |
C_init_node(ans, nobs, get_ninputs(learnsample), get_maxsurrogate(get_splitctrl(controls)), |
|
|
126 |
ncol(get_predict_trafo(GET_SLOT(learnsample, PL2_responsesSym)))); |
|
|
127 |
|
|
|
128 |
nweights = S3get_nodeweights(ans); |
|
|
129 |
dnweights = REAL(nweights); |
|
|
130 |
dweights = REAL(weights); |
|
|
131 |
for (i = 0; i < nobs; i++) dnweights[i] = dweights[i]; |
|
|
132 |
|
|
|
133 |
C_TreeGrow(ans, learnsample, fitmem, controls, INTEGER(where), &nodenum, 1, NULL); |
|
|
134 |
|
|
|
135 |
PutRNGstate(); |
|
|
136 |
|
|
|
137 |
UNPROTECT(1); |
|
|
138 |
return(ans); |
|
|
139 |
} |