[81de4e]: / nsc.R

Download this file

78 lines (64 with data), 2.6 kB

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
nsc <-
function (x, y, n.threshold = 30, offset.percent = 50, prior = NULL, remove.zeros = TRUE)
{
this.call <- match.call()
Y <- model.matrix(~factor(y) - 1, data = list(y = y))
xtest <- x
ytest <- y
n.class <- table(y)
if (min(n.class) == 1) {
stop(warning("Warning: a class contains only 1 sample"))
}
n <- sum(n.class)
ntest <- ncol(xtest)
K <- length(prior)
p <- nrow(x)
dimnames(Y) <- list(NULL, names(n.class))
centroids <- scale(x %*% Y, FALSE, n.class) ## WMEAN.G
xdif <- x - centroids %*% t(Y)
sd <- (xdif^2) %*% rep(1/(n - K), n)
sd <- drop(sqrt(sd)) #WSD.POOLED
offset <- quantile(sd, offset.percent/100)
sd <- sd + offset
centroid.overall <- drop(x %*% rep(1/n, n)) ## WMEAN
se.scale <- sqrt(1/n.class - 1/n) # mk
delta <- (centroids - centroid.overall)/sd
delta <- scale(delta, FALSE, se.scale) ##dik
threshold <- seq(0, max(abs(delta)), length = n.threshold)
nonzero <- seq(n.threshold)
errors <- threshold
yhat <- as.list(seq(n.threshold))
prob <- array(0, c(ntest, K, n.threshold))
for (ii in 1:n.threshold) {
cat(ii)
delta.shrunk <- soft.shrink(delta, threshold[ii])
#delta.shrunk <- scale(delta.shrunk, FALSE, 1/(se.scale))
delta.shrunk <- t(t(delta.shrunk) * as.numeric(se.scale))
nonzero[ii] <- attr(delta.shrunk, "nonzero")
posid <- drop(abs(delta.shrunk) %*% rep(1, K)) > 0
dd <- diag.disc((xtest - centroid.overall)/sd, delta.shrunk,
prior, weight = posid)
yhat[[ii]] <- softmax(dd)
dd <- safe.exp(dd)
prob[, , ii] <- dd/drop(dd %*% rep(1, K))
if (!is.null(ytest)) {
errors[ii] <- sum(yhat[[ii]] != ytest)
}
}
thresh.names <- format(round(threshold, 3))
names(yhat) <- thresh.names
attr(yhat, "row.names") <- paste(seq(ntest))
class(yhat) <- "data.frame"
if (remove.zeros)
n.threshold <- match(0, nonzero, n.threshold)
dimnames(prob) <- list(paste(seq(ntest)), names(n.class),
thresh.names)
object <- list(y = ytest, yhat = yhat, prob = prob[, , seq(n.threshold)],
centroids = centroids, centroid.overall = centroid.overall,
sd = sd, threshold = threshold[seq(n.threshold)], nonzero = nonzero[seq(n.threshold)],
se.scale = se.scale, call = this.call, prior = prior, offset = offset)
if (!is.null(ytest))
object$errors <- errors
#class(object) <- "nsc"
object
}