Diff of /partyMod/src/Predict.c [000000] .. [fbf06f]

Switch to unified view

a b/partyMod/src/Predict.c
1
2
/**
3
    Node splitting and prediction
4
    *\file Predict.c
5
    *\author $Author$
6
    *\date $Date$
7
*/
8
                
9
#include "party.h"
10
11
12
/**
13
    Split a node according to a splitting rule \n
14
    *\param node the current node with primary split specified
15
    *\param learnsample learning sample
16
    *\param control an object of class `TreeControl'
17
    *\todo outplace the splitting since there are at least 3 functions
18
           with nearly identical code
19
*/
20
                
21
void C_splitnode(SEXP node, SEXP learnsample, SEXP control) {
22
23
    SEXP weights, leftnode, rightnode, split;
24
    SEXP responses, inputs, whichNA;
25
    double cutpoint, *dx, *dweights, *leftweights, *rightweights;
26
    double sleft = 0.0, sright = 0.0;
27
    int *ix, *levelset, *iwhichNA;
28
    int nobs, i, nna;
29
                    
30
    weights = S3get_nodeweights(node);
31
    dweights = REAL(weights);
32
    responses = GET_SLOT(learnsample, PL2_responsesSym);
33
    inputs = GET_SLOT(learnsample, PL2_inputsSym);
34
    nobs = get_nobs(learnsample);
35
            
36
    /* set up memory for the left daughter */
37
    SET_VECTOR_ELT(node, S3_LEFT, leftnode = allocVector(VECSXP, NODE_LENGTH));
38
    C_init_node(leftnode, nobs, 
39
        get_ninputs(learnsample), get_maxsurrogate(get_splitctrl(control)),
40
        ncol(get_predict_trafo(responses)));
41
    leftweights = REAL(S3get_nodeweights(leftnode));
42
43
    /* set up memory for the right daughter */
44
    SET_VECTOR_ELT(node, S3_RIGHT, 
45
                   rightnode = allocVector(VECSXP, NODE_LENGTH));
46
    C_init_node(rightnode, nobs, 
47
        get_ninputs(learnsample), get_maxsurrogate(get_splitctrl(control)),
48
        ncol(get_predict_trafo(responses)));
49
    rightweights = REAL(S3get_nodeweights(rightnode));
50
51
    /* split according to the primary split */
52
    split = S3get_primarysplit(node);
53
    if (has_missings(inputs, S3get_variableID(split))) {
54
        whichNA = get_missings(inputs, S3get_variableID(split));
55
        iwhichNA = INTEGER(whichNA);
56
        nna = LENGTH(whichNA);
57
    } else {
58
        nna = 0;
59
        whichNA = R_NilValue;
60
        iwhichNA = NULL;
61
    }
62
    
63
    if (S3is_ordered(split)) {
64
        cutpoint = REAL(S3get_splitpoint(split))[0];
65
        dx = REAL(get_variable(inputs, S3get_variableID(split)));
66
        for (i = 0; i < nobs; i++) {
67
            if (nna > 0) {
68
                if (i_in_set(i + 1, iwhichNA, nna)) continue;
69
            }
70
            if (dx[i] <= cutpoint) 
71
                leftweights[i] = dweights[i]; 
72
            else 
73
                leftweights[i] = 0.0;
74
            rightweights[i] = dweights[i] - leftweights[i];
75
            sleft += leftweights[i];
76
            sright += rightweights[i];
77
        }
78
    } else {
79
        levelset = INTEGER(S3get_splitpoint(split));
80
        ix = INTEGER(get_variable(inputs, S3get_variableID(split)));
81
82
        for (i = 0; i < nobs; i++) {
83
            if (nna > 0) {
84
                if (i_in_set(i + 1, iwhichNA, nna)) continue;
85
            }
86
            if (levelset[ix[i] - 1])
87
                leftweights[i] = dweights[i];
88
            else 
89
                leftweights[i] = 0.0;
90
            rightweights[i] = dweights[i] - leftweights[i];
91
            sleft += leftweights[i];
92
            sright += rightweights[i];
93
        }
94
    }
95
    
96
    /* for the moment: NA's go with majority */
97
    if (nna > 0) {
98
        for (i = 0; i < nna; i++) {
99
            if (sleft > sright) {
100
                leftweights[iwhichNA[i] - 1] = dweights[iwhichNA[i] - 1];
101
                rightweights[iwhichNA[i] - 1] = 0.0;
102
            } else {
103
                rightweights[iwhichNA[i] - 1] = dweights[iwhichNA[i] - 1];
104
                leftweights[iwhichNA[i] - 1] = 0.0;
105
            }
106
        }
107
    }
108
}
109
110
111
/**
112
    Get the terminal node for obs. number `numobs' of `newinputs' \n
113
    *\param subtree a tree
114
    *\param newinputs an object of class `VariableFrame'
115
    *\param mincriterion overwrites mincriterion used for tree growing
116
    *\param numobs observation number
117
    *\param varperm which variable shall be permuted?
118
    *\todo handle surrogate splits
119
*/
120
121
SEXP C_get_node(SEXP subtree, SEXP newinputs, 
122
                double mincriterion, int numobs, int varperm) {
123
124
    SEXP split, whichNA, ssplit, surrsplit;
125
    double cutpoint, x, swleft, swright;
126
    int level, *levelset, i, ns;
127
128
    if (S3get_nodeterminal(subtree) || 
129
        REAL(S3get_maxcriterion(subtree))[0] < mincriterion) 
130
        return(subtree);
131
    
132
    split = S3get_primarysplit(subtree);
133
134
    /* Maybe store the proportions left / right in each node? */
135
    swleft = S3get_sumweights(S3get_leftnode(subtree));
136
    swright = S3get_sumweights(S3get_rightnode(subtree));
137
138
    /* splits based on variable varperm are random */    
139
    if (S3get_variableID(split) == varperm) {
140
        if (unif_rand() < swleft / (swleft + swright)) {
141
            return(C_get_node(S3get_leftnode(subtree),
142
                       newinputs, mincriterion, numobs, varperm));
143
        } else {
144
            return(C_get_node(S3get_rightnode(subtree),
145
                       newinputs, mincriterion, numobs, varperm));
146
        }
147
    }
148
                   
149
    /* missing values */
150
    if (has_missings(newinputs, S3get_variableID(split))) {
151
        whichNA = get_missings(newinputs, S3get_variableID(split));
152
    
153
        /* numobs 0 ... n - 1 but whichNA has 1:n */
154
        if (C_i_in_set(numobs + 1, whichNA)) {
155
        
156
            surrsplit = S3get_surrogatesplits(subtree);
157
            ns = 0;
158
            i = numobs;      
159
160
            /* try to find a surrogate split */
161
            while(TRUE) {
162
    
163
                if (ns >= LENGTH(surrsplit)) break;
164
            
165
                ssplit = VECTOR_ELT(surrsplit, ns);
166
                if (has_missings(newinputs, S3get_variableID(ssplit))) {
167
                    if (INTEGER(get_missings(newinputs, 
168
                                             S3get_variableID(ssplit)))[i]) {
169
                        ns++;
170
                        continue;
171
                    }
172
                }
173
174
                cutpoint = REAL(S3get_splitpoint(ssplit))[0];
175
                x = REAL(get_variable(newinputs, S3get_variableID(ssplit)))[i];
176
                     
177
                if (S3get_toleft(ssplit)) {
178
                    if (x <= cutpoint) {
179
                        return(C_get_node(S3get_leftnode(subtree),
180
                                          newinputs, mincriterion, numobs, varperm));
181
                    } else {
182
                        return(C_get_node(S3get_rightnode(subtree),
183
                               newinputs, mincriterion, numobs, varperm));
184
                    }
185
                } else {
186
                    if (x <= cutpoint) {
187
                        return(C_get_node(S3get_rightnode(subtree),
188
                                          newinputs, mincriterion, numobs, varperm));
189
                    } else {
190
                        return(C_get_node(S3get_leftnode(subtree),
191
                               newinputs, mincriterion, numobs, varperm));
192
                    }
193
                }
194
                break;
195
            }
196
197
            /* if this was not successful, we go with the majority */
198
            if (swleft > swright) {
199
                return(C_get_node(S3get_leftnode(subtree), 
200
                                  newinputs, mincriterion, numobs, varperm));
201
            } else {
202
                return(C_get_node(S3get_rightnode(subtree), 
203
                                  newinputs, mincriterion, numobs, varperm));
204
            }
205
        }
206
    }
207
    
208
    if (S3is_ordered(split)) {
209
        cutpoint = REAL(S3get_splitpoint(split))[0];
210
        x = REAL(get_variable(newinputs, 
211
                     S3get_variableID(split)))[numobs];
212
        if (x <= cutpoint) {
213
            return(C_get_node(S3get_leftnode(subtree), 
214
                              newinputs, mincriterion, numobs, varperm));
215
        } else {
216
            return(C_get_node(S3get_rightnode(subtree), 
217
                              newinputs, mincriterion, numobs, varperm));
218
        }
219
    } else {
220
        levelset = INTEGER(S3get_splitpoint(split));
221
        level = INTEGER(get_variable(newinputs, 
222
                            S3get_variableID(split)))[numobs];
223
        /* level is in 1, ..., K */
224
        if (levelset[level - 1]) {
225
            return(C_get_node(S3get_leftnode(subtree), newinputs, 
226
                              mincriterion, numobs, varperm));
227
        } else {
228
            return(C_get_node(S3get_rightnode(subtree), newinputs, 
229
                              mincriterion, numobs, varperm));
230
        }
231
    }
232
}
233
234
235
/**
236
    R-Interface to C_get_node \n
237
    *\param subtree a tree
238
    *\param newinputs an object of class `VariableFrame'
239
    *\param mincriterion overwrites mincriterion used for tree growing
240
    *\param numobs observation number
241
*/
242
243
SEXP R_get_node(SEXP subtree, SEXP newinputs, SEXP mincriterion, 
244
                SEXP numobs, SEXP varperm) {
245
    return(C_get_node(subtree, newinputs, REAL(mincriterion)[0],
246
                      INTEGER(numobs)[0] - 1, INTEGER(varperm)[0]));
247
}
248
249
250
/**
251
    Get the node with nodeID `nodenum' \n
252
    *\param subtree a tree
253
    *\param nodenum a nodeID
254
*/
255
256
SEXP C_get_nodebynum(SEXP subtree, int nodenum) {
257
    
258
    if (nodenum == S3get_nodeID(subtree)) return(subtree);
259
260
    if (S3get_nodeterminal(subtree)) 
261
        error("no node with number %d\n", nodenum);
262
263
    if (nodenum < S3get_nodeID(S3get_rightnode(subtree))) {
264
        return(C_get_nodebynum(S3get_leftnode(subtree), nodenum));
265
    } else {
266
        return(C_get_nodebynum(S3get_rightnode(subtree), nodenum));
267
    }
268
}
269
270
271
/**
272
    R-Interface to C_get_nodenum \n
273
    *\param subtree a tree
274
    *\param nodenum a nodeID
275
*/
276
277
SEXP R_get_nodebynum(SEXP subtree, SEXP nodenum) {
278
    return(C_get_nodebynum(subtree, INTEGER(nodenum)[0]));
279
}
280
281
282
/**
283
    Get the prediction of a new observation\n
284
    *\param subtree a tree
285
    *\param newinputs an object of class `VariableFrame'
286
    *\param mincriterion overwrites mincriterion used for tree growing
287
    *\param numobs observation number
288
    *\param varperm which variable shall be permuted?
289
*/
290
291
SEXP C_get_prediction(SEXP subtree, SEXP newinputs, 
292
                      double mincriterion, int numobs, int varperm) {
293
    return(S3get_prediction(C_get_node(subtree, newinputs, 
294
                            mincriterion, numobs, varperm)));
295
}
296
297
298
/**
299
    Get the weights for a new observation \n
300
    *\param subtree a tree
301
    *\param newinputs an object of class `VariableFrame'
302
    *\param mincriterion overwrites mincriterion used for tree growing
303
    *\param numobs observation number
304
*/
305
306
SEXP C_get_nodeweights(SEXP subtree, SEXP newinputs, 
307
                       double mincriterion, int numobs) {
308
    return(S3get_nodeweights(C_get_node(subtree, newinputs, 
309
                             mincriterion, numobs, -1)));
310
}
311
312
313
/**
314
    Get the nodeID for a new observation \n
315
    *\param subtree a tree
316
    *\param newinputs an object of class `VariableFrame'
317
    *\param mincriterion overwrites mincriterion used for tree growing
318
    *\param numobs observation number
319
    *\param varperm which variable shall be permuted?
320
*/
321
322
int C_get_nodeID(SEXP subtree, SEXP newinputs,
323
                  double mincriterion, int numobs, int varperm) {
324
     return(S3get_nodeID(C_get_node(subtree, newinputs, 
325
            mincriterion, numobs, varperm)));
326
}
327
328
329
/**
330
    R-Interface to C_get_nodeID \n
331
    *\param tree a tree
332
    *\param newinputs an object of class `VariableFrame'
333
    *\param mincriterion overwrites mincriterion used for tree growing
334
*/
335
336
SEXP R_get_nodeID(SEXP tree, SEXP newinputs, SEXP mincriterion, SEXP varperm) {
337
338
    SEXP ans;
339
    int nobs, i, *dans;
340
            
341
    nobs = get_nobs(newinputs);
342
    PROTECT(ans = allocVector(INTSXP, nobs));
343
    dans = INTEGER(ans);
344
    for (i = 0; i < nobs; i++)
345
         dans[i] = C_get_nodeID(tree, newinputs, REAL(mincriterion)[0], i, INTEGER(varperm)[0]);
346
    UNPROTECT(1);
347
    return(ans);
348
}
349
350
351
/**
352
    Get all predictions for `newinputs' \n
353
    *\param tree a tree
354
    *\param newinputs an object of class `VariableFrame'
355
    *\param mincriterion overwrites mincriterion used for tree growing
356
    *\param varperm which variable shall be permuted?
357
    *\param ans return value
358
*/
359
360
void C_predict(SEXP tree, SEXP newinputs, double mincriterion, 
361
               int varperm, SEXP ans) {
362
    
363
    int nobs, i;
364
    
365
    nobs = get_nobs(newinputs);    
366
    if (LENGTH(ans) != nobs) 
367
        error("ans is not of length %d\n", nobs);
368
        
369
    for (i = 0; i < nobs; i++)
370
        SET_VECTOR_ELT(ans, i, C_get_prediction(tree, newinputs, 
371
                       mincriterion, i, varperm));
372
}
373
374
375
/**
376
    R-Interface to C_predict \n
377
    *\param tree a tree
378
    *\param newinputs an object of class `VariableFrame'
379
    *\param mincriterion overwrites mincriterion used for tree growing
380
    *\param varperm which variable shall be permuted? -1 for no permutation
381
*/
382
383
SEXP R_predict(SEXP tree, SEXP newinputs, SEXP mincriterion,
384
               SEXP varperm) {
385
386
    SEXP ans;
387
    int nobs;
388
389
    nobs = get_nobs(newinputs);
390
    PROTECT(ans = allocVector(VECSXP, nobs));
391
    GetRNGstate();
392
    C_predict(tree, newinputs, REAL(mincriterion)[0], 
393
              INTEGER(varperm)[0], ans);
394
    PutRNGstate();
395
    UNPROTECT(1);
396
    return(ans);
397
}
398
399
/**
400
    Get the predictions from `where' nodes\n
401
    *\param tree a tree
402
    *\param where vector of nodeID's
403
    *\param ans return value
404
*/
405
406
void C_getpredictions(SEXP tree, SEXP where, SEXP ans) {
407
408
    int nobs, i, *iwhere;
409
    
410
    nobs = LENGTH(where);
411
    iwhere = INTEGER(where);
412
    if (LENGTH(ans) != nobs)
413
        error("ans is not of length %d\n", nobs);
414
        
415
    for (i = 0; i < nobs; i++)
416
        SET_VECTOR_ELT(ans, i, S3get_prediction(
417
            C_get_nodebynum(tree, iwhere[i])));
418
}
419
420
421
/**
422
    R-Interface to C_getpredictions\n
423
    *\param tree a tree
424
    *\param where vector of nodeID's
425
*/
426
            
