|
a |
|
b/tests/testthat/setup.R |
|
|
1 |
## silence output and warnings |
|
|
2 |
SW <- function(expr) suppressMessages(suppressWarnings(expr)) |
|
|
3 |
|
|
|
4 |
## dataset |
|
|
5 |
set.seed(1) |
|
|
6 |
N <- 50 |
|
|
7 |
P <- 10 |
|
|
8 |
U <- 3 |
|
|
9 |
x <- matrix(rnorm(N * P), nrow=N, ncol=P) |
|
|
10 |
b <- runif(P) - 0.5 |
|
|
11 |
y.gauss <- rnorm(N, mean=x %*% b, sd=runif(1, 1, 2)) |
|
|
12 |
y.binom <- factor(rbinom(N, 1, binomial()$linkinv(x %*% b))) |
|
|
13 |
df <- data.frame(x, y.gauss=y.gauss, y.binom=y.binom) |
|
|
14 |
df[, 1] <- factor(letters[rbinom(N, 2, 0.5) + 1]) |
|
|
15 |
df$X1b_X3 <- df$X3 * (df$X1 == "b") |
|
|
16 |
df$X1c_X3 <- df$X3 * (df$X1 == "c") |
|
|
17 |
df$X3_X2 <- df$X3 * df$X2 |
|
|
18 |
|
|
|
19 |
## model options |
|
|
20 |
unp <- paste0("X", 1:U) |
|
|
21 |
pen <- setdiff(paste0("X", 1:P), unp) |
|
|
22 |
mod.gauss <- reformulate(unp, "y.gauss") |
|
|
23 |
mod.binom <- reformulate(unp, "y.binom") |
|
|
24 |
folds <- c(rep(1, N / 2), rep(2, N / 2)) |
|
|
25 |
iters <- 500 |
|
|
26 |
chains <- 2 |
|
|
27 |
|
|
|
28 |
## numerical tolerance |
|
|
29 |
tol <- 0.000001 |
|
|
30 |
|
|
|
31 |
## wrapper to set commonly used options |
|
|
32 |
hs <- function(model, family, ...) |
|
|
33 |
hsstan(df, model, pen, iter=iters, chains=chains, family=family, |
|
|
34 |
refresh=0, ...) |
|
|
35 |
|
|
|
36 |
message("Running hsstan models...") |
|
|
37 |
SW({ |
|
|
38 |
hs.base <- hsstan(df, mod.gauss, iter=iters, chains=chains, refresh=0) |
|
|
39 |
hs.gauss <- hs(mod.gauss, gaussian) |
|
|
40 |
hs.binom <- hs(mod.binom, binomial) |
|
|
41 |
hs.inter <- hs(y.gauss ~ X1 * X3 + X2 * X3, gaussian) |
|
|
42 |
}) |
|
|
43 |
|
|
|
44 |
message("Running cross-validated hsstan models...") |
|
|
45 |
SW({ |
|
|
46 |
cv.gauss <- kfold(hs.gauss, folds=folds, chains=2) |
|
|
47 |
cv.binom <- kfold(hs.binom, folds=folds) |
|
|
48 |
cv.nofit <- cv.gauss |
|
|
49 |
cv.nofit$fits <- cv.nofit$data <- NULL |
|
|
50 |
}) |