Switch to unified view

a b/partyMod/src/RandomForest.c
1
2
/**
3
    Random forest with conditional inference trees
4
    *\file RandomForest.c
5
    *\author $Author$
6
    *\date $Date$
7
*/
8
9
#include "party.h"
10
11
/**
12
    An experimental implementation of random forest like algorithms \n
13
    *\param learnsample an object of class `LearningSample'
14
    *\param weights a vector of case weights
15
    *\param bwhere integer matrix (n x ntree) for terminal node numbers
16
    *\param bweights double matrix (n x ntree) for bootstrap case weights
17
    *\param fitmem an object of class `TreeFitMemory'
18
    *\param controls an object of class `TreeControl'
19
*/
20
21
22
SEXP R_Ensemble(SEXP learnsample, SEXP weights, SEXP bwhere, SEXP bweights, 
23
                SEXP fitmem, SEXP controls) {
24
            
25
     SEXP nweights, tree, where, ans, bw, compress_exp;
26
     double *dnweights, *dweights, sw = 0.0, *prob, tmp;
27
     int nobs, i, b, B , nodenum = 1, *iweights, *iweightstmp, 
28
         *iwhere, replace, fraction, wgrzero = 0, realweights = 0;
29
     int j, k, l, swi = 0;
30
     int errorOccurred;
31
     int *variables_to_ignore = NULL;
32
     int ninputs;
33
     
34
     B = get_ntree(controls);
35
     nobs = get_nobs(learnsample);
36
     ninputs = get_ninputs(learnsample);
37
     
38
     PROTECT(ans = allocVector(VECSXP, B));
39
40
     iweights = Calloc(nobs, int);
41
     iweightstmp = Calloc(nobs, int);
42
     prob = Calloc(nobs, double);
43
     dweights = REAL(weights);
44
     
45
     int varOnce = get_only_use_variable_once(get_tgctrl(controls));
46
     
47
//     printf("R_Ensemble: varOnce=%d\n", varOnce);
48
     if (varOnce) {
49
       variables_to_ignore = Calloc(ninputs, int);
50
     }
51
52
     for (i = 0; i < nobs; i++) {
53
         /* sum of weights */
54
         sw += dweights[i];
55
         /* number of weights > 0 */
56
         if (dweights[i] > 0) wgrzero++;
57
         /* case weights or real weights? */
58
         if (dweights[i] - ftrunc(dweights[i]) > 0) 
59
             realweights = 1;
60
     }
61
     for (i = 0; i < nobs; i++)
62
         prob[i] = dweights[i]/sw;
63
     swi = (int) ftrunc(sw);
64
65
     replace = get_replace(controls);
66
     /* fraction of number of obs with weight > 0 */
67
     if (realweights) {
68
         /* fraction of number of obs with weight > 0 for real weights*/
69
         tmp = (get_fraction(controls) * wgrzero);
70
     } else {
71
         /* fraction of sum of weights for case weights */
72
         tmp = (get_fraction(controls) * sw);
73
     }
74
     fraction = (int) ftrunc(tmp);
75
     if (ftrunc(tmp) < tmp) fraction++;
76
77
     if (!replace) {
78
         if (fraction < 10)
79
             error("fraction of %f is too small", fraction);
80
     }
81
82
     /* <FIXME> can we call those guys ONCE? what about the deeper
83
         calls??? </FIXME> */
84
     GetRNGstate();
85
  
86
     if (get_trace(controls))
87
         Rprintf("\n");
88
     for (b  = 0; b < B; b++) {
89
         SET_VECTOR_ELT(ans, b, tree = allocVector(VECSXP, NODE_LENGTH + 1));
90
         SET_VECTOR_ELT(bwhere, b, where = allocVector(INTSXP, nobs));
91
         SET_VECTOR_ELT(bweights, b, bw = allocVector(REALSXP, nobs));
92
         
93
         iwhere = INTEGER(where);
94
         for (i = 0; i < nobs; i++) iwhere[i] = 0;
95
     
96
         C_init_node(tree, nobs, get_ninputs(learnsample), 
97
                     get_maxsurrogate(get_splitctrl(controls)),
98
                     ncol(get_predict_trafo(GET_SLOT(learnsample, 
99
                                                   PL2_responsesSym))));
100
101
         /* generate altered weights for perturbation */
102
         if (replace) {
103
             /* weights for a bootstrap sample */
104
             rmultinom(swi, prob, nobs, iweights);
105
         } else {
106
             /* weights for sample splitting */
107
             C_SampleSplitting(nobs, prob, iweights, fraction);
108
         }
109
110
         nweights = S3get_nodeweights(tree);
111
         dnweights = REAL(nweights);
112
         for (i = 0; i < nobs; i++) {
113
             REAL(bw)[i] = (double) iweights[i];
114
             dnweights[i] = REAL(bw)[i];
115
         }
116
     
117
         C_TreeGrow(tree, learnsample, fitmem, controls, iwhere, &nodenum, 1, variables_to_ignore);
118
         nodenum = 1;
119
         int dropcriterion = get_dropcriterion(controls);
120
         C_remove_weights(tree, dropcriterion);
121
122
         PROTECT(compress_exp = lang2(get_compress(controls), tree));
123
         SET_VECTOR_ELT(ans, b, R_tryEval(compress_exp, R_GlobalEnv, &errorOccurred) );
124
         if(errorOccurred) { 
125
           Rprintf("error calling compress\n");
126
         } else {
127
//           Rprintf("no error\n");
128
         }        
129
         UNPROTECT(1);
130
         
131
         if (get_trace(controls)) {
132
             /* progress bar; inspired by 
133
             http://avinashjoshi.co.in/2009/10/13/creating-a-progress-bar-in-c/ */
134
             Rprintf("[");
135
             /* Print the = until the current percentage */
136
             l = (int) ceil( ((double) b * 50.0) / B);
137
             for (j = 0; j < l; j++)
138
                 Rprintf("=");
139
             Rprintf(">");
140
             for (k = j; k < 50; k++)
141
                 Rprintf(" ");
142
             Rprintf("]");
143
             /* % completed */
144
                 Rprintf(" %3d%% completed", j * 2);
145
             /* To delete the previous line */
146
             Rprintf("\r");
147
             /* Flush all char in buffer */
148
             /* fflush(stdout); */
149
         }
150
     }
151
     if (get_trace(controls))
152
         Rprintf("\n");
153
154
     PutRNGstate();
155
156
     Free(prob); Free(iweights); Free(iweightstmp);
157
     if(variables_to_ignore != NULL) {
158
       Free(variables_to_ignore);
159
     }
160
     UNPROTECT(1);
161
     return(ans);
162
}
163
164
/**
165
    An experimental implementation of random forest like algorithms \n
166
    *\param learnsample an object of class `LearningSample'
167
    *\param weights a vector of case weights
168
    *\param bwhere integer matrix (n x ntree) for terminal node numbers
169
    *\param bweights double matrix (n x ntree) for bootstrap case weights
170
    *\param fitmem an object of class `TreeFitMemory'
171
    *\param controls an object of class `TreeControl'
172
*/
173
174
175
SEXP R_Ensemble_weights(SEXP learnsample, SEXP bwhere, SEXP bweights, 
176
                SEXP fitmem, SEXP controls) {
177
            
178
     SEXP nweights, tree, where, ans, compress_exp;
179
     double *dnweights, *dweights;
180
     int nobs, i, b, B , nodenum = 1, *iwhere;
181
     int j, k, l;
182
     int errorOccurred;
183
184
     int *variables_to_ignore = NULL;
185
     int ninputs;
186
     int varOnce;
187
     
188
     B = get_ntree(controls);
189
     nobs = get_nobs(learnsample);
190
     ninputs = get_ninputs(learnsample);
191
192
     varOnce = get_only_use_variable_once(get_tgctrl(controls));
193
//     printf("R_Ensemble_weight: varOnce=%d\n", varOnce);
194
     
195
     if (varOnce) {
196
       variables_to_ignore = Calloc(ninputs, int);
197
     }
198
199
     PROTECT(ans = allocVector(VECSXP, B));
200
201
     /* <FIXME> can we call those guys ONCE? what about the deeper
202
         calls??? </FIXME> */
203
     GetRNGstate();
204
  
205
     if (get_trace(controls))
206
         Rprintf("\n");
207
     for (b  = 0; b < B; b++) {
208
         SET_VECTOR_ELT(ans, b, tree = allocVector(VECSXP, NODE_LENGTH + 1));
209
         SET_VECTOR_ELT(bwhere, b, where = allocVector(INTSXP, nobs));
210
         
211
         iwhere = INTEGER(where);
212
         for (i = 0; i < nobs; i++) iwhere[i] = 0;
213
     
214
         C_init_node(tree, nobs, get_ninputs(learnsample), 
215
                     get_maxsurrogate(get_splitctrl(controls)),
216
                     ncol(get_predict_trafo(GET_SLOT(learnsample, 
217
                                                   PL2_responsesSym))));
218
219
         nweights = S3get_nodeweights(tree);
220
         dnweights = REAL(nweights);
221
         dweights = REAL(VECTOR_ELT(bweights, b));
222
         for (i = 0; i < nobs; i++) {
223
             dnweights[i] = dweights[i];
224
         }
225
     
226
         C_TreeGrow(tree, learnsample, fitmem, controls, iwhere, &nodenum, 1, variables_to_ignore);
227
         nodenum = 1;
228
         int dropcriterion = get_dropcriterion(controls);
229
         C_remove_weights(tree, dropcriterion);
230
         
231
         PROTECT(compress_exp = lang2(get_compress(controls), tree));
232
         SET_VECTOR_ELT(ans, b, R_tryEval(compress_exp, R_GlobalEnv, &errorOccurred) );
233
         if(errorOccurred) { 
234
           Rprintf("error calling compress\n");
235
         } else {
236
//           Rprintf("no error\n");
237
         }        
238
         UNPROTECT(1);
239
         
240
         if (get_trace(controls)) {
241
             /* progress bar; inspired by 
242
             http://avinashjoshi.co.in/2009/10/13/creating-a-progress-bar-in-c/ */
243
             Rprintf("[");
244
             /* Print the = until the current percentage */
245
             l = (int) ceil( ((double) b * 50.0) / B);
246
             for (j = 0; j < l; j++)
247
                 Rprintf("=");
248
             Rprintf(">");
249
             for (k = j; k < 50; k++)
250
                 Rprintf(" ");
251
             Rprintf("]");
252
             /* % completed */
253
                 Rprintf(" %3d%% completed", j * 2);
254
             /* To delete the previous line */
255
             Rprintf("\r");
256
             /* Flush all char in buffer */
257
             /* fflush(stdout); */
258
         }
259
     }
260
     if (get_trace(controls))
261
         Rprintf("\n");
262
263
     PutRNGstate();
264
265
     UNPROTECT(1);
266
     if(variables_to_ignore != NULL) {
267
       Free(variables_to_ignore);
268
     }
269
270
     return(ans);
271
}