Switch to unified view

a b/tests/testthat/test-hsstan.R
1
test_that("hsstan",
2
{
3
    expect_s3_class(hs.gauss,
4
                    "hsstan")
5
    expect_s4_class(hs.gauss$stanfit,
6
                    "stanfit")
7
    expect_named(hs.gauss,
8
                 c("stanfit", "betas", "call", "data", "model.terms", "family",
9
                   "hsstan.settings"))
10
    expect_false("lambda[1]" %in% names(hs.gauss$stanfit))
11
    expect_equal(hs.gauss$family,
12
                 gaussian())
13
    expect_named(hs.gauss$betas$unpenalized,
14
                 c("(Intercept)", "X1b", "X1c", "X2", "X3"))
15
    expect_named(hs.gauss$betas$penalized,
16
                 hs.gauss$model.terms$penalized)
17
    expect_length(hs.gauss$betas, 2)
18
    expect_named(hs.gauss$hsstan.settings,
19
                 c("adapt.delta", "qr", "seed", "scale.u", "regularized", "nu",
20
                   "par.ratio", "global.scale", "global.df", "slab.scale",
21
                   "slab.df"))
22
    expect_equal(hs.gauss$hsstan.settings$global.scale,
23
                 0.007071068)
24
    expect_equal(hs.gauss$hsstan.settings$adapt.delta,
25
                 0.99)
26
})
27
28
test_that("hsstan with no penalized predictors",
29
{
30
    expect_null(hs.base$betas$penalized)
31
    expect_length(hs.base$penalized,
32
                  0)
33
    expect_named(hs.base$hsstan.settings,
34
                 c("adapt.delta", "qr", "seed", "scale.u"))
35
    expect_equal(hs.base$hsstan.settings$adapt.delta,
36
                 0.95)
37
})
38
39
test_that("hsstan handles categorical variables in the penalized predictors",
40
{
41
    SW({
42
        hs.1 <- hsstan(df, y.gauss ~ X2 + X3, "X1", iter=250,
43
                       keep.hs.pars=TRUE, refresh=0)
44
    })
45
    expect_named(hs.1$betas$unpenalized,
46
                 c("(Intercept)", "X2", "X3"))
47
    expect_named(hs.1$betas$penalized,
48
                 c("X1b", "X1c"))
49
})
50
51
test_that("hsstan handles penalized predictors appearing in the formula",
52
{
53
    SW({
54
        hs.1 <- hsstan(df, y.gauss ~ X1 + X2 + X3, "X2", iter=250,
55
                       keep.hs.pars=TRUE, refresh=0)
56
        hs.2 <- hsstan(df, y.gauss ~ X1 + X3, "X2", iter=250,
57
                       keep.hs.pars=TRUE, refresh=0)
58
    })
59
    for (val in c("beta", "data", "model.terms"))
60
        expect_equal(hs.1[[val]],
61
                     hs.2[[val]])
62
})
63
64
test_that("hsstan handles interaction terms correctly",
65
{
66
    SW({
67
        hs.int.2 <- hs(y.gauss ~ X1 + X3 + X2 + X1b_X3 + X1c_X3 + X3_X2, gaussian)
68
    })
69
    expect_equal(names(hs.inter$betas$unpenalized),
70
                 gsub("_", ":", names(hs.int.2$betas$unpenalized)))
71
    expect_equivalent(hs.inter$betas$unpenalized,
72
                      hs.int.2$betas$unpenalized)
73
    expect_equal(hs.inter$betas$penalized,
74
                 hs.int.2$betas$penalized)
75
})
76
77
test_that("hsstan handles interaction term without main effects",
78
{
79
    SW({
80
        hs.int.0 <- hsstan(df, y.gauss ~ X1:X3, iter=200, refresh=0)
81
    })
82
    expect_named(hs.int.0$betas$unpenalized,
83
                 c("(Intercept)", "X1a:X3", "X1b:X3", "X1c:X3"))
84
})
85
86
test_that("hsstan doesn't use the QR decomposition if P > N",
87
{
88
    SW({
89
        hs.noqr <- hsstan(df[1:5, ], mod.gauss, pen, iter=100, qr=TRUE,
90
                          keep.hs.pars=TRUE, refresh=0)
91
    })
92
    expect_false(hs.noqr$hsstan.settings$qr)
93
    expect_match(names(hs.noqr$stanfit),
94
                 "lambda", all=FALSE)
95
    expect_match(names(hs.noqr$stanfit),
96
                 "tau", all=FALSE)
97
})
98
99
test_that("kfold",
100
{
101
    expect_s3_class(cv.gauss,
102
                    c("kfold", "loo"))
103
    expect_output(print(cv.gauss),
104
                  "Based on 2-fold cross-validation")
105
    expect_named(cv.gauss,
106
                 c("estimates", "pointwise", "fits", "data"))
107
    expect_equal(rownames(cv.gauss$estimates),
108
                 c("elpd_kfold", "p_kfold", "kfoldic"))
109
    expect_equal(colnames(cv.gauss$estimates),
110
                 c("Estimate", "SE"))
111
    expect_equal(nrow(cv.gauss$pointwise),
112
                 N)
113
    expect_equal(colnames(cv.gauss$pointwise),
114
                 c("elpd_kfold", "p_kfold", "kfoldic"))
115
    expect_true(all(is.na(cv.gauss$pointwise[, "p_kfold"])))
116
    expect_length(cv.gauss$fits[[1]]$stanfit@stan_args,
117
                  2)
118
119
    expect_named(cv.binom,
120
                 c("estimates", "pointwise", "fits", "data"))
121
    for (i in 1:max(folds))
122
        expect_s3_class(cv.binom$fits[[i]],
123
                        "hsstan")
124
    expect_silent(validate.samples(cv.binom$fits[[1]]))
125
    expect_equal(nrow(cv.binom$fits),
126
                 2)
127
    expect_length(cv.binom$fits[[1]]$stanfit@stan_args,
128
                  1)
129
    expect_length(cv.binom$fits,
130
                  max(folds) * 2)
131
    expect_equal(cv.binom$fits[[max(folds) + 1]],
132
                 which(folds == 1))
133
})
134
135
test_that("kfold with store.fits=FALSE",
136
{
137
    expect_named(cv.nofit,
138
                 c("estimates", "pointwise"))
139
})
140
141
test_that("hsstan with invalid inputs",
142
{
143
    expect_error(hsstan(df, mod.gauss, adapt.delta=1),
144
                 "'adapt.delta' must be less than 1")
145
146
    expect_error(hsstan(df, mod.gauss, iter=0),
147
                 "'iter' must be a positive integer")
148
    expect_error(hsstan(df, mod.gauss, iter=-1),
149
                 "'iter' must be a positive integer")
150
151
    expect_error(hsstan(df, mod.gauss, warmup=0),
152
                 "'warmup' must be a positive integer")
153
    expect_error(hsstan(df, mod.gauss, warmup=-1),
154
                 "'warmup' must be a positive integer")
155
156
    expect_error(hsstan(df, mod.gauss, iter=1000, warmup=1000),
157
                 "'warmup' must be smaller than 'iter'")
158
159
    expect_error(hsstan(df, mod.gauss, chains=0),
160
                 "rstan::sampling failed")
161
})
162
163
test_that("kfold with invalid inputs",
164
{
165
    expect_error(kfold(hs.gauss, folds, chains=0),
166
                 "'chains' must be a positive integer")
167
    expect_error(kfold(hs.gauss, folds, chains=4.4),
168
                 "'chains' must be a positive integer")
169
})