a b/partyMod/src/SurrogateSplits.c
1
2
/**
3
    Suggorgate splits
4
    *\file SurrogateSplits.c
5
    *\author $Author$
6
    *\date $Date$
7
*/
8
                
9
#include "party.h"
10
11
/**
12
    Search for surrogate splits for bypassing the primary split \n
13
    *\param node the current node with primary split specified
14
    *\param learnsample learning sample
15
    *\param weights the weights associated with the current node
16
    *\param controls an object of class `TreeControl'
17
    *\param fitmem an object of class `TreeFitMemory'
18
    *\todo enable nominal surrogate split variables as well
19
*/
20
21
void C_surrogates(SEXP node, SEXP learnsample, SEXP weights, SEXP controls, 
22
                  SEXP fitmem) {
23
24
    SEXP x, y, expcovinf; 
25
    SEXP splitctrl, inputs; 
26
    SEXP split, thiswhichNA;
27
    int nobs, ninputs, i, j, k, jselect, maxsurr, *order, nvar = 0;
28
    double ms, cp, *thisweights, *cutpoint, *maxstat, 
29
           *splitstat, *dweights, *tweights, *dx, *dy;
30
    double cut, *twotab, *ytmp, sumw = 0.0;
31
    
32
    nobs = get_nobs(learnsample);
33
    ninputs = get_ninputs(learnsample);
34
    splitctrl = get_splitctrl(controls);
35
    maxsurr = get_maxsurrogate(splitctrl);
36
    inputs = GET_SLOT(learnsample, PL2_inputsSym);
37
    jselect = S3get_variableID(S3get_primarysplit(node));
38
    
39
    /* (weights > 0) in left node are the new `response' to be approximated */
40
    y = S3get_nodeweights(VECTOR_ELT(node, S3_LEFT));
41
    ytmp = Calloc(nobs, double);
42
    for (i = 0; i < nobs; i++) {
43
        ytmp[i] = REAL(y)[i];
44
        if (ytmp[i] > 1.0) ytmp[i] = 1.0;
45
    }
46
47
    for (j = 0; j < ninputs; j++) {
48
        if (is_nominal(inputs, j + 1)) continue;
49
        nvar++;
50
    }
51
    nvar--;
52
53
    if (maxsurr != LENGTH(S3get_surrogatesplits(node)))
54
        error("nodes does not have %d surrogate splits", maxsurr);
55
    if (maxsurr > nvar)
56
        error("cannot set up %d surrogate splits with only %d ordered input variable(s)", 
57
              maxsurr, nvar);
58
59
    tweights = Calloc(nobs, double);
60
    dweights = REAL(weights);
61
    for (i = 0; i < nobs; i++) tweights[i] = dweights[i];
62
    if (has_missings(inputs, jselect)) {
63
        thiswhichNA = get_missings(inputs, jselect);
64
        for (k = 0; k < LENGTH(thiswhichNA); k++)
65
            tweights[INTEGER(thiswhichNA)[k] - 1] = 0.0;
66
    }
67
68
    /* check if sum(weights) > 1 */
69
    sumw = 0.0;
70
    for (i = 0; i < nobs; i++) sumw += tweights[i];
71
    if (sumw < 2.0)
72
        error("can't implement surrogate splits, not enough observations available");
73
74
    expcovinf = GET_SLOT(fitmem, PL2_expcovinfssSym);
75
    C_ExpectCovarInfluence(ytmp, 1, tweights, nobs, expcovinf);
76
    
77
    splitstat = REAL(get_splitstatistics(fitmem));
78
    /* <FIXME> extend `TreeFitMemory' to those as well ... */
79
    maxstat = Calloc(ninputs, double);
80
    cutpoint = Calloc(ninputs, double);
81
    order = Calloc(ninputs, int);
82
    /* <FIXME> */
83
    
84
    /* this is essentially an exhaustive search */
85
    /* <FIXME>: we don't want to do this for random forest like trees 
86
       </FIXME>
87
     */
88
    for (j = 0; j < ninputs; j++) {
89
    
90
         order[j] = j + 1;
91
         maxstat[j] = 0.0;
92
         cutpoint[j] = 0.0;
93
94
         /* ordered input variables only (for the moment) */
95
         if ((j + 1) == jselect || is_nominal(inputs, j + 1))
96
             continue;
97
98
         x = get_variable(inputs, j + 1);
99
100
         if (has_missings(inputs, j + 1)) {
101
102
             thisweights = C_tempweights(j + 1, weights, fitmem, inputs);
103
104
             /* check if sum(weights) > 1 */
105
             sumw = 0.0;
106
             for (i = 0; i < nobs; i++) sumw += thisweights[i];
107
             if (sumw < 2.0) continue;
108
                 
109
             C_ExpectCovarInfluence(ytmp, 1, thisweights, nobs, expcovinf);
110
             
111
             C_split(REAL(x), 1, ytmp, 1, thisweights, nobs,
112
                     INTEGER(get_ordering(inputs, j + 1)), splitctrl,
113
                     GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
114
                     expcovinf, &cp, &ms, splitstat);
115
         } else {
116
         
117
             C_split(REAL(x), 1, ytmp, 1, tweights, nobs,
118
             INTEGER(get_ordering(inputs, j + 1)), splitctrl,
119
             GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
120
             expcovinf, &cp, &ms, splitstat);
121
         }
122
123
         maxstat[j] = -ms;
124
         cutpoint[j] = cp;
125
    }
126
127
128
    /* order with respect to maximal statistic */
129
    rsort_with_index(maxstat, order, ninputs);
130
    
131
    twotab = Calloc(4, double);
132
    
133
    /* the best `maxsurr' ones are implemented */
134
    for (j = 0; j < maxsurr; j++) {
135
136
        if (is_nominal(inputs, order[j])) continue;
137
        
138
        for (i = 0; i < 4; i++) twotab[i] = 0.0;
139
        cut = cutpoint[order[j] - 1];
140
        SET_VECTOR_ELT(S3get_surrogatesplits(node), j, 
141
                       split = allocVector(VECSXP, SPLIT_LENGTH));
142
        C_init_orderedsplit(split, 0);
143
        S3set_variableID(split, order[j]);
144
        REAL(S3get_splitpoint(split))[0] = cut;
145
        dx = REAL(get_variable(inputs, order[j]));
146
        dy = REAL(y);
147
148
        /* OK, this is a dirty hack: determine if the split 
149
           goes left or right by the Pearson residual of a 2x2 table.
150
           I don't want to use the big caliber here 
151
        */
152
        for (i = 0; i < nobs; i++) {
153
            twotab[0] += ((dy[i] == 1) && (dx[i] <= cut)) * tweights[i];
154
            twotab[1] += (dy[i] == 1) * tweights[i];
155
            twotab[2] += (dx[i] <= cut) * tweights[i];
156
            twotab[3] += tweights[i];
157
        }
158
        S3set_toleft(split, (int) (twotab[0] - twotab[1] * twotab[2] / 
159
                     twotab[3]) > 0);
160
    }
161
    
162
    Free(maxstat);
163
    Free(cutpoint);
164
    Free(order);
165
    Free(tweights);
166
    Free(twotab);
167
    Free(ytmp);
168
}
169
170
/**
171
    R-interface to C_surrogates \n
172
    *\param node the current node with primary split specified
173
    *\param learnsample learning sample
174
    *\param weights the weights associated with the current node
175
    *\param controls an object of class `TreeControl'
176
    *\param fitmem an object of class `TreeFitMemory'
177
*/
178
179
180
SEXP R_surrogates(SEXP node, SEXP learnsample, SEXP weights, SEXP controls, 
181
                  SEXP fitmem) {
182
183
    C_surrogates(node, learnsample, weights, controls, fitmem);
184
    return(S3get_surrogatesplits(node));
185
    
186
}
187
188
/**
189
    Split with missing values \n
190
    *\param node the current node with primary and surrogate splits 
191
                 specified
192
    *\param learnsample learning sample
193
*/
194
195
void C_splitsurrogate(SEXP node, SEXP learnsample) {
196
197
    SEXP weights, split, surrsplit;
198
    SEXP inputs, whichNA, whichNAns;
199
    double cutpoint, *dx, *dweights, *leftweights, *rightweights;
200
    int *iwhichNA, k;
201
    int i, nna, ns;
202
                    
203
    weights = S3get_nodeweights(node);
204
    dweights = REAL(weights);
205
    inputs = GET_SLOT(learnsample, PL2_inputsSym);
206
            
207
    leftweights = REAL(S3get_nodeweights(S3get_leftnode(node)));
208
    rightweights = REAL(S3get_nodeweights(S3get_rightnode(node)));
209
    surrsplit = S3get_surrogatesplits(node);
210
211
    /* if the primary split has any missings */
212
    split = S3get_primarysplit(node);
213
    if (has_missings(inputs, S3get_variableID(split))) {
214
215
        /* where are the missings? */
216
        whichNA = get_missings(inputs, S3get_variableID(split));
217
        iwhichNA = INTEGER(whichNA);
218
        nna = LENGTH(whichNA);
219
220
        /* for all missing values ... */
221
        for (k = 0; k < nna; k++) {
222
            ns = 0;
223
            i = iwhichNA[k] - 1;
224
            if (dweights[i] == 0) continue;
225
            
226
            /* loop over surrogate splits until an appropriate one is found */
227
            while(TRUE) {
228
            
229
                if (ns >= LENGTH(surrsplit)) break;
230
231
                split = VECTOR_ELT(surrsplit, ns);
232
                if (has_missings(inputs, S3get_variableID(split))) {
233
                    whichNAns = get_missings(inputs, S3get_variableID(split));
234
                    if (C_i_in_set(i + 1, whichNAns)) {
235
                        ns++;
236
                        continue;
237
                    }
238
                }
239
240
                cutpoint = REAL(S3get_splitpoint(split))[0];
241
                dx = REAL(get_variable(inputs, S3get_variableID(split)));
242
243
                if (S3get_toleft(split)) {
244
                    if (dx[i] <= cutpoint) {
245
                        leftweights[i] = dweights[i];
246
                        rightweights[i] = 0.0;
247
                    } else {
248
                        rightweights[i] = dweights[i];
249
                        leftweights[i] = 0.0;
250
                    }
251
                } else {
252
                    if (dx[i] <= cutpoint) {
253
                        rightweights[i] = dweights[i];
254
                        leftweights[i] = 0.0;
255
                    } else {
256
                        leftweights[i] = dweights[i];
257
                        rightweights[i] = 0.0;
258
                    }
259
                }
260
                break;
261
            }
262
        }
263
    }
264
}