|
a |
|
b/partyMod/R/MOB-Utils.R |
|
|
1 |
########################### |
|
|
2 |
## convenience functions ## |
|
|
3 |
########################### |
|
|
4 |
|
|
|
5 |
## obtain the number/ID for all terminal nodes |
|
|
6 |
terminal_nodeIDs <- function(node) { |
|
|
7 |
if(node$terminal) return(node$nodeID) |
|
|
8 |
ll <- terminal_nodeIDs(node$left) |
|
|
9 |
rr <- terminal_nodeIDs(node$right) |
|
|
10 |
return(c(ll, rr)) |
|
|
11 |
} |
|
|
12 |
|
|
|
13 |
|
|
|
14 |
######################### |
|
|
15 |
## workhorse functions ## |
|
|
16 |
######################### |
|
|
17 |
|
|
|
18 |
### determine which observations go left or right |
|
|
19 |
mob_fit_childweights <- function(node, mf, weights) { |
|
|
20 |
|
|
|
21 |
partvar <- mf@get("part") |
|
|
22 |
xselect <- partvar[[node$psplit$variableID]] |
|
|
23 |
|
|
|
24 |
## we need to coerce ordered factors to numeric |
|
|
25 |
## this is what party C code does as well! |
|
|
26 |
|
|
|
27 |
if (class(node$psplit) == "orderedSplit") { |
|
|
28 |
leftweights <- (as.double(xselect) <= node$psplit$splitpoint) * weights |
|
|
29 |
rightweights <- (as.double(xselect) > node$psplit$splitpoint) * weights |
|
|
30 |
} else { |
|
|
31 |
leftweights <- (xselect %in% |
|
|
32 |
levels(xselect)[as.logical(node$psplit$splitpoint)]) * weights |
|
|
33 |
rightweights <- (!(xselect %in% |
|
|
34 |
levels(xselect)[as.logical(node$psplit$splitpoint)])) * weights |
|
|
35 |
} |
|
|
36 |
|
|
|
37 |
list(left = leftweights, right = rightweights) |
|
|
38 |
} |
|
|
39 |
|
|
|
40 |
### setup a new (inner or terminal) node of a tree |
|
|
41 |
mob_fit_setupnode <- function(obj, mf, weights, control) { |
|
|
42 |
|
|
|
43 |
### control parameters |
|
|
44 |
alpha <- control$alpha |
|
|
45 |
bonferroni <- control$bonferroni |
|
|
46 |
minsplit <- control$minsplit |
|
|
47 |
trim <- control$trim |
|
|
48 |
objfun <- control$objfun |
|
|
49 |
verbose <- control$verbose |
|
|
50 |
breakties <- control$breakties |
|
|
51 |
parm <- control$parm |
|
|
52 |
|
|
|
53 |
### if too few observations: no split = return terminal node |
|
|
54 |
if (sum(weights) < 2 * minsplit) { |
|
|
55 |
node <- list(nodeID = NULL, weights = weights, |
|
|
56 |
criterion = list(statistic = 0, criterion = 0, maxcriterion = 0), |
|
|
57 |
terminal = TRUE, psplit = NULL, ssplits = NULL, |
|
|
58 |
prediction = 0, left = NULL, right = NULL, |
|
|
59 |
sumweights = as.double(sum(weights))) |
|
|
60 |
class(node) <- "TerminalModelNode" |
|
|
61 |
return(node) |
|
|
62 |
} |
|
|
63 |
|
|
|
64 |
### variable selection via fluctuation tests |
|
|
65 |
test <- try(mob_fit_fluctests(obj, mf, minsplit = minsplit, trim = trim, |
|
|
66 |
breakties = breakties, parm = parm)) |
|
|
67 |
|
|
|
68 |
if (!inherits(test, "try-error")) { |
|
|
69 |
if(bonferroni) { |
|
|
70 |
pval1 <- pmin(1, sum(!is.na(test$pval)) * test$pval) |
|
|
71 |
pval2 <- 1 - (1-test$pval)^sum(!is.na(test$pval)) |
|
|
72 |
test$pval <- ifelse(!is.na(test$pval) & (test$pval > 0.01), pval2, pval1) |
|
|
73 |
} |
|
|
74 |
|
|
|
75 |
best <- test$best |
|
|
76 |
TERMINAL <- is.na(best) || test$pval[best] > alpha |
|
|
77 |
|
|
|
78 |
if (verbose) { |
|
|
79 |
cat("\n-------------------------------------------\nFluctuation tests of splitting variables:\n") |
|
|
80 |
print(rbind(statistic = test$stat, p.value = test$pval)) |
|
|
81 |
cat("\nBest splitting variable: ") |
|
|
82 |
cat(names(test$stat)[best]) |
|
|
83 |
cat("\nPerform split? ") |
|
|
84 |
cat(ifelse(TERMINAL, "no", "yes")) |
|
|
85 |
cat("\n-------------------------------------------\n") |
|
|
86 |
} |
|
|
87 |
} else { |
|
|
88 |
TERMINAL <- TRUE |
|
|
89 |
test <- list(stat = NA, pval = NA) |
|
|
90 |
} |
|
|
91 |
|
|
|
92 |
### splitting |
|
|
93 |
na_max <- function(x) { |
|
|
94 |
if(all(is.na(x))) NA else max(x, na.rm = TRUE) |
|
|
95 |
} |
|
|
96 |
if (TERMINAL) { |
|
|
97 |
node <- list(nodeID = NULL, weights = weights, |
|
|
98 |
criterion = list(statistic = test$stat, |
|
|
99 |
criterion = 1 - test$pval, |
|
|
100 |
maxcriterion = na_max(1 - test$pval)), |
|
|
101 |
terminal = TRUE, psplit = NULL, ssplits = NULL, |
|
|
102 |
prediction = 0, left = NULL, right = NULL, |
|
|
103 |
sumweights = as.double(sum(weights))) |
|
|
104 |
class(node) <- "TerminalModelNode" |
|
|
105 |
return(node) |
|
|
106 |
} else { |
|
|
107 |
partvar <- mf@get("part") |
|
|
108 |
xselect <- partvar[[best]] |
|
|
109 |
thissplit <- mob_fit_splitnode(xselect, obj, mf, weights, minsplit = minsplit, |
|
|
110 |
objfun = objfun, verbose = verbose) |
|
|
111 |
|
|
|
112 |
## check if splitting was unsuccessful |
|
|
113 |
if (identical(FALSE, thissplit)) { |
|
|
114 |
node <- list(nodeID = NULL, weights = weights, |
|
|
115 |
criterion = list(statistic = test$stat, |
|
|
116 |
criterion = 1 - test$pval, |
|
|
117 |
maxcriterion = na_max(1 - test$pval)), |
|
|
118 |
terminal = TRUE, psplit = NULL, ssplits = NULL, |
|
|
119 |
prediction = 0, left = NULL, right = NULL, |
|
|
120 |
sumweights = as.double(sum(weights))) |
|
|
121 |
class(node) <- "TerminalModelNode" |
|
|
122 |
|
|
|
123 |
### more confusion than information |
|
|
124 |
### warning("no admissable split found", call. = FALSE) |
|
|
125 |
if(verbose) |
|
|
126 |
cat(paste("\nNo admissable split found in ", sQuote(names(test$stat)[best]), "\n", sep = "")) |
|
|
127 |
return(node) |
|
|
128 |
} |
|
|
129 |
|
|
|
130 |
thissplit$variableID <- best |
|
|
131 |
thissplit$variableName <- names(partvar)[best] |
|
|
132 |
node <- list(nodeID = NULL, weights = weights, |
|
|
133 |
criterion = list(statistic = test$stat, |
|
|
134 |
criterion = 1 - test$pval, |
|
|
135 |
maxcriterion = na_max(1 - test$pval)), |
|
|
136 |
terminal = FALSE, |
|
|
137 |
psplit = thissplit, ssplits = NULL, |
|
|
138 |
prediction = 0, left = NULL, right = NULL, |
|
|
139 |
sumweights = as.double(sum(weights))) |
|
|
140 |
class(node) <- "SplittingNode" |
|
|
141 |
} |
|
|
142 |
|
|
|
143 |
node$variableID <- best |
|
|
144 |
if (verbose) { |
|
|
145 |
cat("\nNode properties:\n") |
|
|
146 |
print(node$psplit, left = TRUE) |
|
|
147 |
cat(paste("; criterion = ", round(node$criterion$maxcriterion, 3), |
|
|
148 |
", statistic = ", round(max(node$criterion$statistic), 3), "\n", |
|
|
149 |
collapse = "", sep = "")) |
|
|
150 |
} |
|
|
151 |
node |
|
|
152 |
} |
|
|
153 |
|
|
|
154 |
### variable selection: |
|
|
155 |
### conduct all M-fluctuation tests of fitted obj |
|
|
156 |
### with respect to each variable from a set of |
|
|
157 |
### potential partitioning variables in mf |
|
|
158 |
mob_fit_fluctests <- function(obj, mf, minsplit, trim, breakties, parm) { |
|
|
159 |
## Cramer-von Mises statistic might be supported in future versions |
|
|
160 |
CvM <- FALSE |
|
|
161 |
|
|
|
162 |
## set up return values |
|
|
163 |
partvar <- mf@get("part") |
|
|
164 |
m <- NCOL(partvar) |
|
|
165 |
pval <- rep.int(0, m) |
|
|
166 |
stat <- rep.int(0, m) |
|
|
167 |
ifac <- rep.int(FALSE, m) |
|
|
168 |
|
|
|
169 |
## extract estimating functions |
|
|
170 |
process <- as.matrix(estfun(obj)) |
|
|
171 |
k <- NCOL(process) |
|
|
172 |
|
|
|
173 |
## extract weights |
|
|
174 |
ww <- weights(obj) |
|
|
175 |
if(is.null(ww)) ww <- rep(1, NROW(process)) |
|
|
176 |
n <- sum(ww) |
|
|
177 |
|
|
|
178 |
## drop observations with zero weight |
|
|
179 |
ww0 <- (ww > 0) |
|
|
180 |
process <- process[ww0, , drop = FALSE] |
|
|
181 |
partvar <- partvar[ww0, , drop = FALSE] |
|
|
182 |
ww <- ww[ww0] |
|
|
183 |
## repeat observations with weight > 1 |
|
|
184 |
process <- process/ww |
|
|
185 |
ww1 <- rep.int(1:length(ww), ww) |
|
|
186 |
process <- process[ww1, , drop = FALSE] |
|
|
187 |
stopifnot(NROW(process) == n) |
|
|
188 |
|
|
|
189 |
## scale process |
|
|
190 |
process <- process/sqrt(n) |
|
|
191 |
J12 <- root.matrix(crossprod(process)) |
|
|
192 |
process <- t(chol2inv(chol(J12)) %*% t(process)) |
|
|
193 |
|
|
|
194 |
## select parameters to test |
|
|
195 |
if(!is.null(parm)) process <- process[, parm, drop = FALSE] |
|
|
196 |
k <- NCOL(process) |
|
|
197 |
|
|
|
198 |
## get critical values for CvM statistic |
|
|
199 |
if(CvM) { |
|
|
200 |
if(k > 25) k <- 25 #Z# also issue warning |
|
|
201 |
critval <- get("sc.meanL2")[as.character(k), ] |
|
|
202 |
} else { |
|
|
203 |
from <- if(trim > 1) trim else ceiling(n * trim) |
|
|
204 |
from <- max(from, minsplit) |
|
|
205 |
to <- n - from |
|
|
206 |
lambda <- ((n-from)*to)/(from*(n-to)) |
|
|
207 |
|
|
|
208 |
beta <- get("sc.beta.sup") |
|
|
209 |
logp.supLM <- function(x, k, lambda) |
|
|
210 |
{ |
|
|
211 |
if(k > 40) { |
|
|
212 |
## use Estrella (2003) asymptotic approximation |
|
|
213 |
logp_estrella2003 <- function(x, k, lambda) |
|
|
214 |
-lgamma(k/2) + k/2 * log(x/2) - x/2 + log(abs(log(lambda) * (1 - k/x) + 2/x)) |
|
|
215 |
## FIXME: Estrella only works well for large enough x |
|
|
216 |
## hence require x > 1.5 * k for Estrella approximation and |
|
|
217 |
## use an ad hoc interpolation for larger p-values |
|
|
218 |
p <- ifelse(x <= 1.5 * k, (x/(1.5 * k))^sqrt(k) * logp_estrella2003(1.5 * k, k, lambda), logp_estrella2003(x, k, lambda)) |
|
|
219 |
} else { |
|
|
220 |
## use Hansen (1997) approximation |
|
|
221 |
m <- ncol(beta)-1 |
|
|
222 |
if(lambda<1) tau <- lambda |
|
|
223 |
else tau <- 1/(1+sqrt(lambda)) |
|
|
224 |
beta <- beta[(((k-1)*25 +1):(k*25)),] |
|
|
225 |
dummy <- beta[,(1:m)]%*%x^(0:(m-1)) |
|
|
226 |
dummy <- dummy*(dummy>0) |
|
|
227 |
pp <- pchisq(dummy, beta[,(m+1)], lower.tail = FALSE, log.p = TRUE) |
|
|
228 |
if(tau==0.5) |
|
|
229 |
p <- pchisq(x, k, lower.tail = FALSE, log.p = TRUE) |
|
|
230 |
else if(tau <= 0.01) |
|
|
231 |
p <- pp[25] |
|
|
232 |
else if(tau >= 0.49) |
|
|
233 |
p <- log((exp(log(0.5-tau) + pp[1]) + exp(log(tau-0.49) + pchisq(x,k,lower.tail = FALSE, log.p = TRUE)))*100) |
|
|
234 |
else |
|
|
235 |
{ |
|
|
236 |
taua <- (0.51-tau)*50 |
|
|
237 |
tau1 <- floor(taua) |
|
|
238 |
p <- log(exp(log(tau1 + 1 - taua) + pp[tau1]) + exp(log(taua-tau1) + pp[tau1+1])) |
|
|
239 |
} |
|
|
240 |
} |
|
|
241 |
return(as.vector(p)) |
|
|
242 |
} |
|
|
243 |
} |
|
|
244 |
|
|
|
245 |
## compute statistic and p-value for each ordering |
|
|
246 |
for(i in 1:m) { |
|
|
247 |
pvi <- partvar[,i] |
|
|
248 |
pvi <- pvi[ww1] |
|
|
249 |
if(is.factor(pvi)) { |
|
|
250 |
proci <- process[ORDER(pvi), , drop = FALSE] |
|
|
251 |
ifac[i] <- TRUE |
|
|
252 |
|
|
|
253 |
# re-apply factor() added to drop unused levels |
|
|
254 |
pvi <- factor(pvi[ORDER(pvi)]) |
|
|
255 |
# compute segment weights |
|
|
256 |
segweights <- as.vector(table(pvi))/n ## tapply(ww, pvi, sum)/n |
|
|
257 |
|
|
|
258 |
# compute statistic only if at least two levels are left |
|
|
259 |
if(length(segweights) < 2) { |
|
|
260 |
stat[i] <- 0 |
|
|
261 |
pval[i] <- NA |
|
|
262 |
} else { |
|
|
263 |
stat[i] <- sum(sapply(1:k, function(j) (tapply(proci[,j], pvi, sum)^2)/segweights)) |
|
|
264 |
pval[i] <- pchisq(stat[i], k*(length(levels(pvi))-1), log.p = TRUE, lower.tail = FALSE) |
|
|
265 |
} |
|
|
266 |
} else { |
|
|
267 |
oi <- if(breakties) { |
|
|
268 |
mm <- sort(unique(pvi)) |
|
|
269 |
mm <- ifelse(length(mm) > 1, min(diff(mm))/10, 1) |
|
|
270 |
ORDER(pvi + runif(length(pvi), min = -mm, max = +mm)) |
|
|
271 |
} else { |
|
|
272 |
ORDER(pvi) |
|
|
273 |
} |
|
|
274 |
proci <- process[oi, , drop = FALSE] |
|
|
275 |
proci <- apply(proci, 2, cumsum) |
|
|
276 |
stat[i] <- if(CvM) sum((proci)^2)/n |
|
|
277 |
else if(from < to) { |
|
|
278 |
xx <- rowSums(proci^2) |
|
|
279 |
xx <- xx[from:to] |
|
|
280 |
tt <- (from:to)/n |
|
|
281 |
max(xx/(tt * (1-tt))) |
|
|
282 |
} else { |
|
|
283 |
0 |
|
|
284 |
} |
|
|
285 |
pval[i] <- if(CvM) log(approx(c(0, critval), c(1, 1-as.numeric(names(critval))), stat[i], rule=2)$y) |
|
|
286 |
else if(from < to) logp.supLM(stat[i], k, lambda) else NA |
|
|
287 |
} |
|
|
288 |
} |
|
|
289 |
|
|
|
290 |
## select variable with minimal p-value |
|
|
291 |
best <- which.min(pval) |
|
|
292 |
if(length(best) < 1) best <- NA |
|
|
293 |
rval <- list(pval = exp(pval), stat = stat, best = best) |
|
|
294 |
names(rval$pval) <- names(partvar) |
|
|
295 |
names(rval$stat) <- names(partvar) |
|
|
296 |
if (!all(is.na(rval$best))) |
|
|
297 |
names(rval$best) <- names(partvar)[rval$best] |
|
|
298 |
return(rval) |
|
|
299 |
} |
|
|
300 |
|
|
|
301 |
### split in variable x, either ordered or nominal |
|
|
302 |
mob_fit_splitnode <- function(x, obj, mf, weights, minsplit, objfun, verbose = TRUE) { |
|
|
303 |
|
|
|
304 |
## process minsplit (to minimal number of observations) |
|
|
305 |
if (minsplit > 0.5 & minsplit < 1) minsplit <- 1 - minsplit |
|
|
306 |
if (minsplit < 0.5) |
|
|
307 |
minsplit <- ceiling(sum(weights) * minsplit) |
|
|
308 |
|
|
|
309 |
if (is.numeric(x)) { |
|
|
310 |
### for numerical variables |
|
|
311 |
ux <- sort(unique(x)) |
|
|
312 |
if (length(ux) == 0) stop("cannot find admissible split point in x") |
|
|
313 |
dev <- vector(mode = "numeric", length = length(ux)) |
|
|
314 |
|
|
|
315 |
for (i in 1:length(ux)) { |
|
|
316 |
xs <- x <= ux[i] |
|
|
317 |
if (mob_fit_checksplit(xs, weights, minsplit)) { |
|
|
318 |
dev[i] <- Inf |
|
|
319 |
} else { |
|
|
320 |
dev[i] <- mob_fit_getobjfun(obj, mf, weights, xs, objfun = objfun) |
|
|
321 |
} |
|
|
322 |
} |
|
|
323 |
|
|
|
324 |
## maybe none of the possible splits is admissible |
|
|
325 |
if (all(!is.finite(dev))) return(FALSE) |
|
|
326 |
|
|
|
327 |
split <- list(variableID = NULL, ordered = TRUE, |
|
|
328 |
splitpoint = as.double(ux[which.min(dev)]), |
|
|
329 |
splitstatistic = dev, toleft = TRUE) |
|
|
330 |
class(split) <- "orderedSplit" |
|
|
331 |
} else { |
|
|
332 |
### for categorical variables |
|
|
333 |
al <- mob_fit_getlevels(x) |
|
|
334 |
dev <- apply(al, 1, function(w) { |
|
|
335 |
xs <- x %in% levels(x)[w] |
|
|
336 |
if (mob_fit_checksplit(xs, weights, minsplit)) { |
|
|
337 |
return(Inf) |
|
|
338 |
} else { |
|
|
339 |
mob_fit_getobjfun(obj, mf, weights, xs, objfun = objfun) |
|
|
340 |
} |
|
|
341 |
}) |
|
|
342 |
|
|
|
343 |
if (verbose) { |
|
|
344 |
cat(paste("\nSplitting ", if(is.ordered(x)) "ordered ", |
|
|
345 |
"factor variable, objective function: \n", sep = "")) |
|
|
346 |
print(dev) |
|
|
347 |
} |
|
|
348 |
|
|
|
349 |
if (all(!is.finite(dev))) return(FALSE) |
|
|
350 |
|
|
|
351 |
## ordered factors are of storage mode "numeric" in party! |
|
|
352 |
## initVariableFrame coerces ordered factors to storage.mode "numeric" |
|
|
353 |
## the following is consistent with party |
|
|
354 |
|
|
|
355 |
if (is.ordered(x)) { |
|
|
356 |
split <- list(variableID = NULL, ordered = TRUE, |
|
|
357 |
splitpoint = as.double(which.min(dev)), |
|
|
358 |
splitstatistic = dev, toleft = TRUE) |
|
|
359 |
class(split) <- "orderedSplit" |
|
|
360 |
attr(split$splitpoint, "levels") <- levels(x) |
|
|
361 |
} else { |
|
|
362 |
tab <- as.integer(table(x[weights > 0]) > 0) |
|
|
363 |
split <- list(variableID = NULL, ordered = FALSE, |
|
|
364 |
splitpoint = as.integer(al[which.min(dev),]), |
|
|
365 |
splitstatistic = dev, |
|
|
366 |
toleft = TRUE, table = tab) |
|
|
367 |
attr(split$splitpoint, "levels") <- levels(x) |
|
|
368 |
class(split) <- "nominalSplit" |
|
|
369 |
} |
|
|
370 |
} |
|
|
371 |
split |
|
|
372 |
} |
|
|
373 |
|
|
|
374 |
### get partitioned objective function for a particular split |
|
|
375 |
mob_fit_getobjfun <- function(obj, mf, weights, left, objfun = deviance) { |
|
|
376 |
## mf is the model frame |
|
|
377 |
## weights are the observation weights |
|
|
378 |
## left is 1 (if left of splitpoint) or 0 |
|
|
379 |
weightsleft <- weights * left |
|
|
380 |
weightsright <- weights * (1 - left) |
|
|
381 |
|
|
|
382 |
### fit left / right model |
|
|
383 |
fmleft <- reweight(obj, weights = weightsleft) |
|
|
384 |
fmright <- reweight(obj, weights = weightsright) |
|
|
385 |
|
|
|
386 |
return(objfun(fmleft) + objfun(fmright)) |
|
|
387 |
} |
|
|
388 |
|
|
|
389 |
### determine all possible splits for a factor, both nominal and ordinal |
|
|
390 |
mob_fit_getlevels <- function(x) { |
|
|
391 |
nl <- nlevels(x) |
|
|
392 |
if (inherits(x, "ordered")) { |
|
|
393 |
indx <- diag(nl) |
|
|
394 |
indx[lower.tri(indx)] <- 1 |
|
|
395 |
indx <- indx[-nl,] |
|
|
396 |
rownames(indx) <- levels(x)[-nl] |
|
|
397 |
} else { |
|
|
398 |
mi <- 2^(nl - 1) - 1 |
|
|
399 |
indx <- matrix(0, nrow = mi, ncol = nl) |
|
|
400 |
for (i in 1:mi) { # go though all splits # |
|
|
401 |
ii <- i |
|
|
402 |
for (l in 1:nl) { |
|
|
403 |
indx[i, l] <- ii%%2; |
|
|
404 |
ii <- ii %/% 2 |
|
|
405 |
} |
|
|
406 |
} |
|
|
407 |
rownames(indx) <- apply(indx, 1, function(z) paste(levels(x)[z > 0], collapse = "+")) |
|
|
408 |
} |
|
|
409 |
colnames(indx) <- as.character(levels(x)) |
|
|
410 |
storage.mode(indx) <- "logical" |
|
|
411 |
indx |
|
|
412 |
} |
|
|
413 |
|
|
|
414 |
### check split |
|
|
415 |
mob_fit_checksplit <- function(split, weights, minsplit) |
|
|
416 |
(sum(split * weights) < minsplit || sum((1 - split) * weights) < minsplit) |