--- a +++ b/partyMod/src/SurrogateSplits.c @@ -0,0 +1,264 @@ + +/** + Suggorgate splits + *\file SurrogateSplits.c + *\author $Author$ + *\date $Date$ +*/ + +#include "party.h" + +/** + Search for surrogate splits for bypassing the primary split \n + *\param node the current node with primary split specified + *\param learnsample learning sample + *\param weights the weights associated with the current node + *\param controls an object of class `TreeControl' + *\param fitmem an object of class `TreeFitMemory' + *\todo enable nominal surrogate split variables as well +*/ + +void C_surrogates(SEXP node, SEXP learnsample, SEXP weights, SEXP controls, + SEXP fitmem) { + + SEXP x, y, expcovinf; + SEXP splitctrl, inputs; + SEXP split, thiswhichNA; + int nobs, ninputs, i, j, k, jselect, maxsurr, *order, nvar = 0; + double ms, cp, *thisweights, *cutpoint, *maxstat, + *splitstat, *dweights, *tweights, *dx, *dy; + double cut, *twotab, *ytmp, sumw = 0.0; + + nobs = get_nobs(learnsample); + ninputs = get_ninputs(learnsample); + splitctrl = get_splitctrl(controls); + maxsurr = get_maxsurrogate(splitctrl); + inputs = GET_SLOT(learnsample, PL2_inputsSym); + jselect = S3get_variableID(S3get_primarysplit(node)); + + /* (weights > 0) in left node are the new `response' to be approximated */ + y = S3get_nodeweights(VECTOR_ELT(node, S3_LEFT)); + ytmp = Calloc(nobs, double); + for (i = 0; i < nobs; i++) { + ytmp[i] = REAL(y)[i]; + if (ytmp[i] > 1.0) ytmp[i] = 1.0; + } + + for (j = 0; j < ninputs; j++) { + if (is_nominal(inputs, j + 1)) continue; + nvar++; + } + nvar--; + + if (maxsurr != LENGTH(S3get_surrogatesplits(node))) + error("nodes does not have %d surrogate splits", maxsurr); + if (maxsurr > nvar) + error("cannot set up %d surrogate splits with only %d ordered input variable(s)", + maxsurr, nvar); + + tweights = Calloc(nobs, double); + dweights = REAL(weights); + for (i = 0; i < nobs; i++) tweights[i] = dweights[i]; + if (has_missings(inputs, jselect)) { + thiswhichNA = get_missings(inputs, jselect); + for (k = 0; k < LENGTH(thiswhichNA); k++) + tweights[INTEGER(thiswhichNA)[k] - 1] = 0.0; + } + + /* check if sum(weights) > 1 */ + sumw = 0.0; + for (i = 0; i < nobs; i++) sumw += tweights[i]; + if (sumw < 2.0) + error("can't implement surrogate splits, not enough observations available"); + + expcovinf = GET_SLOT(fitmem, PL2_expcovinfssSym); + C_ExpectCovarInfluence(ytmp, 1, tweights, nobs, expcovinf); + + splitstat = REAL(get_splitstatistics(fitmem)); + /* <FIXME> extend `TreeFitMemory' to those as well ... */ + maxstat = Calloc(ninputs, double); + cutpoint = Calloc(ninputs, double); + order = Calloc(ninputs, int); + /* <FIXME> */ + + /* this is essentially an exhaustive search */ + /* <FIXME>: we don't want to do this for random forest like trees + </FIXME> + */ + for (j = 0; j < ninputs; j++) { + + order[j] = j + 1; + maxstat[j] = 0.0; + cutpoint[j] = 0.0; + + /* ordered input variables only (for the moment) */ + if ((j + 1) == jselect || is_nominal(inputs, j + 1)) + continue; + + x = get_variable(inputs, j + 1); + + if (has_missings(inputs, j + 1)) { + + thisweights = C_tempweights(j + 1, weights, fitmem, inputs); + + /* check if sum(weights) > 1 */ + sumw = 0.0; + for (i = 0; i < nobs; i++) sumw += thisweights[i]; + if (sumw < 2.0) continue; + + C_ExpectCovarInfluence(ytmp, 1, thisweights, nobs, expcovinf); + + C_split(REAL(x), 1, ytmp, 1, thisweights, nobs, + INTEGER(get_ordering(inputs, j + 1)), splitctrl, + GET_SLOT(fitmem, PL2_linexpcov2sampleSym), + expcovinf, &cp, &ms, splitstat); + } else { + + C_split(REAL(x), 1, ytmp, 1, tweights, nobs, + INTEGER(get_ordering(inputs, j + 1)), splitctrl, + GET_SLOT(fitmem, PL2_linexpcov2sampleSym), + expcovinf, &cp, &ms, splitstat); + } + + maxstat[j] = -ms; + cutpoint[j] = cp; + } + + + /* order with respect to maximal statistic */ + rsort_with_index(maxstat, order, ninputs); + + twotab = Calloc(4, double); + + /* the best `maxsurr' ones are implemented */ + for (j = 0; j < maxsurr; j++) { + + if (is_nominal(inputs, order[j])) continue; + + for (i = 0; i < 4; i++) twotab[i] = 0.0; + cut = cutpoint[order[j] - 1]; + SET_VECTOR_ELT(S3get_surrogatesplits(node), j, + split = allocVector(VECSXP, SPLIT_LENGTH)); + C_init_orderedsplit(split, 0); + S3set_variableID(split, order[j]); + REAL(S3get_splitpoint(split))[0] = cut; + dx = REAL(get_variable(inputs, order[j])); + dy = REAL(y); + + /* OK, this is a dirty hack: determine if the split + goes left or right by the Pearson residual of a 2x2 table. + I don't want to use the big caliber here + */ + for (i = 0; i < nobs; i++) { + twotab[0] += ((dy[i] == 1) && (dx[i] <= cut)) * tweights[i]; + twotab[1] += (dy[i] == 1) * tweights[i]; + twotab[2] += (dx[i] <= cut) * tweights[i]; + twotab[3] += tweights[i]; + } + S3set_toleft(split, (int) (twotab[0] - twotab[1] * twotab[2] / + twotab[3]) > 0); + } + + Free(maxstat); + Free(cutpoint); + Free(order); + Free(tweights); + Free(twotab); + Free(ytmp); +} + +/** + R-interface to C_surrogates \n + *\param node the current node with primary split specified + *\param learnsample learning sample + *\param weights the weights associated with the current node + *\param controls an object of class `TreeControl' + *\param fitmem an object of class `TreeFitMemory' +*/ + + +SEXP R_surrogates(SEXP node, SEXP learnsample, SEXP weights, SEXP controls, + SEXP fitmem) { + + C_surrogates(node, learnsample, weights, controls, fitmem); + return(S3get_surrogatesplits(node)); + +} + +/** + Split with missing values \n + *\param node the current node with primary and surrogate splits + specified + *\param learnsample learning sample +*/ + +void C_splitsurrogate(SEXP node, SEXP learnsample) { + + SEXP weights, split, surrsplit; + SEXP inputs, whichNA, whichNAns; + double cutpoint, *dx, *dweights, *leftweights, *rightweights; + int *iwhichNA, k; + int i, nna, ns; + + weights = S3get_nodeweights(node); + dweights = REAL(weights); + inputs = GET_SLOT(learnsample, PL2_inputsSym); + + leftweights = REAL(S3get_nodeweights(S3get_leftnode(node))); + rightweights = REAL(S3get_nodeweights(S3get_rightnode(node))); + surrsplit = S3get_surrogatesplits(node); + + /* if the primary split has any missings */ + split = S3get_primarysplit(node); + if (has_missings(inputs, S3get_variableID(split))) { + + /* where are the missings? */ + whichNA = get_missings(inputs, S3get_variableID(split)); + iwhichNA = INTEGER(whichNA); + nna = LENGTH(whichNA); + + /* for all missing values ... */ + for (k = 0; k < nna; k++) { + ns = 0; + i = iwhichNA[k] - 1; + if (dweights[i] == 0) continue; + + /* loop over surrogate splits until an appropriate one is found */ + while(TRUE) { + + if (ns >= LENGTH(surrsplit)) break; + + split = VECTOR_ELT(surrsplit, ns); + if (has_missings(inputs, S3get_variableID(split))) { + whichNAns = get_missings(inputs, S3get_variableID(split)); + if (C_i_in_set(i + 1, whichNAns)) { + ns++; + continue; + } + } + + cutpoint = REAL(S3get_splitpoint(split))[0]; + dx = REAL(get_variable(inputs, S3get_variableID(split))); + + if (S3get_toleft(split)) { + if (dx[i] <= cutpoint) { + leftweights[i] = dweights[i]; + rightweights[i] = 0.0; + } else { + rightweights[i] = dweights[i]; + leftweights[i] = 0.0; + } + } else { + if (dx[i] <= cutpoint) { + rightweights[i] = dweights[i]; + leftweights[i] = 0.0; + } else { + leftweights[i] = dweights[i]; + rightweights[i] = 0.0; + } + } + break; + } + } + } +}