|
a |
|
b/partyMod/src/SurrogateSplits.c |
|
|
1 |
|
|
|
2 |
/** |
|
|
3 |
Suggorgate splits |
|
|
4 |
*\file SurrogateSplits.c |
|
|
5 |
*\author $Author$ |
|
|
6 |
*\date $Date$ |
|
|
7 |
*/ |
|
|
8 |
|
|
|
9 |
#include "party.h" |
|
|
10 |
|
|
|
11 |
/** |
|
|
12 |
Search for surrogate splits for bypassing the primary split \n |
|
|
13 |
*\param node the current node with primary split specified |
|
|
14 |
*\param learnsample learning sample |
|
|
15 |
*\param weights the weights associated with the current node |
|
|
16 |
*\param controls an object of class `TreeControl' |
|
|
17 |
*\param fitmem an object of class `TreeFitMemory' |
|
|
18 |
*\todo enable nominal surrogate split variables as well |
|
|
19 |
*/ |
|
|
20 |
|
|
|
21 |
void C_surrogates(SEXP node, SEXP learnsample, SEXP weights, SEXP controls, |
|
|
22 |
SEXP fitmem) { |
|
|
23 |
|
|
|
24 |
SEXP x, y, expcovinf; |
|
|
25 |
SEXP splitctrl, inputs; |
|
|
26 |
SEXP split, thiswhichNA; |
|
|
27 |
int nobs, ninputs, i, j, k, jselect, maxsurr, *order, nvar = 0; |
|
|
28 |
double ms, cp, *thisweights, *cutpoint, *maxstat, |
|
|
29 |
*splitstat, *dweights, *tweights, *dx, *dy; |
|
|
30 |
double cut, *twotab, *ytmp, sumw = 0.0; |
|
|
31 |
|
|
|
32 |
nobs = get_nobs(learnsample); |
|
|
33 |
ninputs = get_ninputs(learnsample); |
|
|
34 |
splitctrl = get_splitctrl(controls); |
|
|
35 |
maxsurr = get_maxsurrogate(splitctrl); |
|
|
36 |
inputs = GET_SLOT(learnsample, PL2_inputsSym); |
|
|
37 |
jselect = S3get_variableID(S3get_primarysplit(node)); |
|
|
38 |
|
|
|
39 |
/* (weights > 0) in left node are the new `response' to be approximated */ |
|
|
40 |
y = S3get_nodeweights(VECTOR_ELT(node, S3_LEFT)); |
|
|
41 |
ytmp = Calloc(nobs, double); |
|
|
42 |
for (i = 0; i < nobs; i++) { |
|
|
43 |
ytmp[i] = REAL(y)[i]; |
|
|
44 |
if (ytmp[i] > 1.0) ytmp[i] = 1.0; |
|
|
45 |
} |
|
|
46 |
|
|
|
47 |
for (j = 0; j < ninputs; j++) { |
|
|
48 |
if (is_nominal(inputs, j + 1)) continue; |
|
|
49 |
nvar++; |
|
|
50 |
} |
|
|
51 |
nvar--; |
|
|
52 |
|
|
|
53 |
if (maxsurr != LENGTH(S3get_surrogatesplits(node))) |
|
|
54 |
error("nodes does not have %d surrogate splits", maxsurr); |
|
|
55 |
if (maxsurr > nvar) |
|
|
56 |
error("cannot set up %d surrogate splits with only %d ordered input variable(s)", |
|
|
57 |
maxsurr, nvar); |
|
|
58 |
|
|
|
59 |
tweights = Calloc(nobs, double); |
|
|
60 |
dweights = REAL(weights); |
|
|
61 |
for (i = 0; i < nobs; i++) tweights[i] = dweights[i]; |
|
|
62 |
if (has_missings(inputs, jselect)) { |
|
|
63 |
thiswhichNA = get_missings(inputs, jselect); |
|
|
64 |
for (k = 0; k < LENGTH(thiswhichNA); k++) |
|
|
65 |
tweights[INTEGER(thiswhichNA)[k] - 1] = 0.0; |
|
|
66 |
} |
|
|
67 |
|
|
|
68 |
/* check if sum(weights) > 1 */ |
|
|
69 |
sumw = 0.0; |
|
|
70 |
for (i = 0; i < nobs; i++) sumw += tweights[i]; |
|
|
71 |
if (sumw < 2.0) |
|
|
72 |
error("can't implement surrogate splits, not enough observations available"); |
|
|
73 |
|
|
|
74 |
expcovinf = GET_SLOT(fitmem, PL2_expcovinfssSym); |
|
|
75 |
C_ExpectCovarInfluence(ytmp, 1, tweights, nobs, expcovinf); |
|
|
76 |
|
|
|
77 |
splitstat = REAL(get_splitstatistics(fitmem)); |
|
|
78 |
/* <FIXME> extend `TreeFitMemory' to those as well ... */ |
|
|
79 |
maxstat = Calloc(ninputs, double); |
|
|
80 |
cutpoint = Calloc(ninputs, double); |
|
|
81 |
order = Calloc(ninputs, int); |
|
|
82 |
/* <FIXME> */ |
|
|
83 |
|
|
|
84 |
/* this is essentially an exhaustive search */ |
|
|
85 |
/* <FIXME>: we don't want to do this for random forest like trees |
|
|
86 |
</FIXME> |
|
|
87 |
*/ |
|
|
88 |
for (j = 0; j < ninputs; j++) { |
|
|
89 |
|
|
|
90 |
order[j] = j + 1; |
|
|
91 |
maxstat[j] = 0.0; |
|
|
92 |
cutpoint[j] = 0.0; |
|
|
93 |
|
|
|
94 |
/* ordered input variables only (for the moment) */ |
|
|
95 |
if ((j + 1) == jselect || is_nominal(inputs, j + 1)) |
|
|
96 |
continue; |
|
|
97 |
|
|
|
98 |
x = get_variable(inputs, j + 1); |
|
|
99 |
|
|
|
100 |
if (has_missings(inputs, j + 1)) { |
|
|
101 |
|
|
|
102 |
thisweights = C_tempweights(j + 1, weights, fitmem, inputs); |
|
|
103 |
|
|
|
104 |
/* check if sum(weights) > 1 */ |
|
|
105 |
sumw = 0.0; |
|
|
106 |
for (i = 0; i < nobs; i++) sumw += thisweights[i]; |
|
|
107 |
if (sumw < 2.0) continue; |
|
|
108 |
|
|
|
109 |
C_ExpectCovarInfluence(ytmp, 1, thisweights, nobs, expcovinf); |
|
|
110 |
|
|
|
111 |
C_split(REAL(x), 1, ytmp, 1, thisweights, nobs, |
|
|
112 |
INTEGER(get_ordering(inputs, j + 1)), splitctrl, |
|
|
113 |
GET_SLOT(fitmem, PL2_linexpcov2sampleSym), |
|
|
114 |
expcovinf, &cp, &ms, splitstat); |
|
|
115 |
} else { |
|
|
116 |
|
|
|
117 |
C_split(REAL(x), 1, ytmp, 1, tweights, nobs, |
|
|
118 |
INTEGER(get_ordering(inputs, j + 1)), splitctrl, |
|
|
119 |
GET_SLOT(fitmem, PL2_linexpcov2sampleSym), |
|
|
120 |
expcovinf, &cp, &ms, splitstat); |
|
|
121 |
} |
|
|
122 |
|
|
|
123 |
maxstat[j] = -ms; |
|
|
124 |
cutpoint[j] = cp; |
|
|
125 |
} |
|
|
126 |
|
|
|
127 |
|
|
|
128 |
/* order with respect to maximal statistic */ |
|
|
129 |
rsort_with_index(maxstat, order, ninputs); |
|
|
130 |
|
|
|
131 |
twotab = Calloc(4, double); |
|
|
132 |
|
|
|
133 |
/* the best `maxsurr' ones are implemented */ |
|
|
134 |
for (j = 0; j < maxsurr; j++) { |
|
|
135 |
|
|
|
136 |
if (is_nominal(inputs, order[j])) continue; |
|
|
137 |
|
|
|
138 |
for (i = 0; i < 4; i++) twotab[i] = 0.0; |
|
|
139 |
cut = cutpoint[order[j] - 1]; |
|
|
140 |
SET_VECTOR_ELT(S3get_surrogatesplits(node), j, |
|
|
141 |
split = allocVector(VECSXP, SPLIT_LENGTH)); |
|
|
142 |
C_init_orderedsplit(split, 0); |
|
|
143 |
S3set_variableID(split, order[j]); |
|
|
144 |
REAL(S3get_splitpoint(split))[0] = cut; |
|
|
145 |
dx = REAL(get_variable(inputs, order[j])); |
|
|
146 |
dy = REAL(y); |
|
|
147 |
|
|
|
148 |
/* OK, this is a dirty hack: determine if the split |
|
|
149 |
goes left or right by the Pearson residual of a 2x2 table. |
|
|
150 |
I don't want to use the big caliber here |
|
|
151 |
*/ |
|
|
152 |
for (i = 0; i < nobs; i++) { |
|
|
153 |
twotab[0] += ((dy[i] == 1) && (dx[i] <= cut)) * tweights[i]; |
|
|
154 |
twotab[1] += (dy[i] == 1) * tweights[i]; |
|
|
155 |
twotab[2] += (dx[i] <= cut) * tweights[i]; |
|
|
156 |
twotab[3] += tweights[i]; |
|
|
157 |
} |
|
|
158 |
S3set_toleft(split, (int) (twotab[0] - twotab[1] * twotab[2] / |
|
|
159 |
twotab[3]) > 0); |
|
|
160 |
} |
|
|
161 |
|
|
|
162 |
Free(maxstat); |
|
|
163 |
Free(cutpoint); |
|
|
164 |
Free(order); |
|
|
165 |
Free(tweights); |
|
|
166 |
Free(twotab); |
|
|
167 |
Free(ytmp); |
|
|
168 |
} |
|
|
169 |
|
|
|
170 |
/** |
|
|
171 |
R-interface to C_surrogates \n |
|
|
172 |
*\param node the current node with primary split specified |
|
|
173 |
*\param learnsample learning sample |
|
|
174 |
*\param weights the weights associated with the current node |
|
|
175 |
*\param controls an object of class `TreeControl' |
|
|
176 |
*\param fitmem an object of class `TreeFitMemory' |
|
|
177 |
*/ |
|
|
178 |
|
|
|
179 |
|
|
|
180 |
SEXP R_surrogates(SEXP node, SEXP learnsample, SEXP weights, SEXP controls, |
|
|
181 |
SEXP fitmem) { |
|
|
182 |
|
|
|
183 |
C_surrogates(node, learnsample, weights, controls, fitmem); |
|
|
184 |
return(S3get_surrogatesplits(node)); |
|
|
185 |
|
|
|
186 |
} |
|
|
187 |
|
|
|
188 |
/** |
|
|
189 |
Split with missing values \n |
|
|
190 |
*\param node the current node with primary and surrogate splits |
|
|
191 |
specified |
|
|
192 |
*\param learnsample learning sample |
|
|
193 |
*/ |
|
|
194 |
|
|
|
195 |
void C_splitsurrogate(SEXP node, SEXP learnsample) { |
|
|
196 |
|
|
|
197 |
SEXP weights, split, surrsplit; |
|
|
198 |
SEXP inputs, whichNA, whichNAns; |
|
|
199 |
double cutpoint, *dx, *dweights, *leftweights, *rightweights; |
|
|
200 |
int *iwhichNA, k; |
|
|
201 |
int i, nna, ns; |
|
|
202 |
|
|
|
203 |
weights = S3get_nodeweights(node); |
|
|
204 |
dweights = REAL(weights); |
|
|
205 |
inputs = GET_SLOT(learnsample, PL2_inputsSym); |
|
|
206 |
|
|
|
207 |
leftweights = REAL(S3get_nodeweights(S3get_leftnode(node))); |
|
|
208 |
rightweights = REAL(S3get_nodeweights(S3get_rightnode(node))); |
|
|
209 |
surrsplit = S3get_surrogatesplits(node); |
|
|
210 |
|
|
|
211 |
/* if the primary split has any missings */ |
|
|
212 |
split = S3get_primarysplit(node); |
|
|
213 |
if (has_missings(inputs, S3get_variableID(split))) { |
|
|
214 |
|
|
|
215 |
/* where are the missings? */ |
|
|
216 |
whichNA = get_missings(inputs, S3get_variableID(split)); |
|
|
217 |
iwhichNA = INTEGER(whichNA); |
|
|
218 |
nna = LENGTH(whichNA); |
|
|
219 |
|
|
|
220 |
/* for all missing values ... */ |
|
|
221 |
for (k = 0; k < nna; k++) { |
|
|
222 |
ns = 0; |
|
|
223 |
i = iwhichNA[k] - 1; |
|
|
224 |
if (dweights[i] == 0) continue; |
|
|
225 |
|
|
|
226 |
/* loop over surrogate splits until an appropriate one is found */ |
|
|
227 |
while(TRUE) { |
|
|
228 |
|
|
|
229 |
if (ns >= LENGTH(surrsplit)) break; |
|
|
230 |
|
|
|
231 |
split = VECTOR_ELT(surrsplit, ns); |
|
|
232 |
if (has_missings(inputs, S3get_variableID(split))) { |
|
|
233 |
whichNAns = get_missings(inputs, S3get_variableID(split)); |
|
|
234 |
if (C_i_in_set(i + 1, whichNAns)) { |
|
|
235 |
ns++; |
|
|
236 |
continue; |
|
|
237 |
} |
|
|
238 |
} |
|
|
239 |
|
|
|
240 |
cutpoint = REAL(S3get_splitpoint(split))[0]; |
|
|
241 |
dx = REAL(get_variable(inputs, S3get_variableID(split))); |
|
|
242 |
|
|
|
243 |
if (S3get_toleft(split)) { |
|
|
244 |
if (dx[i] <= cutpoint) { |
|
|
245 |
leftweights[i] = dweights[i]; |
|
|
246 |
rightweights[i] = 0.0; |
|
|
247 |
} else { |
|
|
248 |
rightweights[i] = dweights[i]; |
|
|
249 |
leftweights[i] = 0.0; |
|
|
250 |
} |
|
|
251 |
} else { |
|
|
252 |
if (dx[i] <= cutpoint) { |
|
|
253 |
rightweights[i] = dweights[i]; |
|
|
254 |
leftweights[i] = 0.0; |
|
|
255 |
} else { |
|
|
256 |
leftweights[i] = dweights[i]; |
|
|
257 |
rightweights[i] = 0.0; |
|
|
258 |
} |
|
|
259 |
} |
|
|
260 |
break; |
|
|
261 |
} |
|
|
262 |
} |
|
|
263 |
} |
|
|
264 |
} |