Diff of /R/stan.R [000000] .. [ede2d4]

Switch to unified view

a b/R/stan.R
1
##=============================================================================
2
##
3
## Copyright (c) 2017-2019 Marco Colombo and Paul McKeigue
4
##
5
## kfold.hsstan() is based on code from https://github.com/stan-dev/rstanarm
6
## Portions copyright (C) 2015, 2016, 2017 Trustees of Columbia University
7
##
8
## This program is free software: you can redistribute it and/or modify
9
## it under the terms of the GNU General Public License as published by
10
## the Free Software Foundation, either version 3 of the License, or
11
## (at your option) any later version.
12
##
13
## This program is distributed in the hope that it will be useful,
14
## but WITHOUT ANY WARRANTY; without even the implied warranty of
15
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16
## GNU General Public License for more details.
17
##
18
## You should have received a copy of the GNU General Public License
19
## along with this program.  If not, see <http://www.gnu.org/licenses/>.
20
##
21
##=============================================================================
22
23
24
#' Hierarchical shrinkage models
25
#'
26
#' Run the No-U-Turn Sampler (NUTS) as implemented in Stan to fit a hierarchical
27
#' shrinkage model.
28
#'
29
#' @param x Data frame containing outcome, covariates and penalized predictors.
30
#'        Continuous predictors and outcome variable should be standardized
31
#'        before fitting the models as priors assume them to have mean zero and
32
#'        unit variance.
33
#' @param covs.model Formula containing the unpenalized covariates.
34
#' @param penalized Names of the variables to be used as penalized predictors.
35
#'        Any variable that is already part of the `covs.model` formula will be
36
#'        penalized. If `NULL` or an empty vector, a model with only unpenalized
37
#'        covariates is fitted.
38
#' @param family Type of model fitted: either `gaussian()` for linear regression
39
#'        (default) or `binomial()` for logistic regression.
40
#' @param seed Optional integer defining the seed for the pseudo-random number
41
#'        generator.
42
#' @param qr Whether the thin QR decomposition should be used to decorrelate the
43
#'        predictors (`TRUE` by default). This is silently set to `FALSE` if
44
#'        there are more predictors than observations.
45
#' @param adapt.delta Target average proposal acceptance probability for
46
#'        adaptation, a value between 0.8 and 1 (excluded). If unspecified,
47
#'        it's set to 0.99 for hierarchical shrinkage models and to 0.95 for
48
#'        base models.
49
#' @param iter Total number of iterations in each chain, including warmup
50
#'        (2000 by default).
51
#' @param warmup Number of warmup iterations per chain (by default, half the
52
#'        total number of iterations).
53
#' @param scale.u Prior scale (standard deviation) for the unpenalized
54
#'        covariates.
55
#' @param regularized If `TRUE` (default), the regularized horseshoe prior
56
#'        is used as opposed to the original horseshoe prior.
57
#' @param nu Number of degrees of freedom of the half-Student-t prior on the
58
#'        local shrinkage parameters (by default, 1 if `regularized=TRUE`
59
#'        and 3 otherwise).
60
#' @param par.ratio Expected ratio of non-zero to zero coefficients (ignored
61
#'        if `regularized=FALSE`). The scale of the global shrinkage parameter
62
#'        corresponds to `par.ratio` divided by the square root of the number of
63
#'        observations; for linear regression only, it's further multiplied by
64
#'        the residual standard deviation `sigma`.
65
#' @param global.df Number of degrees of freedom for the global shrinkage
66
#'        parameter (ignored if `regularized=FALSE`). Larger values induce more
67
#'        shrinkage.
68
#' @param slab.scale Scale of the regularization parameter (ignored if
69
#'        `regularized=FALSE`).
70
#' @param slab.df Number of degrees of freedom of the regularization parameter
71
#'        (ignored if `regularized=FALSE`).
72
#' @param keep.hs.pars Whether the parameters for the horseshoe prior should be
73
#'        kept in the `stanfit` object returned (`FALSE` by default).
74
#' @param ... Further arguments passed to [rstan::sampling()],
75
#'        such as `chains` (4 by default), `cores` (the value of
76
#'        `options("mc.cores")` by default), `refresh` (`iter / 10` by default).
77
#'
78
#' @return
79
#' An object of class `hsstan` containing the following fields:
80
#' \item{stanfit}{an object of class `stanfit` containing the output
81
#'       produced by Stan, including posterior samples and diagnostic summaries.
82
#'       It can be manipulated using methods from the **rstan** package.}
83
#' \item{betas}{posterior means of the unpenalized and penalized regression
84
#'       parameters.}
85
#' \item{call}{the matched call.}
86
#' \item{data}{the dataset used in fitting the model.}
87
#' \item{model.terms}{a list of names for the outcome variable, the unpenalized
88
#'       covariates and the penalized predictors.}
89
#' \item{family}{the `family` object used.}
90
#' \item{hsstan.settings}{the optional settings used in the model.}
91
#'
92
#' @seealso
93
#' [kfold()] for cross-validating a fitted object.
94
#'
95
#' @examples
96
#' \dontshow{oldopts <- options(mc.cores=2)}
97
#' data(diabetes)
98
#'
99
#' # non-default settings for speed of the example
100
#' df <- diabetes[1:50, ]
101
#' hs.biom <- hsstan(df, Y ~ age + sex, penalized=colnames(df)[5:10],
102
#'                   chains=2, iter=250)
103
#' \dontshow{options(oldopts)}
104
#'
105
#' @importFrom stats gaussian
106
#' @export
107
hsstan <- function(x, covs.model, penalized=NULL, family=gaussian,
108
                   iter=2000, warmup=floor(iter / 2),
109
                   scale.u=2, regularized=TRUE, nu=ifelse(regularized, 1, 3),
110
                   par.ratio=0.05, global.df=1, slab.scale=2, slab.df=4,
111
                   qr=TRUE, seed=123, adapt.delta=NULL,
112
                   keep.hs.pars=FALSE, ...) {
113
114
    model.terms <- validate.model(covs.model, penalized)
115
    x <- validate.data(x, model.terms)
116
    y <- x[[model.terms$outcome]]
117
    family <- validate.family(family, y)
118
    regularized <- as.integer(regularized)
119
    validate.positive.scalar(iter, "iter", int=TRUE)
120
    validate.positive.scalar(warmup, "warmup", int=TRUE)
121
    if (iter <= warmup)
122
        stop("'warmup' must be smaller than 'iter'.", call.=FALSE)
123
124
    ## stop if options to be passed to rstan::sampling are not valid, as
125
    ## to work around rstan issue #681
126
    validate.rstan.args(...)
127
128
    ## parameter names not to include by default in the stanfit object
129
    hs.pars <- c("lambda", "tau", "z", "c2")
130
    if (keep.hs.pars)
131
        hs.pars <- NA
132
133
    ## retrieve the call and its actual argument values
134
    call <- match.call(expand.dots=TRUE)
135
    args <- c(as.list(environment()), list(...))
136
    for (nm in names(call)[-c(1:2)]) # exclude "" and "x"
137
        call[[nm]] <- args[[nm]]
138
139
    ## choose the model to be fitted
140
    model <- ifelse(length(penalized) == 0, "base0", "hs")
141
    if (family$family == "binomial") model <- paste0(model, "_logit")
142
143
    ## set or check adapt.delta
144
    if (is.null(adapt.delta)) {
145
        adapt.delta <- ifelse(grepl("hs", model), 0.99, 0.95)
146
    } else {
147
        validate.adapt.delta(adapt.delta)
148
    }
149
150
    ## create the design matrix
151
    X <- ordered.model.matrix(x, model.terms$unpenalized, model.terms$penalized)
152
    N <- nrow(X)
153
    P <- ncol(X)
154
155
    ## number of penalized and unpenalized columns in the design matrix
156
    K <- length(expand.terms(x, model.terms$penalized)[-1])
157
    U <- P - K
158
159
    ## thin QR decomposition
160
    if (P > N) qr <- FALSE
161
    if (qr) {
162
        qr.dec <- qr(X)
163
        Q.qr <- qr.Q(qr.dec)
164
        R.inv <- qr.solve(qr.dec, Q.qr) * sqrt(N - 1)
165
        Q.qr <- Q.qr * sqrt(N - 1)
166
    }
167
168
    ## global scale for regularized horseshoe prior
169
    global.scale <- if (regularized) par.ratio / sqrt(N) else 1
170
171
    ## core block
172
    {
173
        ## parameters not used by a model are ignored
174
        data.input <- list(X=if (qr) Q.qr else X, y=y, N=N,
175
                           P=P, U=U, scale_u=scale.u,
176
                           regularized=regularized, nu=nu,
177
                           global_scale=global.scale, global_df=global.df,
178
                           slab_scale=slab.scale, slab_df=slab.df)
179
180
        ## run the stan model
181
        samples <- rstan::sampling(stanmodels[[model]], data=data.input,
182
                                   iter=iter, warmup=warmup, seed=seed, ...,
183
                                   pars=hs.pars, include=keep.hs.pars,
184
                                   control=list(adapt_delta=adapt.delta))
185
        if (is.na(nrow(samples)))
186
            stop("rstan::sampling failed, see error message above.", call.=FALSE)
187
188
        ## assign proper names
189
        par.idx <- grep("^beta_[up]", names(samples))
190
        stopifnot(length(par.idx) == ncol(X))
191
        names(samples)[par.idx] <- colnames(X)
192
193
        if (qr) {
194
            pars <- grep("beta_", samples@sim$pars_oi, value=TRUE)
195
            stopifnot(pars[1] == "beta_u")
196
            beta.tilde <- rstan::extract(samples, pars=pars,
197
                                         inc_warmup=TRUE, permuted=FALSE)
198
            B <- apply(beta.tilde, 1:2, FUN=function(z) R.inv %*% z)
199
            chains <- ncol(beta.tilde)
200
            for (chain in 1:chains) {
201
                for (p in 1:P)
202
                    samples@sim$samples[[chain]][[par.idx[p]]] <- B[p, , chain]
203
            }
204
        }
205
206
        ## store the hierarchical shrinkage settings
207
        opts <- list(adapt.delta=adapt.delta, qr=qr, seed=seed, scale.u=scale.u)
208
        if (K > 0)
209
            opts <- c(opts, regularized=regularized, nu=nu, par.ratio=par.ratio,
210
                      global.scale=global.scale, global.df=global.df,
211
                      slab.scale=slab.scale, slab.df=slab.df)
212
213
        ## compute the posterior means of the regression coefficients
214
        betas <- list(unpenalized=colMeans(as.matrix(samples, pars="beta_u")),
215
                      penalized=tryCatch(colMeans(as.matrix(samples,
216
                                                            pars="beta_p")),
217
                                         error=function(e) NULL))
218
        obj <- list(stanfit=samples, betas=betas, call=call, data=x,
219
                    model.terms=model.terms, family=family, hsstan.settings=opts)
220
        class(obj) <- "hsstan"
221
    }
222
223
    return(obj)
224
}
225
226
#' K-fold cross-validation
227
#'
228
#' Perform K-fold cross-validation using the same settings used when fitting
229
#' the model on the whole data.
230
#'
231
#' @param x An object of class `hsstan`.
232
#' @param folds Integer vector with one element per observation indicating the
233
#'        cross-validation fold in which the observation should be withdrawn.
234
#' @param chains Number of Markov chains to run. By default this is set to 1,
235
#'        independently of the number of chains used for `x`.
236
#' @param store.fits Whether the fitted models for each fold should be stored
237
#'        in the returned object (`TRUE` by default).
238
#' @param cores Number of cores to use for parallelization (the value of
239
#'        `options("mc.cores")` by default). The cross-validation folds will
240
#'        be distributed to the available cores, and the Markov chains for each
241
#'        model will be run sequentially.
242
#' @param ... Further arguments passed to [rstan::sampling()].
243
#'
244
#' @return
245
#' An object with classes `kfold` and `loo` that has a similar structure as the
246
#' objects returned by [loo()] and [waic()] and is compatible with the
247
#' \code{\link[loo]{loo_compare}} function for
248
#' comparing models. The object contains the following fields:
249
#' \item{estimates}{a matrix containing point estimates and standard errors of
250
#'       the expected log pointwise predictive density ("elpd_kfold"),
251
#'       the effective number of parameters ("p_kfold", always `NA`) and the
252
#'       K-fold information criterion "kfoldic" (which is `-2 * elpd_kfold`,
253
#'       i.e., converted to the deviance scale).}
254
#' \item{pointwise}{a matrix containing the pointwise contributions of
255
#'       "elpd_kfold", "p_kfold" and "kfoldic".}
256
#' \item{fits}{a matrix with two columns and number of rows equal to the number
257
#'       of cross-validation folds. Column `fit` contains the fitted
258
#'       `hsstan` objects for each fold, and column `test.idx` contains
259
#'       the indices of the withdrawn observations for each fold. This is not
260
#'       present if `store.fits=FALSE`.}
261
#' \item{data}{the dataset used in fitting the model (before withdrawing
262
#'       observations). This is not present if `store.fits=FALSE`.}
263
#'
264
#' @examples
265
#' \donttest{
266
#' \dontshow{utils::example("hsstan", echo=FALSE)}
267
#' # continued from ?hsstan
268
#' # only 2 folds for speed of example
269
#' folds <- rep(1:2, length.out=length(df$Y))
270
#' cv.biom <- kfold(hs.biom, folds=folds, cores=2)
271
#' }
272
#'
273
#' @importFrom loo kfold
274
#' @method kfold hsstan
275
#' @aliases kfold
276
#' @export kfold
277
#' @export
278
kfold.hsstan <- function(x, folds, chains=1, store.fits=TRUE,
279
                         cores=getOption("mc.cores", 1), ...) {
280
    data <- x$data
281
    N <- nrow(data)
282
    folds <- validate.folds(folds, N)
283
    num.folds <- max(folds)
284
    validate.positive.scalar(chains, "chains", int=TRUE)
285
    validate.rstan.args(...)
286
287
    ## collect the list of calls to be evaluated in parallel
288
    calls <- list()
289
    for (fold in 1:num.folds) {
290
        test.idx <- which(folds == fold)
291
        fit.call <- stats::update(object=x, x=data[-test.idx, , drop=FALSE],
292
                                  chains=chains, cores=1, refresh=0,
293
                                  open_progress=FALSE, evaluate=FALSE, ...)
294
        fit.call$x <- eval(fit.call$x)
295
        calls[[fold]] <- fit.call
296
    }
297
298
    ## evaluate the models
299
    message("Fitting ", num.folds, " models using ",
300
            min(cores, num.folds), " cores")
301
    par.fun <- function(fold) {
302
        fit <- eval(calls[[fold]])
303
304
        ## log pointwise predictive densities (pointwise test log-likelihood)
305
        lppd <- log_lik(fit, newdata=data[which(folds == fold), , drop=FALSE])
306
        return(list(lppd=lppd, fit=if (store.fits) fit else NULL))
307
    }
308
    if (.Platform$OS.type != "windows") {
309
        cv <- parallel::mclapply(X=1:num.folds, mc.cores=cores,
310
                                 mc.preschedule=FALSE, FUN=par.fun)
311
    } else { # windows
312
        cl <- parallel::makePSOCKcluster(cores)
313
        on.exit(parallel::stopCluster(cl))
314
        cv <- parallel::parLapply(X=1:num.folds, cl=cl, fun=par.fun)
315
    }
316
317
    ## expected log predictive densities
318
    elpds.unord <- unlist(lapply(cv, function(z) apply(z$lppd, 2, logMeanExp)))
319
    obs.idx <- unlist(lapply(1:num.folds, function(z) which(folds == z)))
320
    elpds <- elpds.unord[obs.idx]
321
322
    pointwise <- cbind(elpd_kfold=elpds, p_kfold=NA, kfoldic=-2 * elpds)
323
    estimates <- colSums(pointwise)
324
    se.est <- sqrt(N * apply(pointwise, 2, stats::var))
325
    out <- list(estimates=cbind(Estimate=estimates, SE=se.est),
326
                pointwise=pointwise)
327
    rownames(out$estimates) <- colnames(pointwise)
328
    if (store.fits) {
329
        fits <- array(list(), c(num.folds, 2), list(NULL, c("fit", "test.idx")))
330
        for (fold in 1:num.folds)
331
            fits[fold, ] <- list(fit=cv[[fold]][["fit"]],
332
                                 test.idx=which(folds == fold))
333
        out$fits <- fits
334
        out$data <- data
335
    }
336
    attr(out, "K") <- num.folds
337
    class(out) <- c("kfold", "loo")
338
    return(out)
339
}