Switch to unified view

a b/tests/testthat/test-summaries.R
1
test_that("summary.hsstan",
2
{
3
    expect_error(summary(hs.gauss, prob=c(0.2, 0.8)),
4
                 "'prob' must be a single value between 0 and 1")
5
    expect_error(summary(hs.gauss, prob=1),
6
                 "'prob' must be a single value between 0 and 1")
7
8
    out <- summary(hs.gauss)
9
    expect_is(out, "matrix")
10
    expect_equal(colnames(out),
11
                 c("mean", "sd", "2.5%", "97.5%", "n_eff", "Rhat"))
12
    expect_equal(nrow(out),
13
                 P + 1 + 1) # intercept and extra factor level for X1
14
15
    out <- summary(hs.gauss, pars="X1")
16
    expect_equal(rownames(out),
17
                 c("X1b", "X1c", "X10"))
18
19
    expect_equal(summary(hs.gauss, max.rows=0),
20
                 summary(hs.gauss))
21
    expect_equal(nrow(summary(hs.gauss, max.rows=5)), 5)
22
23
    out <- summary(hs.gauss, sort="n_eff", decreasing=FALSE)
24
    expect_true(all(diff(out[, "n_eff"]) > 0))
25
})
26
27
test_that("print.hsstan",
28
{
29
    expect_output(print(hs.gauss))
30
})
31
32
test_that("posterior_summary",
33
{
34
    out <- posterior_summary(1:100)
35
    expect_is(out,
36
              "matrix")
37
    expect_equal(nrow(out), 1)
38
    expect_equal(colnames(out),
39
                 c("mean", "sd", "2.5%", "97.5%"))
40
    expect_equal(as.numeric(out),
41
                 c(50.5, 29.01149198, 3.475, 97.525))
42
43
    out <- posterior_summary(hs.binom)
44
    expect_equal(rownames(out),
45
                 names(c(hs.binom$betas$unpenalized, hs.binom$betas$penalized)))
46
    expect_equal(out[, "mean"],
47
                 c(hs.binom$betas$unpenalized, hs.binom$betas$penalized))
48
})
49
50
test_that("sampler.stats",
51
{
52
    out <- sampler.stats(hs.gauss)
53
    expect_is(out, "matrix")
54
    expect_equal(colnames(out),
55
                 c("accept.stat", "stepsize", "divergences", "treedepth",
56
                   "gradients", "warmup", "sample"))
57
    expect_equal(rownames(out),
58
                 c(paste0("chain:", 1:chains), "all"))
59
})
60
61
test_that("nsamples",
62
{
63
    expect_equal(nsamples(hs.gauss), iters * chains / 2)
64
})