|
a |
|
b/tests/testthat/test-hsstan.R |
|
|
1 |
test_that("hsstan", |
|
|
2 |
{ |
|
|
3 |
expect_s3_class(hs.gauss, |
|
|
4 |
"hsstan") |
|
|
5 |
expect_s4_class(hs.gauss$stanfit, |
|
|
6 |
"stanfit") |
|
|
7 |
expect_named(hs.gauss, |
|
|
8 |
c("stanfit", "betas", "call", "data", "model.terms", "family", |
|
|
9 |
"hsstan.settings")) |
|
|
10 |
expect_false("lambda[1]" %in% names(hs.gauss$stanfit)) |
|
|
11 |
expect_equal(hs.gauss$family, |
|
|
12 |
gaussian()) |
|
|
13 |
expect_named(hs.gauss$betas$unpenalized, |
|
|
14 |
c("(Intercept)", "X1b", "X1c", "X2", "X3")) |
|
|
15 |
expect_named(hs.gauss$betas$penalized, |
|
|
16 |
hs.gauss$model.terms$penalized) |
|
|
17 |
expect_length(hs.gauss$betas, 2) |
|
|
18 |
expect_named(hs.gauss$hsstan.settings, |
|
|
19 |
c("adapt.delta", "qr", "seed", "scale.u", "regularized", "nu", |
|
|
20 |
"par.ratio", "global.scale", "global.df", "slab.scale", |
|
|
21 |
"slab.df")) |
|
|
22 |
expect_equal(hs.gauss$hsstan.settings$global.scale, |
|
|
23 |
0.007071068) |
|
|
24 |
expect_equal(hs.gauss$hsstan.settings$adapt.delta, |
|
|
25 |
0.99) |
|
|
26 |
}) |
|
|
27 |
|
|
|
28 |
test_that("hsstan with no penalized predictors", |
|
|
29 |
{ |
|
|
30 |
expect_null(hs.base$betas$penalized) |
|
|
31 |
expect_length(hs.base$penalized, |
|
|
32 |
0) |
|
|
33 |
expect_named(hs.base$hsstan.settings, |
|
|
34 |
c("adapt.delta", "qr", "seed", "scale.u")) |
|
|
35 |
expect_equal(hs.base$hsstan.settings$adapt.delta, |
|
|
36 |
0.95) |
|
|
37 |
}) |
|
|
38 |
|
|
|
39 |
test_that("hsstan handles categorical variables in the penalized predictors", |
|
|
40 |
{ |
|
|
41 |
SW({ |
|
|
42 |
hs.1 <- hsstan(df, y.gauss ~ X2 + X3, "X1", iter=250, |
|
|
43 |
keep.hs.pars=TRUE, refresh=0) |
|
|
44 |
}) |
|
|
45 |
expect_named(hs.1$betas$unpenalized, |
|
|
46 |
c("(Intercept)", "X2", "X3")) |
|
|
47 |
expect_named(hs.1$betas$penalized, |
|
|
48 |
c("X1b", "X1c")) |
|
|
49 |
}) |
|
|
50 |
|
|
|
51 |
test_that("hsstan handles penalized predictors appearing in the formula", |
|
|
52 |
{ |
|
|
53 |
SW({ |
|
|
54 |
hs.1 <- hsstan(df, y.gauss ~ X1 + X2 + X3, "X2", iter=250, |
|
|
55 |
keep.hs.pars=TRUE, refresh=0) |
|
|
56 |
hs.2 <- hsstan(df, y.gauss ~ X1 + X3, "X2", iter=250, |
|
|
57 |
keep.hs.pars=TRUE, refresh=0) |
|
|
58 |
}) |
|
|
59 |
for (val in c("beta", "data", "model.terms")) |
|
|
60 |
expect_equal(hs.1[[val]], |
|
|
61 |
hs.2[[val]]) |
|
|
62 |
}) |
|
|
63 |
|
|
|
64 |
test_that("hsstan handles interaction terms correctly", |
|
|
65 |
{ |
|
|
66 |
SW({ |
|
|
67 |
hs.int.2 <- hs(y.gauss ~ X1 + X3 + X2 + X1b_X3 + X1c_X3 + X3_X2, gaussian) |
|
|
68 |
}) |
|
|
69 |
expect_equal(names(hs.inter$betas$unpenalized), |
|
|
70 |
gsub("_", ":", names(hs.int.2$betas$unpenalized))) |
|
|
71 |
expect_equivalent(hs.inter$betas$unpenalized, |
|
|
72 |
hs.int.2$betas$unpenalized) |
|
|
73 |
expect_equal(hs.inter$betas$penalized, |
|
|
74 |
hs.int.2$betas$penalized) |
|
|
75 |
}) |
|
|
76 |
|
|
|
77 |
test_that("hsstan handles interaction term without main effects", |
|
|
78 |
{ |
|
|
79 |
SW({ |
|
|
80 |
hs.int.0 <- hsstan(df, y.gauss ~ X1:X3, iter=200, refresh=0) |
|
|
81 |
}) |
|
|
82 |
expect_named(hs.int.0$betas$unpenalized, |
|
|
83 |
c("(Intercept)", "X1a:X3", "X1b:X3", "X1c:X3")) |
|
|
84 |
}) |
|
|
85 |
|
|
|
86 |
test_that("hsstan doesn't use the QR decomposition if P > N", |
|
|
87 |
{ |
|
|
88 |
SW({ |
|
|
89 |
hs.noqr <- hsstan(df[1:5, ], mod.gauss, pen, iter=100, qr=TRUE, |
|
|
90 |
keep.hs.pars=TRUE, refresh=0) |
|
|
91 |
}) |
|
|
92 |
expect_false(hs.noqr$hsstan.settings$qr) |
|
|
93 |
expect_match(names(hs.noqr$stanfit), |
|
|
94 |
"lambda", all=FALSE) |
|
|
95 |
expect_match(names(hs.noqr$stanfit), |
|
|
96 |
"tau", all=FALSE) |
|
|
97 |
}) |
|
|
98 |
|
|
|
99 |
test_that("kfold", |
|
|
100 |
{ |
|
|
101 |
expect_s3_class(cv.gauss, |
|
|
102 |
c("kfold", "loo")) |
|
|
103 |
expect_output(print(cv.gauss), |
|
|
104 |
"Based on 2-fold cross-validation") |
|
|
105 |
expect_named(cv.gauss, |
|
|
106 |
c("estimates", "pointwise", "fits", "data")) |
|
|
107 |
expect_equal(rownames(cv.gauss$estimates), |
|
|
108 |
c("elpd_kfold", "p_kfold", "kfoldic")) |
|
|
109 |
expect_equal(colnames(cv.gauss$estimates), |
|
|
110 |
c("Estimate", "SE")) |
|
|
111 |
expect_equal(nrow(cv.gauss$pointwise), |
|
|
112 |
N) |
|
|
113 |
expect_equal(colnames(cv.gauss$pointwise), |
|
|
114 |
c("elpd_kfold", "p_kfold", "kfoldic")) |
|
|
115 |
expect_true(all(is.na(cv.gauss$pointwise[, "p_kfold"]))) |
|
|
116 |
expect_length(cv.gauss$fits[[1]]$stanfit@stan_args, |
|
|
117 |
2) |
|
|
118 |
|
|
|
119 |
expect_named(cv.binom, |
|
|
120 |
c("estimates", "pointwise", "fits", "data")) |
|
|
121 |
for (i in 1:max(folds)) |
|
|
122 |
expect_s3_class(cv.binom$fits[[i]], |
|
|
123 |
"hsstan") |
|
|
124 |
expect_silent(validate.samples(cv.binom$fits[[1]])) |
|
|
125 |
expect_equal(nrow(cv.binom$fits), |
|
|
126 |
2) |
|
|
127 |
expect_length(cv.binom$fits[[1]]$stanfit@stan_args, |
|
|
128 |
1) |
|
|
129 |
expect_length(cv.binom$fits, |
|
|
130 |
max(folds) * 2) |
|
|
131 |
expect_equal(cv.binom$fits[[max(folds) + 1]], |
|
|
132 |
which(folds == 1)) |
|
|
133 |
}) |
|
|
134 |
|
|
|
135 |
test_that("kfold with store.fits=FALSE", |
|
|
136 |
{ |
|
|
137 |
expect_named(cv.nofit, |
|
|
138 |
c("estimates", "pointwise")) |
|
|
139 |
}) |
|
|
140 |
|
|
|
141 |
test_that("hsstan with invalid inputs", |
|
|
142 |
{ |
|
|
143 |
expect_error(hsstan(df, mod.gauss, adapt.delta=1), |
|
|
144 |
"'adapt.delta' must be less than 1") |
|
|
145 |
|
|
|
146 |
expect_error(hsstan(df, mod.gauss, iter=0), |
|
|
147 |
"'iter' must be a positive integer") |
|
|
148 |
expect_error(hsstan(df, mod.gauss, iter=-1), |
|
|
149 |
"'iter' must be a positive integer") |
|
|
150 |
|
|
|
151 |
expect_error(hsstan(df, mod.gauss, warmup=0), |
|
|
152 |
"'warmup' must be a positive integer") |
|
|
153 |
expect_error(hsstan(df, mod.gauss, warmup=-1), |
|
|
154 |
"'warmup' must be a positive integer") |
|
|
155 |
|
|
|
156 |
expect_error(hsstan(df, mod.gauss, iter=1000, warmup=1000), |
|
|
157 |
"'warmup' must be smaller than 'iter'") |
|
|
158 |
|
|
|
159 |
expect_error(hsstan(df, mod.gauss, chains=0), |
|
|
160 |
"rstan::sampling failed") |
|
|
161 |
}) |
|
|
162 |
|
|
|
163 |
test_that("kfold with invalid inputs", |
|
|
164 |
{ |
|
|
165 |
expect_error(kfold(hs.gauss, folds, chains=0), |
|
|
166 |
"'chains' must be a positive integer") |
|
|
167 |
expect_error(kfold(hs.gauss, folds, chains=4.4), |
|
|
168 |
"'chains' must be a positive integer") |
|
|
169 |
}) |