Diff of /partyMod/R/MOB-Utils.R [000000] .. [fbf06f]

Switch to unified view

a b/partyMod/R/MOB-Utils.R
1
###########################
2
## convenience functions ##
3
###########################
4
5
## obtain the number/ID for all terminal nodes
6
terminal_nodeIDs <- function(node) {
7
  if(node$terminal) return(node$nodeID)
8
  ll <- terminal_nodeIDs(node$left)
9
  rr <- terminal_nodeIDs(node$right)
10
  return(c(ll, rr))
11
}
12
13
14
#########################
15
## workhorse functions ##
16
#########################
17
18
### determine which observations go left or right
19
mob_fit_childweights <- function(node, mf, weights) {
20
21
    partvar <- mf@get("part")
22
    xselect <- partvar[[node$psplit$variableID]]
23
24
    ## we need to coerce ordered factors to numeric
25
    ## this is what party C code does as well!
26
27
    if (class(node$psplit) == "orderedSplit") {
28
        leftweights <- (as.double(xselect) <= node$psplit$splitpoint) * weights
29
        rightweights <- (as.double(xselect) > node$psplit$splitpoint) * weights
30
    } else {
31
        leftweights <- (xselect %in%
32
            levels(xselect)[as.logical(node$psplit$splitpoint)]) * weights
33
        rightweights <- (!(xselect %in%
34
            levels(xselect)[as.logical(node$psplit$splitpoint)])) * weights
35
    }
36
37
    list(left = leftweights, right = rightweights)
38
}
39
40
### setup a new (inner or terminal) node of a tree
41
mob_fit_setupnode <- function(obj, mf, weights, control) {
42
43
    ### control parameters
44
    alpha <- control$alpha
45
    bonferroni <- control$bonferroni
46
    minsplit <- control$minsplit
47
    trim <- control$trim
48
    objfun <- control$objfun
49
    verbose <- control$verbose
50
    breakties <- control$breakties
51
    parm <- control$parm
52
53
    ### if too few observations: no split = return terminal node
54
    if (sum(weights) < 2 * minsplit) {
55
        node <- list(nodeID = NULL, weights = weights,
56
                     criterion = list(statistic = 0, criterion = 0, maxcriterion = 0),
57
                     terminal = TRUE, psplit = NULL, ssplits = NULL,
58
                     prediction = 0, left = NULL, right = NULL,
59
                     sumweights = as.double(sum(weights)))
60
        class(node) <- "TerminalModelNode"
61
        return(node)
62
    }
63
64
    ### variable selection via fluctuation tests
65
    test <- try(mob_fit_fluctests(obj, mf, minsplit = minsplit, trim = trim,
66
      breakties = breakties, parm = parm))
67
68
    if (!inherits(test, "try-error")) {
69
        if(bonferroni) {
70
          pval1 <- pmin(1, sum(!is.na(test$pval)) * test$pval)
71
          pval2 <- 1 - (1-test$pval)^sum(!is.na(test$pval))
72
          test$pval <- ifelse(!is.na(test$pval) & (test$pval > 0.01), pval2, pval1)
73
        }
74
75
        best <- test$best
76
        TERMINAL <- is.na(best) || test$pval[best] > alpha
77
78
        if (verbose) {
79
            cat("\n-------------------------------------------\nFluctuation tests of splitting variables:\n")
80
            print(rbind(statistic = test$stat, p.value = test$pval))
81
            cat("\nBest splitting variable: ")
82
            cat(names(test$stat)[best])
83
            cat("\nPerform split? ")
84
            cat(ifelse(TERMINAL, "no", "yes"))
85
            cat("\n-------------------------------------------\n")    
86
        }
87
    } else {
88
        TERMINAL <- TRUE
89
        test <- list(stat = NA, pval = NA)
90
    }
91
92
    ### splitting
93
    na_max <- function(x) {
94
      if(all(is.na(x))) NA else max(x, na.rm = TRUE)
95
    }
96
    if (TERMINAL) {
97
        node <- list(nodeID = NULL, weights = weights,
98
                 criterion = list(statistic = test$stat, 
99
                                      criterion = 1 - test$pval,
100
                      maxcriterion = na_max(1 - test$pval)),
101
                     terminal = TRUE, psplit = NULL, ssplits = NULL,
102
                     prediction = 0, left = NULL, right = NULL, 
103
                     sumweights = as.double(sum(weights)))
104
        class(node) <- "TerminalModelNode"
105
        return(node)
106
    } else {
107
        partvar <- mf@get("part")
108
        xselect <- partvar[[best]]
109
        thissplit <- mob_fit_splitnode(xselect, obj, mf, weights, minsplit = minsplit, 
110
                                       objfun = objfun, verbose = verbose)
111
112
        ## check if splitting was unsuccessful
113
        if (identical(FALSE, thissplit)) {
114
            node <- list(nodeID = NULL, weights = weights,
115
                         criterion = list(statistic = test$stat, 
116
                                          criterion = 1 - test$pval, 
117
                                          maxcriterion = na_max(1 - test$pval)),
118
                         terminal = TRUE, psplit = NULL, ssplits = NULL,
119
                         prediction = 0, left = NULL, right = NULL, 
120
                         sumweights = as.double(sum(weights)))
121
            class(node) <- "TerminalModelNode"  
122
            
123
            ### more confusion than information
124
        ### warning("no admissable split found", call. = FALSE)
125
        if(verbose)
126
          cat(paste("\nNo admissable split found in ", sQuote(names(test$stat)[best]), "\n", sep = ""))     
127
        return(node)
128
        }
129
130
        thissplit$variableID <- best
131
        thissplit$variableName <- names(partvar)[best]
132
        node <- list(nodeID = NULL, weights = weights, 
133
                     criterion = list(statistic = test$stat, 
134
                                      criterion = 1 - test$pval, 
135
                                      maxcriterion = na_max(1 - test$pval)),
136
                     terminal = FALSE,
137
                     psplit = thissplit, ssplits = NULL, 
138
                     prediction = 0, left = NULL, right = NULL, 
139
                     sumweights = as.double(sum(weights)))
140
        class(node) <- "SplittingNode"
141
    }
142
    
143
    node$variableID <- best
144
    if (verbose) {
145
        cat("\nNode properties:\n")
146
        print(node$psplit, left = TRUE)
147
        cat(paste("; criterion = ", round(node$criterion$maxcriterion, 3), 
148
              ", statistic = ", round(max(node$criterion$statistic), 3), "\n",
149
              collapse = "", sep = ""))
150
    }
151
    node
152
}
153
154
### variable selection:
155
### conduct all M-fluctuation tests of fitted obj 
156
### with respect to each variable from a set of
157
### potential partitioning variables in mf
158
mob_fit_fluctests <- function(obj, mf, minsplit, trim, breakties, parm) {
159
  ## Cramer-von Mises statistic might be supported in future versions
160
  CvM <- FALSE
161
  
162
  ## set up return values
163
  partvar <- mf@get("part")
164
  m <- NCOL(partvar)
165
  pval <- rep.int(0, m)
166
  stat <- rep.int(0, m)
167
  ifac <- rep.int(FALSE, m)
168
169
  ## extract estimating functions  
170
  process <- as.matrix(estfun(obj))
171
  k <- NCOL(process)
172
  
173
  ## extract weights
174
  ww <- weights(obj)
175
  if(is.null(ww)) ww <- rep(1, NROW(process))
176
  n <- sum(ww)
177
  
178
  ## drop observations with zero weight
179
  ww0 <- (ww > 0)
180
  process <- process[ww0, , drop = FALSE]
181
  partvar <- partvar[ww0, , drop = FALSE]
182
  ww <- ww[ww0]
183
  ## repeat observations with weight > 1
184
  process <- process/ww
185
  ww1 <- rep.int(1:length(ww), ww)
186
  process <- process[ww1, , drop = FALSE]
187
  stopifnot(NROW(process) == n)
188
189
  ## scale process
190
  process <- process/sqrt(n)
191
  J12 <- root.matrix(crossprod(process))
192
  process <- t(chol2inv(chol(J12)) %*% t(process))  
193
194
  ## select parameters to test
195
  if(!is.null(parm)) process <- process[, parm, drop = FALSE]
196
  k <- NCOL(process)
197
198
  ## get critical values for CvM statistic
199
  if(CvM) {
200
    if(k > 25) k <- 25 #Z# also issue warning
201
    critval <- get("sc.meanL2")[as.character(k), ]
202
  } else {
203
    from <- if(trim > 1) trim else ceiling(n * trim)
204
    from <- max(from, minsplit)
205
    to <- n - from
206
    lambda <- ((n-from)*to)/(from*(n-to))
207
208
    beta <- get("sc.beta.sup")
209
    logp.supLM <- function(x, k, lambda)
210
    {
211
      if(k > 40) {
212
        ## use Estrella (2003) asymptotic approximation
213
        logp_estrella2003 <- function(x, k, lambda)
214
          -lgamma(k/2) + k/2 * log(x/2) - x/2 + log(abs(log(lambda) * (1 - k/x) + 2/x))
215
        ## FIXME: Estrella only works well for large enough x
216
    ## hence require x > 1.5 * k for Estrella approximation and
217
    ## use an ad hoc interpolation for larger p-values
218
    p <- ifelse(x <= 1.5 * k, (x/(1.5 * k))^sqrt(k) * logp_estrella2003(1.5 * k, k, lambda), logp_estrella2003(x, k, lambda))
219
      } else {
220
        ## use Hansen (1997) approximation
221
        m <- ncol(beta)-1
222
        if(lambda<1) tau <- lambda
223
        else tau <- 1/(1+sqrt(lambda))
224
        beta <- beta[(((k-1)*25 +1):(k*25)),]
225
        dummy <- beta[,(1:m)]%*%x^(0:(m-1))
226
        dummy <- dummy*(dummy>0)
227
        pp <- pchisq(dummy, beta[,(m+1)], lower.tail = FALSE, log.p = TRUE)
228
        if(tau==0.5)
229
          p <- pchisq(x, k, lower.tail = FALSE, log.p = TRUE)
230
        else if(tau <= 0.01)
231
          p <- pp[25]
232
        else if(tau >= 0.49)
233
          p <- log((exp(log(0.5-tau) + pp[1]) + exp(log(tau-0.49) + pchisq(x,k,lower.tail = FALSE, log.p = TRUE)))*100)
234
        else
235
        {
236
          taua <- (0.51-tau)*50
237
          tau1 <- floor(taua)
238
          p <- log(exp(log(tau1 + 1 - taua) + pp[tau1]) + exp(log(taua-tau1) + pp[tau1+1]))
239
        }
240
      }
241
      return(as.vector(p))
242
    }
243
  }
244
245
  ## compute statistic and p-value for each ordering
246
  for(i in 1:m) {
247
    pvi <- partvar[,i]
248
    pvi <- pvi[ww1]
249
    if(is.factor(pvi)) {
250
      proci <- process[ORDER(pvi), , drop = FALSE]
251
      ifac[i] <- TRUE
252
253
      # re-apply factor() added to drop unused levels
254
      pvi <- factor(pvi[ORDER(pvi)])
255
      # compute segment weights
256
      segweights <- as.vector(table(pvi))/n ## tapply(ww, pvi, sum)/n      
257
258
      # compute statistic only if at least two levels are left
259
      if(length(segweights) < 2) {
260
        stat[i] <- 0
261
    pval[i] <- NA
262
      } else {      
263
        stat[i] <- sum(sapply(1:k, function(j) (tapply(proci[,j], pvi, sum)^2)/segweights))
264
        pval[i] <- pchisq(stat[i], k*(length(levels(pvi))-1), log.p = TRUE, lower.tail = FALSE)
265
      }
266
    } else {
267
      oi <- if(breakties) {
268
        mm <- sort(unique(pvi))
269
    mm <- ifelse(length(mm) > 1, min(diff(mm))/10, 1)
270
    ORDER(pvi + runif(length(pvi), min = -mm, max = +mm))
271
      } else {
272
        ORDER(pvi)
273
      }    
274
      proci <- process[oi, , drop = FALSE]
275
      proci <- apply(proci, 2, cumsum)
276
      stat[i] <- if(CvM) sum((proci)^2)/n 
277
        else if(from < to) {
278
      xx <- rowSums(proci^2)
279
      xx <- xx[from:to]
280
      tt <- (from:to)/n
281
      max(xx/(tt * (1-tt)))   
282
    } else {
283
      0
284
    }
285
      pval[i] <- if(CvM) log(approx(c(0, critval), c(1, 1-as.numeric(names(critval))), stat[i], rule=2)$y)
286
        else if(from < to) logp.supLM(stat[i], k, lambda) else NA
287
    }
288
  }
289
290
  ## select variable with minimal p-value
291
  best <- which.min(pval)
292
  if(length(best) < 1) best <- NA
293
  rval <- list(pval = exp(pval), stat = stat, best = best)
294
  names(rval$pval) <- names(partvar)
295
  names(rval$stat) <- names(partvar)
296
  if (!all(is.na(rval$best)))
297
      names(rval$best) <- names(partvar)[rval$best]
298
  return(rval)
299
}
300
301
### split in variable x, either ordered or nominal
302
mob_fit_splitnode <- function(x, obj, mf, weights, minsplit, objfun, verbose = TRUE) {
303
304
    ## process minsplit (to minimal number of observations)
305
    if (minsplit > 0.5 & minsplit < 1) minsplit <- 1 - minsplit
306
    if (minsplit < 0.5)
307
        minsplit <- ceiling(sum(weights) * minsplit)
308
   
309
    if (is.numeric(x)) {
310
    ### for numerical variables
311
        ux <- sort(unique(x))
312
        if (length(ux) == 0) stop("cannot find admissible split point in x")
313
        dev <- vector(mode = "numeric", length = length(ux))
314
315
        for (i in 1:length(ux)) {
316
            xs <- x <= ux[i]
317
            if (mob_fit_checksplit(xs, weights, minsplit)) {
318
                dev[i] <- Inf
319
            } else {
320
                dev[i] <- mob_fit_getobjfun(obj, mf, weights, xs, objfun = objfun)
321
            }
322
        }
323
324
        ## maybe none of the possible splits is admissible
325
        if (all(!is.finite(dev))) return(FALSE)
326
327
        split <- list(variableID = NULL, ordered = TRUE, 
328
                      splitpoint = as.double(ux[which.min(dev)]),
329
                      splitstatistic = dev, toleft = TRUE)
330
        class(split) <- "orderedSplit"
331
    } else {
332
    ### for categorical variables
333
        al <- mob_fit_getlevels(x)
334
        dev <- apply(al, 1, function(w) {
335
                   xs <- x %in% levels(x)[w]
336
                   if (mob_fit_checksplit(xs, weights, minsplit)) {
337
                       return(Inf)
338
                   } else {
339
                       mob_fit_getobjfun(obj, mf, weights, xs, objfun = objfun)
340
                   }
341
               })
342
343
        if (verbose) {
344
            cat(paste("\nSplitting ", if(is.ordered(x)) "ordered ",
345
                  "factor variable, objective function: \n", sep = ""))
346
            print(dev)
347
        }
348
349
        if (all(!is.finite(dev))) return(FALSE)
350
351
        ## ordered factors are of storage mode "numeric" in party!
352
        ## initVariableFrame coerces ordered factors to storage.mode "numeric"
353
        ## the following is consistent with party
354
        
355
        if (is.ordered(x)) {
356
            split <- list(variableID = NULL, ordered = TRUE,
357
                          splitpoint = as.double(which.min(dev)),
358
                          splitstatistic = dev, toleft = TRUE)
359
            class(split) <- "orderedSplit"
360
            attr(split$splitpoint, "levels") <- levels(x)
361
        }  else {
362
            tab <- as.integer(table(x[weights > 0]) > 0)
363
            split <- list(variableID = NULL, ordered = FALSE,
364
                          splitpoint = as.integer(al[which.min(dev),]), 
365
                          splitstatistic = dev, 
366
                          toleft = TRUE, table = tab)
367
            attr(split$splitpoint, "levels") <- levels(x)
368
            class(split) <- "nominalSplit"
369
        }
370
    }
371
    split
372
}
373
374
### get partitioned objective function for a particular split
375
mob_fit_getobjfun <- function(obj, mf, weights, left, objfun = deviance) {
376
  ## mf is the model frame
377
  ## weights are the observation weights
378
  ## left is 1 (if left of splitpoint) or 0
379
  weightsleft <- weights * left
380
  weightsright <- weights * (1 - left)
381
382
  ### fit left / right model 
383
  fmleft <- reweight(obj, weights = weightsleft)
384
  fmright <- reweight(obj, weights = weightsright)
385
386
  return(objfun(fmleft) + objfun(fmright))
387
}
388
389
### determine all possible splits for a factor, both nominal and ordinal
390
mob_fit_getlevels <- function(x) {
391
    nl <- nlevels(x)
392
    if (inherits(x, "ordered")) {
393
        indx <- diag(nl)
394
        indx[lower.tri(indx)] <- 1
395
        indx <- indx[-nl,]
396
    rownames(indx) <- levels(x)[-nl]
397
    } else {
398
        mi <- 2^(nl - 1) - 1
399
        indx <- matrix(0, nrow = mi, ncol = nl)
400
        for (i in 1:mi) { # go though all splits #
401
            ii <- i
402
            for (l in 1:nl) {
403
                indx[i, l] <- ii%%2;
404
                ii <- ii %/% 2   
405
            }
406
        }
407
        rownames(indx) <- apply(indx, 1, function(z) paste(levels(x)[z > 0], collapse = "+"))
408
    }
409
    colnames(indx) <- as.character(levels(x))
410
    storage.mode(indx) <- "logical"
411
    indx
412
}
413
414
### check split
415
mob_fit_checksplit <- function(split, weights, minsplit)
416
    (sum(split * weights) < minsplit || sum((1 - split) * weights) < minsplit)