Diff of /partyMod/src/Node.c [000000] .. [fbf06f]

Switch to unified view

a b/partyMod/src/Node.c
1
2
/**
3
    Node computations
4
    *\file Node.c
5
    *\author $Author$
6
    *\date $Date$
7
*/
8
                
9
#include "party.h"
10
11
12
/**
13
    Compute prediction of a node
14
    *\param y the response variable (raw numeric values or dummy encoded factor)
15
    *\param n number of observations
16
    *\param q number of columns of y
17
    *\param weights case weights
18
    *\param sweights sum of case weights
19
    *\param ans return value; the q-dimensional predictions
20
*/
21
        
22
void C_prediction(const double *y, int n, int q, const double *weights, 
23
                  const double sweights, double *ans) {
24
25
    int i, j, jn;
26
    
27
    for (j = 0; j < q; j++) {
28
        ans[j] = 0.0;
29
        jn = j * n;
30
        for (i = 0; i < n; i++) 
31
            ans[j] += weights[i] * y[jn + i];
32
        ans[j] = ans[j] / sweights;
33
    }
34
}
35
36
37
void mask_pvalue(double *pvalue, int *variables_to_ignore, int ninputs) {
38
  int i;
39
  if(variables_to_ignore == NULL) {
40
    return;
41
  }
42
  
43
  for(i=0;i<ninputs;i++) {
44
    if(variables_to_ignore[i]) {
45
        pvalue[i] = R_NegInf;
46
    }
47
  }
48
}
49
50
51
/**
52
    The main function for all node computations
53
    *\param node an initialized node (an S3 object!)
54
    *\param learnsample an object of class `LearningSample'
55
    *\param weights case weights
56
    *\param fitmem an object of class `TreeFitMemory'
57
    *\param controls an object of class `TreeControl'
58
    *\param TERMINAL logical indicating if this node will
59
                     be a terminal node
60
    *\param depth an integer giving the depth of the current node
61
*/
62
63
void C_Node(SEXP node, SEXP learnsample, SEXP weights, 
64
            SEXP fitmem, SEXP controls, int TERMINAL, int depth, int *variables_to_ignore) {
65
    
66
    int nobs, ninputs, jselect, q, j, k, i;
67
    double mincriterion, sweights, *dprediction;
68
    double *teststat, *pvalue, smax, cutpoint = 0.0, maxstat = 0.0;
69
    double *standstat, *splitstat;
70
    SEXP responses, inputs, x, expcovinf, linexpcov;
71
    SEXP varctrl, splitctrl, gtctrl, tgctrl, split, testy, predy;
72
    double *dxtransf, *thisweights;
73
    int *itable;
74
    
75
    nobs = get_nobs(learnsample);
76
    ninputs = get_ninputs(learnsample);
77
    varctrl = get_varctrl(controls);
78
    splitctrl = get_splitctrl(controls);
79
    gtctrl = get_gtctrl(controls);
80
    tgctrl = get_tgctrl(controls);
81
    mincriterion = get_mincriterion(gtctrl);
82
    responses = GET_SLOT(learnsample, PL2_responsesSym);
83
    inputs = GET_SLOT(learnsample, PL2_inputsSym);
84
    testy = get_test_trafo(responses);
85
    predy = get_predict_trafo(responses);
86
    q = ncol(testy);
87
88
    /* <FIXME> we compute C_GlobalTest even for TERMINAL nodes! </FIXME> */
89
90
    /* compute the test statistics and the node criteria for each input */        
91
    C_GlobalTest(learnsample, weights, fitmem, varctrl,
92
                 gtctrl, get_minsplit(splitctrl), 
93
                 REAL(S3get_teststat(node)), REAL(S3get_criterion(node)), depth);
94
    
95
    /* sum of weights: C_GlobalTest did nothing if sweights < mincriterion */
96
    sweights = REAL(GET_SLOT(GET_SLOT(fitmem, PL2_expcovinfSym), 
97
                             PL2_sumweightsSym))[0];
98
    REAL(VECTOR_ELT(node, S3_SUMWEIGHTS))[0] = sweights;
99
100
    /* compute the prediction of this node */
101
    dprediction = REAL(S3get_prediction(node));
102
103
    /* <FIXME> feed raw numeric values OR dummy encoded factors as y 
104
       Problem: what happens for survival times ? */
105
    C_prediction(REAL(predy), nobs, ncol(predy), REAL(weights), 
106
                     sweights, dprediction);
107
    /* </FIXME> */
108
109
    teststat = REAL(S3get_teststat(node));
110
    pvalue = REAL(S3get_criterion(node));
111
112
    mask_pvalue(pvalue, variables_to_ignore, ninputs);
113
114
    /* try the two out of ninputs best inputs variables */
115
    /* <FIXME> be more flexible and add a parameter controlling
116
               the number of inputs tried </FIXME> */
117
    for (j = 0; j < 2; j++) {
118
119
        smax = C_max(pvalue, ninputs);
120
        REAL(S3get_maxcriterion(node))[0] = smax;
121
    
122
        /* if the global null hypothesis was rejected */
123
        if (smax > mincriterion && !TERMINAL) {
124
125
            /* the input variable with largest association to the response */
126
            jselect = C_whichmax(pvalue, teststat, ninputs) + 1;
127
128
            /* get the raw numeric values or the codings of a factor */
129
            x = get_variable(inputs, jselect);
130
            if (has_missings(inputs, jselect)) {
131
                expcovinf = GET_SLOT(get_varmemory(fitmem, jselect), 
132
                                    PL2_expcovinfSym);
133
                thisweights = C_tempweights(jselect, weights, fitmem, inputs);
134
            } else {
135
                expcovinf = GET_SLOT(fitmem, PL2_expcovinfSym);
136
                thisweights = REAL(weights);
137
            }
138
139
            /* <FIXME> handle ordered factors separatly??? </FIXME> */
140
            if (!is_nominal(inputs, jselect)) {
141
            
142
                /* search for a split in a ordered variable x */
143
                split = S3get_primarysplit(node);
144
                
145
                /* check if the n-vector of splitstatistics 
146
                   should be returned for each primary split */
147
                if (get_savesplitstats(tgctrl)) {
148
                    C_init_orderedsplit(split, nobs);
149
                    splitstat = REAL(S3get_splitstatistics(split));
150
                } else {
151
                    C_init_orderedsplit(split, 0);
152
                    splitstat = REAL(get_splitstatistics(fitmem));
153
                }
154
155
                C_split(REAL(x), 1, REAL(testy), q, thisweights, nobs,
156
                        INTEGER(get_ordering(inputs, jselect)), splitctrl, 
157
                        GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
158
                        expcovinf, REAL(S3get_splitpoint(split)), &maxstat,
159
                        splitstat);
160
                S3set_variableID(split, jselect);
161
             } else {
162
           
163
                 /* search of a set of levels (split) in a numeric variable x */
164
                 split = S3get_primarysplit(node);
165
                 
166
                /* check if the n-vector of splitstatistics 
167
                   should be returned for each primary split */
168
                if (get_savesplitstats(tgctrl)) {
169
                    C_init_nominalsplit(split, 
170
                        LENGTH(get_levels(inputs, jselect)), 
171
                        nobs);
172
                    splitstat = REAL(S3get_splitstatistics(split));
173
                } else {
174
                    C_init_nominalsplit(split, 
175
                        LENGTH(get_levels(inputs, jselect)), 
176
                        0);
177
                    splitstat = REAL(get_splitstatistics(fitmem));
178
                }
179
          
180
                 linexpcov = get_varmemory(fitmem, jselect);
181
                 standstat = Calloc(get_dimension(linexpcov), double);
182
                 C_standardize(REAL(GET_SLOT(linexpcov, 
183
                                             PL2_linearstatisticSym)),
184
                               REAL(GET_SLOT(linexpcov, PL2_expectationSym)),
185
                               REAL(GET_SLOT(linexpcov, PL2_covarianceSym)),
186
                               get_dimension(linexpcov), get_tol(splitctrl), 
187
                               standstat);
188
 
189
                 C_splitcategorical(INTEGER(x), 
190
                                    LENGTH(get_levels(inputs, jselect)), 
191
                                    REAL(testy), q, thisweights, 
192
                                    nobs, standstat, splitctrl, 
193
                                    GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
194
                                    expcovinf, &cutpoint, 
195
                                    INTEGER(S3get_splitpoint(split)),
196
                                    &maxstat, splitstat);
197
198
                 /* compute which levels of a factor are available in this node 
199
                    (for printing) later on. A real `table' for this node would
200
                    induce too much overhead here. Maybe later. */
201
                    
202
                 itable = INTEGER(S3get_table(split));
203
                 dxtransf = REAL(get_transformation(inputs, jselect));
204
                 for (k = 0; k < LENGTH(get_levels(inputs, jselect)); k++) {
205
                     itable[k] = 0;
206
                     for (i = 0; i < nobs; i++) {
207
                         if (dxtransf[k * nobs + i] * thisweights[i] > 0) {
208
                             itable[k] = 1;
209
                             continue;
210
                         }
211
                     }
212
                 }
213
214
                 Free(standstat);
215
            }
216
            if (maxstat == 0) {
217
                if (j == 1) {          
218
                    S3set_nodeterminal(node);
219
                } else {
220
                    /* do not look at jselect in next iteration */
221
                    pvalue[jselect - 1] = R_NegInf;
222
                }
223
            } else {
224
                S3set_variableID(split, jselect);
225
                break;
226
            }
227
        } else {
228
            S3set_nodeterminal(node);
229
            break;
230
        }
231
    }
232
}       
233
234
235
/**
236
    R-interface to C_Node
237
    *\param learnsample an object of class `LearningSample'
238
    *\param weights case weights
239
    *\param fitmem an object of class `TreeFitMemory'
240
    *\param controls an object of class `TreeControl'
241
*/
242
243
SEXP R_Node(SEXP learnsample, SEXP weights, SEXP fitmem, SEXP controls) {
244
            
245
     SEXP ans;
246
     
247
     PROTECT(ans = allocVector(VECSXP, NODE_LENGTH));
248
     C_init_node(ans, get_nobs(learnsample), get_ninputs(learnsample), 
249
                 get_maxsurrogate(get_splitctrl(controls)),
250
                 ncol(get_predict_trafo(GET_SLOT(learnsample, PL2_responsesSym))));
251
252
     C_Node(ans, learnsample, weights, fitmem, controls, 0, 1, NULL);
253
     UNPROTECT(1);
254
     return(ans);
255
}