|
a |
|
b/nsc.R |
|
|
1 |
nsc <- |
|
|
2 |
function (x, y, n.threshold = 30, offset.percent = 50, prior = NULL, remove.zeros = TRUE) |
|
|
3 |
{ |
|
|
4 |
this.call <- match.call() |
|
|
5 |
Y <- model.matrix(~factor(y) - 1, data = list(y = y)) |
|
|
6 |
|
|
|
7 |
xtest <- x |
|
|
8 |
ytest <- y |
|
|
9 |
|
|
|
10 |
|
|
|
11 |
n.class <- table(y) |
|
|
12 |
if (min(n.class) == 1) { |
|
|
13 |
stop(warning("Warning: a class contains only 1 sample")) |
|
|
14 |
} |
|
|
15 |
|
|
|
16 |
n <- sum(n.class) |
|
|
17 |
ntest <- ncol(xtest) |
|
|
18 |
K <- length(prior) |
|
|
19 |
p <- nrow(x) |
|
|
20 |
|
|
|
21 |
dimnames(Y) <- list(NULL, names(n.class)) |
|
|
22 |
centroids <- scale(x %*% Y, FALSE, n.class) ## WMEAN.G |
|
|
23 |
|
|
|
24 |
xdif <- x - centroids %*% t(Y) |
|
|
25 |
sd <- (xdif^2) %*% rep(1/(n - K), n) |
|
|
26 |
sd <- drop(sqrt(sd)) #WSD.POOLED |
|
|
27 |
offset <- quantile(sd, offset.percent/100) |
|
|
28 |
sd <- sd + offset |
|
|
29 |
|
|
|
30 |
centroid.overall <- drop(x %*% rep(1/n, n)) ## WMEAN |
|
|
31 |
|
|
|
32 |
se.scale <- sqrt(1/n.class - 1/n) # mk |
|
|
33 |
|
|
|
34 |
delta <- (centroids - centroid.overall)/sd |
|
|
35 |
delta <- scale(delta, FALSE, se.scale) ##dik |
|
|
36 |
|
|
|
37 |
threshold <- seq(0, max(abs(delta)), length = n.threshold) |
|
|
38 |
|
|
|
39 |
nonzero <- seq(n.threshold) |
|
|
40 |
errors <- threshold |
|
|
41 |
yhat <- as.list(seq(n.threshold)) |
|
|
42 |
prob <- array(0, c(ntest, K, n.threshold)) |
|
|
43 |
|
|
|
44 |
for (ii in 1:n.threshold) { |
|
|
45 |
cat(ii) |
|
|
46 |
delta.shrunk <- soft.shrink(delta, threshold[ii]) |
|
|
47 |
#delta.shrunk <- scale(delta.shrunk, FALSE, 1/(se.scale)) |
|
|
48 |
delta.shrunk <- t(t(delta.shrunk) * as.numeric(se.scale)) |
|
|
49 |
|
|
|
50 |
nonzero[ii] <- attr(delta.shrunk, "nonzero") |
|
|
51 |
posid <- drop(abs(delta.shrunk) %*% rep(1, K)) > 0 |
|
|
52 |
dd <- diag.disc((xtest - centroid.overall)/sd, delta.shrunk, |
|
|
53 |
prior, weight = posid) |
|
|
54 |
yhat[[ii]] <- softmax(dd) |
|
|
55 |
dd <- safe.exp(dd) |
|
|
56 |
prob[, , ii] <- dd/drop(dd %*% rep(1, K)) |
|
|
57 |
if (!is.null(ytest)) { |
|
|
58 |
errors[ii] <- sum(yhat[[ii]] != ytest) |
|
|
59 |
} |
|
|
60 |
} |
|
|
61 |
thresh.names <- format(round(threshold, 3)) |
|
|
62 |
names(yhat) <- thresh.names |
|
|
63 |
attr(yhat, "row.names") <- paste(seq(ntest)) |
|
|
64 |
class(yhat) <- "data.frame" |
|
|
65 |
if (remove.zeros) |
|
|
66 |
n.threshold <- match(0, nonzero, n.threshold) |
|
|
67 |
dimnames(prob) <- list(paste(seq(ntest)), names(n.class), |
|
|
68 |
thresh.names) |
|
|
69 |
object <- list(y = ytest, yhat = yhat, prob = prob[, , seq(n.threshold)], |
|
|
70 |
centroids = centroids, centroid.overall = centroid.overall, |
|
|
71 |
sd = sd, threshold = threshold[seq(n.threshold)], nonzero = nonzero[seq(n.threshold)], |
|
|
72 |
se.scale = se.scale, call = this.call, prior = prior, offset = offset) |
|
|
73 |
if (!is.null(ytest)) |
|
|
74 |
object$errors <- errors |
|
|
75 |
#class(object) <- "nsc" |
|
|
76 |
object |
|
|
77 |
} |