[fbf06f]: / partyMod / src / TreeGrow.c

Download this file

140 lines (108 with data), 4.4 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
/**
The tree growing recursion
*\file TreeGrow.c
*\author $Author$
*\date $Date$
*/
#include "party.h"
int *copy_with_ignoring(int *variables_to_ignore, int ninputs, int index) {
int i;
int *new_mask;
if(variables_to_ignore == NULL) {
// printf("copy_with_ignoring(NULL, %d, %d)\n", ninputs, index);
return NULL;
}
/* printf("copy_with_ignoring([");
for(i = 0;i<ninputs;i++) {
if(i>0) {
printf(", ");
}
printf("%d", variables_to_ignore[i]);
}
printf("], %d, %d)\n", ninputs, index);
*/
new_mask = Calloc(ninputs, int);
for(i = 0;i<ninputs;i++) {
new_mask[i] = i == index ? 1 : variables_to_ignore[i];
}
return new_mask;
}
/**
The main tree growing function, handles the recursion. \n
*\param node a list representing the current node
*\param learnsample an object of class `LearningSample'
*\param fitmem an object of class `TreeFitMemory'
*\param controls an object of class `TreeControl'
*\param where a pointer to an integer vector of n-elements
*\param nodenum a pointer to a integer vector of length 1
*\param depth an integer giving the depth of the current node
*/
void C_TreeGrow(SEXP node, SEXP learnsample, SEXP fitmem,
SEXP controls, int *where, int *nodenum, int depth, int *variables_to_ignore) {
SEXP weights;
int nobs, i, stop;
double *dweights;
SEXP split;
int ninputs;
int *child_variables_to_ignore;
ninputs = get_ninputs(learnsample);
weights = S3get_nodeweights(node);
/* stop if either stumps have been requested or
the maximum depth is exceeded */
stop = (nodenum[0] == 2 || nodenum[0] == 3) &&
get_stump(get_tgctrl(controls));
stop = stop || !check_depth(get_tgctrl(controls), depth);
if (stop)
C_Node(node, learnsample, weights, fitmem, controls, 1, depth, variables_to_ignore);
else
C_Node(node, learnsample, weights, fitmem, controls, 0, depth, variables_to_ignore);
S3set_nodeID(node, nodenum[0]);
if (!S3get_nodeterminal(node)) {
C_splitnode(node, learnsample, controls);
/* determine surrogate splits and split missing values */
if (get_maxsurrogate(get_splitctrl(controls)) > 0) {
C_surrogates(node, learnsample, weights, controls, fitmem);
C_splitsurrogate(node, learnsample);
}
split = S3get_primarysplit(node);
child_variables_to_ignore = copy_with_ignoring(variables_to_ignore, ninputs, S3get_variableID(split)-1);
nodenum[0] += 1;
C_TreeGrow(S3get_leftnode(node), learnsample, fitmem,
controls, where, nodenum, depth + 1, child_variables_to_ignore);
nodenum[0] += 1;
C_TreeGrow(S3get_rightnode(node), learnsample, fitmem,
controls, where, nodenum, depth + 1, child_variables_to_ignore);
Free(child_variables_to_ignore);
} else {
dweights = REAL(weights);
nobs = get_nobs(learnsample);
for (i = 0; i < nobs; i++)
if (dweights[i] > 0) where[i] = nodenum[0];
}
}
/**
R-interface to C_TreeGrow\n
*\param learnsample an object of class `LearningSample'
*\param weights a vector of case weights
*\param fitmem an object of class `TreeFitMemory'
*\param controls an object of class `TreeControl'
*\param where a vector of node indices for each observation
*/
SEXP R_TreeGrow(SEXP learnsample, SEXP weights, SEXP fitmem, SEXP controls, SEXP where) {
SEXP ans, nweights;
double *dnweights, *dweights;
int nobs, i, nodenum = 1;
GetRNGstate();
nobs = get_nobs(learnsample);
PROTECT(ans = allocVector(VECSXP, NODE_LENGTH));
C_init_node(ans, nobs, get_ninputs(learnsample), get_maxsurrogate(get_splitctrl(controls)),
ncol(get_predict_trafo(GET_SLOT(learnsample, PL2_responsesSym))));
nweights = S3get_nodeweights(ans);
dnweights = REAL(nweights);
dweights = REAL(weights);
for (i = 0; i < nobs; i++) dnweights[i] = dweights[i];
C_TreeGrow(ans, learnsample, fitmem, controls, INTEGER(where), &nodenum, 1, NULL);
PutRNGstate();
UNPROTECT(1);
return(ans);
}