427
SEXP R_getpredictions(SEXP tree, SEXP where) {
428
429
    SEXP ans;
430
    int nobs;
431
            
432
    nobs = LENGTH(where);
433
    PROTECT(ans = allocVector(VECSXP, nobs));
434
    C_getpredictions(tree, where, ans);
435
    UNPROTECT(1);
436
    return(ans);
437
}                        
438
439
/**
440
    Predictions weights from RandomForest objects
441
    *\param forest a list of trees
442
    *\param where list (length b) of integer vectors (length n) containing terminal node numbers    
443
    *\param weights list (length b) of bootstrap case weights
444
    *\param newinputs an object of class `VariableFrame'
445
    *\param mincriterion overwrites mincriterion used for tree growing
446
    *\param oobpred a logical indicating out-of-bag predictions
447
*/
448
449
SEXP R_predictRF_weights(SEXP forest, SEXP where, SEXP weights, 
450
                         SEXP newinputs, SEXP mincriterion, SEXP oobpred, SEXP expand) {
451
452
    SEXP ans, tree, bw, expand_exp;
453
    int ntrees, nobs, i, b, j, iwhere, oob = 0, count = 0, ntrain;
454
    int errorOccurred;
455
    
456
    if (LOGICAL(oobpred)[0]) oob = 1;
457
    
458
    nobs = get_nobs(newinputs);
459
    ntrees = LENGTH(forest);
460
461
    if (oob) {
462
        if (LENGTH(VECTOR_ELT(weights, 0)) != nobs)
463
            error("number of observations don't match");
464
    }    
465
    
466
    tree = VECTOR_ELT(forest, 0);
467
    ntrain = LENGTH(VECTOR_ELT(weights, 0));
468
    
469
    PROTECT(ans = allocVector(VECSXP, nobs));
470
    
471
    for (i = 0; i < nobs; i++) {
472
        count = 0;
473
        SET_VECTOR_ELT(ans, i, bw = allocVector(REALSXP, ntrain));
474
        for (j = 0; j < ntrain; j++)
475
            REAL(bw)[j] = 0.0;
476
        for (b = 0; b < ntrees; b++) {
477
            tree = VECTOR_ELT(forest, b);
478
            PROTECT(expand_exp = lang2(expand, tree));
479
            tree = R_tryEval(expand_exp, R_GlobalEnv, &errorOccurred);
480
            UNPROTECT(1);
481
            if(errorOccurred) { 
482
               Rprintf("error calling expand\n");
483
               break;
484
            } 
485
            PROTECT(tree);
486
            
487
            if (oob && 
488
                REAL(VECTOR_ELT(weights, b))[i] > 0.0) {
489
                UNPROTECT(1);
490
                continue;
491
            }
492
493
            iwhere = C_get_nodeID(tree, newinputs, REAL(mincriterion)[0], i, -1);
494
            
495
            for (j = 0; j < ntrain; j++) {
496
                if (iwhere == INTEGER(VECTOR_ELT(where, b))[j])
497
                    REAL(bw)[j] += REAL(VECTOR_ELT(weights, b))[j];
498
            }
499
            count++;
500
            UNPROTECT(1);
501
        }
502
        if(errorOccurred)
503
            break;
504
        
505
        if (count == 0) 
506
            error("cannot compute out-of-bag predictions for observation number %d", i + 1);
507
    }
508
    UNPROTECT(1);
509
    
510
    if(errorOccurred) {
511
        return NULL;
512
    }
513
    
514
    return(ans);
515
}
516
517
518
/**
519
    Proximity matrix for random forests
520
    *\param where list (length b) of integer vectors (length n) containing terminal node numbers
521
*/
522
523
SEXP R_proximity(SEXP where) {
524
525
    SEXP ans, bw, bin;
526
    int ntrees, nobs, i, b, j, iwhere;
527
    
528
    ntrees = LENGTH(where);
529
    nobs = LENGTH(VECTOR_ELT(where, 0));
530
    
531
    PROTECT(ans = allocVector(VECSXP, nobs));
532
    PROTECT(bin = allocVector(INTSXP, nobs));
533
     
534
    for (i = 0; i < nobs; i++) {
535
        SET_VECTOR_ELT(ans, i, bw = allocVector(REALSXP, nobs));
536
        for (j = 0; j < nobs; j++) {
537
            REAL(bw)[j] = 0.0;
538
            INTEGER(bin)[j] = 0;
539
        }
540
        for (b = 0; b < ntrees; b++) {
541
            /* don't look at out-of-bag observations */
542
            if (INTEGER(VECTOR_ELT(where, b))[i] == 0)
543
                continue;
544
            iwhere = INTEGER(VECTOR_ELT(where, b))[i];
545
            for (j = 0; j < nobs; j++) {
546
                if (iwhere == INTEGER(VECTOR_ELT(where, b))[j])
547
                    /* only count the number of trees; no weights */
548
                    REAL(bw)[j]++;
549
                if (INTEGER(VECTOR_ELT(where, b))[j] > 0)
550
                    /* count the number of bootstrap samples
551
                    containing both i and j */
552
                    INTEGER(bin)[j]++;
553
            }
554
        }
555
        for (j = 0; j < nobs; j++)
556
            REAL(bw)[j] = REAL(bw)[j] / INTEGER(bin)[j];
557
    }
558
    UNPROTECT(2);
559
    return(ans);
560
}