a b/nsc.R
1
nsc <-
2
function (x, y, n.threshold = 30, offset.percent = 50, prior = NULL, remove.zeros = TRUE) 
3
{
4
  this.call <- match.call()
5
  Y <- model.matrix(~factor(y) - 1, data = list(y = y))
6
  
7
  xtest <- x
8
  ytest <- y
9
  
10
  
11
  n.class <- table(y)
12
  if (min(n.class) == 1) {
13
    stop(warning("Warning: a class contains only 1 sample"))
14
  }
15
16
  n <- sum(n.class)
17
  ntest <- ncol(xtest)
18
  K <- length(prior)
19
  p <- nrow(x)
20
  
21
  dimnames(Y) <- list(NULL, names(n.class))
22
  centroids <- scale(x %*% Y, FALSE, n.class)  ## WMEAN.G
23
  
24
  xdif <- x - centroids %*% t(Y)
25
  sd <- (xdif^2) %*% rep(1/(n - K), n)
26
  sd <- drop(sqrt(sd))  #WSD.POOLED
27
  offset <- quantile(sd, offset.percent/100)
28
  sd <- sd + offset
29
30
  centroid.overall <- drop(x %*% rep(1/n, n))   ## WMEAN
31
  
32
  se.scale <- sqrt(1/n.class - 1/n)  # mk
33
  
34
  delta <- (centroids - centroid.overall)/sd
35
  delta <- scale(delta, FALSE, se.scale)    ##dik
36
  
37
  threshold <- seq(0, max(abs(delta)), length = n.threshold)
38
39
  nonzero <- seq(n.threshold)
40
  errors <- threshold
41
  yhat <- as.list(seq(n.threshold))
42
  prob <- array(0, c(ntest, K, n.threshold))
43
  
44
  for (ii in 1:n.threshold) {
45
    cat(ii)
46
    delta.shrunk <- soft.shrink(delta, threshold[ii])
47
    #delta.shrunk <- scale(delta.shrunk, FALSE, 1/(se.scale))
48
    delta.shrunk <- t(t(delta.shrunk) * as.numeric(se.scale))
49
    
50
    nonzero[ii] <- attr(delta.shrunk, "nonzero")
51
    posid <- drop(abs(delta.shrunk) %*% rep(1, K)) > 0
52
    dd <- diag.disc((xtest - centroid.overall)/sd, delta.shrunk, 
53
                    prior, weight = posid)
54
    yhat[[ii]] <- softmax(dd)
55
    dd <- safe.exp(dd)
56
    prob[, , ii] <- dd/drop(dd %*% rep(1, K))
57
    if (!is.null(ytest)) {
58
      errors[ii] <- sum(yhat[[ii]] != ytest)
59
    }
60
  }
61
  thresh.names <- format(round(threshold, 3))
62
  names(yhat) <- thresh.names
63
  attr(yhat, "row.names") <- paste(seq(ntest))
64
  class(yhat) <- "data.frame"
65
  if (remove.zeros) 
66
    n.threshold <- match(0, nonzero, n.threshold)
67
  dimnames(prob) <- list(paste(seq(ntest)), names(n.class), 
68
                         thresh.names)
69
  object <- list(y = ytest, yhat = yhat, prob = prob[, , seq(n.threshold)], 
70
                 centroids = centroids, centroid.overall = centroid.overall, 
71
                 sd = sd, threshold = threshold[seq(n.threshold)], nonzero = nonzero[seq(n.threshold)], 
72
                 se.scale = se.scale, call = this.call, prior = prior, offset = offset)
73
  if (!is.null(ytest)) 
74
    object$errors <- errors
75
  #class(object) <- "nsc"
76
  object
77
}