Switch to side-by-side view

--- a
+++ b/partyMod/src/SurrogateSplits.c
@@ -0,0 +1,264 @@
+
+/**
+    Suggorgate splits
+    *\file SurrogateSplits.c
+    *\author $Author$
+    *\date $Date$
+*/
+                
+#include "party.h"
+
+/**
+    Search for surrogate splits for bypassing the primary split \n
+    *\param node the current node with primary split specified
+    *\param learnsample learning sample
+    *\param weights the weights associated with the current node
+    *\param controls an object of class `TreeControl'
+    *\param fitmem an object of class `TreeFitMemory'
+    *\todo enable nominal surrogate split variables as well
+*/
+
+void C_surrogates(SEXP node, SEXP learnsample, SEXP weights, SEXP controls, 
+                  SEXP fitmem) {
+
+    SEXP x, y, expcovinf; 
+    SEXP splitctrl, inputs; 
+    SEXP split, thiswhichNA;
+    int nobs, ninputs, i, j, k, jselect, maxsurr, *order, nvar = 0;
+    double ms, cp, *thisweights, *cutpoint, *maxstat, 
+           *splitstat, *dweights, *tweights, *dx, *dy;
+    double cut, *twotab, *ytmp, sumw = 0.0;
+    
+    nobs = get_nobs(learnsample);
+    ninputs = get_ninputs(learnsample);
+    splitctrl = get_splitctrl(controls);
+    maxsurr = get_maxsurrogate(splitctrl);
+    inputs = GET_SLOT(learnsample, PL2_inputsSym);
+    jselect = S3get_variableID(S3get_primarysplit(node));
+    
+    /* (weights > 0) in left node are the new `response' to be approximated */
+    y = S3get_nodeweights(VECTOR_ELT(node, S3_LEFT));
+    ytmp = Calloc(nobs, double);
+    for (i = 0; i < nobs; i++) {
+        ytmp[i] = REAL(y)[i];
+        if (ytmp[i] > 1.0) ytmp[i] = 1.0;
+    }
+
+    for (j = 0; j < ninputs; j++) {
+        if (is_nominal(inputs, j + 1)) continue;
+        nvar++;
+    }
+    nvar--;
+
+    if (maxsurr != LENGTH(S3get_surrogatesplits(node)))
+        error("nodes does not have %d surrogate splits", maxsurr);
+    if (maxsurr > nvar)
+        error("cannot set up %d surrogate splits with only %d ordered input variable(s)", 
+              maxsurr, nvar);
+
+    tweights = Calloc(nobs, double);
+    dweights = REAL(weights);
+    for (i = 0; i < nobs; i++) tweights[i] = dweights[i];
+    if (has_missings(inputs, jselect)) {
+        thiswhichNA = get_missings(inputs, jselect);
+        for (k = 0; k < LENGTH(thiswhichNA); k++)
+            tweights[INTEGER(thiswhichNA)[k] - 1] = 0.0;
+    }
+
+    /* check if sum(weights) > 1 */
+    sumw = 0.0;
+    for (i = 0; i < nobs; i++) sumw += tweights[i];
+    if (sumw < 2.0)
+        error("can't implement surrogate splits, not enough observations available");
+
+    expcovinf = GET_SLOT(fitmem, PL2_expcovinfssSym);
+    C_ExpectCovarInfluence(ytmp, 1, tweights, nobs, expcovinf);
+    
+    splitstat = REAL(get_splitstatistics(fitmem));
+    /* <FIXME> extend `TreeFitMemory' to those as well ... */
+    maxstat = Calloc(ninputs, double);
+    cutpoint = Calloc(ninputs, double);
+    order = Calloc(ninputs, int);
+    /* <FIXME> */
+    
+    /* this is essentially an exhaustive search */
+    /* <FIXME>: we don't want to do this for random forest like trees 
+       </FIXME>
+     */
+    for (j = 0; j < ninputs; j++) {
+    
+         order[j] = j + 1;
+         maxstat[j] = 0.0;
+         cutpoint[j] = 0.0;
+
+         /* ordered input variables only (for the moment) */
+         if ((j + 1) == jselect || is_nominal(inputs, j + 1))
+             continue;
+
+         x = get_variable(inputs, j + 1);
+
+         if (has_missings(inputs, j + 1)) {
+
+             thisweights = C_tempweights(j + 1, weights, fitmem, inputs);
+
+             /* check if sum(weights) > 1 */
+             sumw = 0.0;
+             for (i = 0; i < nobs; i++) sumw += thisweights[i];
+             if (sumw < 2.0) continue;
+                 
+             C_ExpectCovarInfluence(ytmp, 1, thisweights, nobs, expcovinf);
+             
+             C_split(REAL(x), 1, ytmp, 1, thisweights, nobs,
+                     INTEGER(get_ordering(inputs, j + 1)), splitctrl,
+                     GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
+                     expcovinf, &cp, &ms, splitstat);
+         } else {
+         
+             C_split(REAL(x), 1, ytmp, 1, tweights, nobs,
+             INTEGER(get_ordering(inputs, j + 1)), splitctrl,
+             GET_SLOT(fitmem, PL2_linexpcov2sampleSym),
+             expcovinf, &cp, &ms, splitstat);
+         }
+
+         maxstat[j] = -ms;
+         cutpoint[j] = cp;
+    }
+
+
+    /* order with respect to maximal statistic */
+    rsort_with_index(maxstat, order, ninputs);
+    
+    twotab = Calloc(4, double);
+    
+    /* the best `maxsurr' ones are implemented */
+    for (j = 0; j < maxsurr; j++) {
+
+        if (is_nominal(inputs, order[j])) continue;
+        
+        for (i = 0; i < 4; i++) twotab[i] = 0.0;
+        cut = cutpoint[order[j] - 1];
+        SET_VECTOR_ELT(S3get_surrogatesplits(node), j, 
+                       split = allocVector(VECSXP, SPLIT_LENGTH));
+        C_init_orderedsplit(split, 0);
+        S3set_variableID(split, order[j]);
+        REAL(S3get_splitpoint(split))[0] = cut;
+        dx = REAL(get_variable(inputs, order[j]));
+        dy = REAL(y);
+
+        /* OK, this is a dirty hack: determine if the split 
+           goes left or right by the Pearson residual of a 2x2 table.
+           I don't want to use the big caliber here 
+        */
+        for (i = 0; i < nobs; i++) {
+            twotab[0] += ((dy[i] == 1) && (dx[i] <= cut)) * tweights[i];
+            twotab[1] += (dy[i] == 1) * tweights[i];
+            twotab[2] += (dx[i] <= cut) * tweights[i];
+            twotab[3] += tweights[i];
+        }
+        S3set_toleft(split, (int) (twotab[0] - twotab[1] * twotab[2] / 
+                     twotab[3]) > 0);
+    }
+    
+    Free(maxstat);
+    Free(cutpoint);
+    Free(order);
+    Free(tweights);
+    Free(twotab);
+    Free(ytmp);
+}
+
+/**
+    R-interface to C_surrogates \n
+    *\param node the current node with primary split specified
+    *\param learnsample learning sample
+    *\param weights the weights associated with the current node
+    *\param controls an object of class `TreeControl'
+    *\param fitmem an object of class `TreeFitMemory'
+*/
+
+
+SEXP R_surrogates(SEXP node, SEXP learnsample, SEXP weights, SEXP controls, 
+                  SEXP fitmem) {
+
+    C_surrogates(node, learnsample, weights, controls, fitmem);
+    return(S3get_surrogatesplits(node));
+    
+}
+
+/**
+    Split with missing values \n
+    *\param node the current node with primary and surrogate splits 
+                 specified
+    *\param learnsample learning sample
+*/
+
+void C_splitsurrogate(SEXP node, SEXP learnsample) {
+
+    SEXP weights, split, surrsplit;
+    SEXP inputs, whichNA, whichNAns;
+    double cutpoint, *dx, *dweights, *leftweights, *rightweights;
+    int *iwhichNA, k;
+    int i, nna, ns;
+                    
+    weights = S3get_nodeweights(node);
+    dweights = REAL(weights);
+    inputs = GET_SLOT(learnsample, PL2_inputsSym);
+            
+    leftweights = REAL(S3get_nodeweights(S3get_leftnode(node)));
+    rightweights = REAL(S3get_nodeweights(S3get_rightnode(node)));
+    surrsplit = S3get_surrogatesplits(node);
+
+    /* if the primary split has any missings */
+    split = S3get_primarysplit(node);
+    if (has_missings(inputs, S3get_variableID(split))) {
+
+        /* where are the missings? */
+        whichNA = get_missings(inputs, S3get_variableID(split));
+        iwhichNA = INTEGER(whichNA);
+        nna = LENGTH(whichNA);
+
+        /* for all missing values ... */
+        for (k = 0; k < nna; k++) {
+            ns = 0;
+            i = iwhichNA[k] - 1;
+            if (dweights[i] == 0) continue;
+            
+            /* loop over surrogate splits until an appropriate one is found */
+            while(TRUE) {
+            
+                if (ns >= LENGTH(surrsplit)) break;
+
+                split = VECTOR_ELT(surrsplit, ns);
+                if (has_missings(inputs, S3get_variableID(split))) {
+                    whichNAns = get_missings(inputs, S3get_variableID(split));
+                    if (C_i_in_set(i + 1, whichNAns)) {
+                        ns++;
+                        continue;
+                    }
+                }
+
+                cutpoint = REAL(S3get_splitpoint(split))[0];
+                dx = REAL(get_variable(inputs, S3get_variableID(split)));
+
+                if (S3get_toleft(split)) {
+                    if (dx[i] <= cutpoint) {
+                        leftweights[i] = dweights[i];
+                        rightweights[i] = 0.0;
+                    } else {
+                        rightweights[i] = dweights[i];
+                        leftweights[i] = 0.0;
+                    }
+                } else {
+                    if (dx[i] <= cutpoint) {
+                        rightweights[i] = dweights[i];
+                        leftweights[i] = 0.0;
+                    } else {
+                        leftweights[i] = dweights[i];
+                        rightweights[i] = 0.0;
+                    }
+                }
+                break;
+            }
+        }
+    }
+}