|
a |
|
b/R/stan.R |
|
|
1 |
##============================================================================= |
|
|
2 |
## |
|
|
3 |
## Copyright (c) 2017-2019 Marco Colombo and Paul McKeigue |
|
|
4 |
## |
|
|
5 |
## kfold.hsstan() is based on code from https://github.com/stan-dev/rstanarm |
|
|
6 |
## Portions copyright (C) 2015, 2016, 2017 Trustees of Columbia University |
|
|
7 |
## |
|
|
8 |
## This program is free software: you can redistribute it and/or modify |
|
|
9 |
## it under the terms of the GNU General Public License as published by |
|
|
10 |
## the Free Software Foundation, either version 3 of the License, or |
|
|
11 |
## (at your option) any later version. |
|
|
12 |
## |
|
|
13 |
## This program is distributed in the hope that it will be useful, |
|
|
14 |
## but WITHOUT ANY WARRANTY; without even the implied warranty of |
|
|
15 |
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
|
|
16 |
## GNU General Public License for more details. |
|
|
17 |
## |
|
|
18 |
## You should have received a copy of the GNU General Public License |
|
|
19 |
## along with this program. If not, see <http://www.gnu.org/licenses/>. |
|
|
20 |
## |
|
|
21 |
##============================================================================= |
|
|
22 |
|
|
|
23 |
|
|
|
24 |
#' Hierarchical shrinkage models |
|
|
25 |
#' |
|
|
26 |
#' Run the No-U-Turn Sampler (NUTS) as implemented in Stan to fit a hierarchical |
|
|
27 |
#' shrinkage model. |
|
|
28 |
#' |
|
|
29 |
#' @param x Data frame containing outcome, covariates and penalized predictors. |
|
|
30 |
#' Continuous predictors and outcome variable should be standardized |
|
|
31 |
#' before fitting the models as priors assume them to have mean zero and |
|
|
32 |
#' unit variance. |
|
|
33 |
#' @param covs.model Formula containing the unpenalized covariates. |
|
|
34 |
#' @param penalized Names of the variables to be used as penalized predictors. |
|
|
35 |
#' Any variable that is already part of the `covs.model` formula will be |
|
|
36 |
#' penalized. If `NULL` or an empty vector, a model with only unpenalized |
|
|
37 |
#' covariates is fitted. |
|
|
38 |
#' @param family Type of model fitted: either `gaussian()` for linear regression |
|
|
39 |
#' (default) or `binomial()` for logistic regression. |
|
|
40 |
#' @param seed Optional integer defining the seed for the pseudo-random number |
|
|
41 |
#' generator. |
|
|
42 |
#' @param qr Whether the thin QR decomposition should be used to decorrelate the |
|
|
43 |
#' predictors (`TRUE` by default). This is silently set to `FALSE` if |
|
|
44 |
#' there are more predictors than observations. |
|
|
45 |
#' @param adapt.delta Target average proposal acceptance probability for |
|
|
46 |
#' adaptation, a value between 0.8 and 1 (excluded). If unspecified, |
|
|
47 |
#' it's set to 0.99 for hierarchical shrinkage models and to 0.95 for |
|
|
48 |
#' base models. |
|
|
49 |
#' @param iter Total number of iterations in each chain, including warmup |
|
|
50 |
#' (2000 by default). |
|
|
51 |
#' @param warmup Number of warmup iterations per chain (by default, half the |
|
|
52 |
#' total number of iterations). |
|
|
53 |
#' @param scale.u Prior scale (standard deviation) for the unpenalized |
|
|
54 |
#' covariates. |
|
|
55 |
#' @param regularized If `TRUE` (default), the regularized horseshoe prior |
|
|
56 |
#' is used as opposed to the original horseshoe prior. |
|
|
57 |
#' @param nu Number of degrees of freedom of the half-Student-t prior on the |
|
|
58 |
#' local shrinkage parameters (by default, 1 if `regularized=TRUE` |
|
|
59 |
#' and 3 otherwise). |
|
|
60 |
#' @param par.ratio Expected ratio of non-zero to zero coefficients (ignored |
|
|
61 |
#' if `regularized=FALSE`). The scale of the global shrinkage parameter |
|
|
62 |
#' corresponds to `par.ratio` divided by the square root of the number of |
|
|
63 |
#' observations; for linear regression only, it's further multiplied by |
|
|
64 |
#' the residual standard deviation `sigma`. |
|
|
65 |
#' @param global.df Number of degrees of freedom for the global shrinkage |
|
|
66 |
#' parameter (ignored if `regularized=FALSE`). Larger values induce more |
|
|
67 |
#' shrinkage. |
|
|
68 |
#' @param slab.scale Scale of the regularization parameter (ignored if |
|
|
69 |
#' `regularized=FALSE`). |
|
|
70 |
#' @param slab.df Number of degrees of freedom of the regularization parameter |
|
|
71 |
#' (ignored if `regularized=FALSE`). |
|
|
72 |
#' @param keep.hs.pars Whether the parameters for the horseshoe prior should be |
|
|
73 |
#' kept in the `stanfit` object returned (`FALSE` by default). |
|
|
74 |
#' @param ... Further arguments passed to [rstan::sampling()], |
|
|
75 |
#' such as `chains` (4 by default), `cores` (the value of |
|
|
76 |
#' `options("mc.cores")` by default), `refresh` (`iter / 10` by default). |
|
|
77 |
#' |
|
|
78 |
#' @return |
|
|
79 |
#' An object of class `hsstan` containing the following fields: |
|
|
80 |
#' \item{stanfit}{an object of class `stanfit` containing the output |
|
|
81 |
#' produced by Stan, including posterior samples and diagnostic summaries. |
|
|
82 |
#' It can be manipulated using methods from the **rstan** package.} |
|
|
83 |
#' \item{betas}{posterior means of the unpenalized and penalized regression |
|
|
84 |
#' parameters.} |
|
|
85 |
#' \item{call}{the matched call.} |
|
|
86 |
#' \item{data}{the dataset used in fitting the model.} |
|
|
87 |
#' \item{model.terms}{a list of names for the outcome variable, the unpenalized |
|
|
88 |
#' covariates and the penalized predictors.} |
|
|
89 |
#' \item{family}{the `family` object used.} |
|
|
90 |
#' \item{hsstan.settings}{the optional settings used in the model.} |
|
|
91 |
#' |
|
|
92 |
#' @seealso |
|
|
93 |
#' [kfold()] for cross-validating a fitted object. |
|
|
94 |
#' |
|
|
95 |
#' @examples |
|
|
96 |
#' \dontshow{oldopts <- options(mc.cores=2)} |
|
|
97 |
#' data(diabetes) |
|
|
98 |
#' |
|
|
99 |
#' # non-default settings for speed of the example |
|
|
100 |
#' df <- diabetes[1:50, ] |
|
|
101 |
#' hs.biom <- hsstan(df, Y ~ age + sex, penalized=colnames(df)[5:10], |
|
|
102 |
#' chains=2, iter=250) |
|
|
103 |
#' \dontshow{options(oldopts)} |
|
|
104 |
#' |
|
|
105 |
#' @importFrom stats gaussian |
|
|
106 |
#' @export |
|
|
107 |
hsstan <- function(x, covs.model, penalized=NULL, family=gaussian, |
|
|
108 |
iter=2000, warmup=floor(iter / 2), |
|
|
109 |
scale.u=2, regularized=TRUE, nu=ifelse(regularized, 1, 3), |
|
|
110 |
par.ratio=0.05, global.df=1, slab.scale=2, slab.df=4, |
|
|
111 |
qr=TRUE, seed=123, adapt.delta=NULL, |
|
|
112 |
keep.hs.pars=FALSE, ...) { |
|
|
113 |
|
|
|
114 |
model.terms <- validate.model(covs.model, penalized) |
|
|
115 |
x <- validate.data(x, model.terms) |
|
|
116 |
y <- x[[model.terms$outcome]] |
|
|
117 |
family <- validate.family(family, y) |
|
|
118 |
regularized <- as.integer(regularized) |
|
|
119 |
validate.positive.scalar(iter, "iter", int=TRUE) |
|
|
120 |
validate.positive.scalar(warmup, "warmup", int=TRUE) |
|
|
121 |
if (iter <= warmup) |
|
|
122 |
stop("'warmup' must be smaller than 'iter'.", call.=FALSE) |
|
|
123 |
|
|
|
124 |
## stop if options to be passed to rstan::sampling are not valid, as |
|
|
125 |
## to work around rstan issue #681 |
|
|
126 |
validate.rstan.args(...) |
|
|
127 |
|
|
|
128 |
## parameter names not to include by default in the stanfit object |
|
|
129 |
hs.pars <- c("lambda", "tau", "z", "c2") |
|
|
130 |
if (keep.hs.pars) |
|
|
131 |
hs.pars <- NA |
|
|
132 |
|
|
|
133 |
## retrieve the call and its actual argument values |
|
|
134 |
call <- match.call(expand.dots=TRUE) |
|
|
135 |
args <- c(as.list(environment()), list(...)) |
|
|
136 |
for (nm in names(call)[-c(1:2)]) # exclude "" and "x" |
|
|
137 |
call[[nm]] <- args[[nm]] |
|
|
138 |
|
|
|
139 |
## choose the model to be fitted |
|
|
140 |
model <- ifelse(length(penalized) == 0, "base0", "hs") |
|
|
141 |
if (family$family == "binomial") model <- paste0(model, "_logit") |
|
|
142 |
|
|
|
143 |
## set or check adapt.delta |
|
|
144 |
if (is.null(adapt.delta)) { |
|
|
145 |
adapt.delta <- ifelse(grepl("hs", model), 0.99, 0.95) |
|
|
146 |
} else { |
|
|
147 |
validate.adapt.delta(adapt.delta) |
|
|
148 |
} |
|
|
149 |
|
|
|
150 |
## create the design matrix |
|
|
151 |
X <- ordered.model.matrix(x, model.terms$unpenalized, model.terms$penalized) |
|
|
152 |
N <- nrow(X) |
|
|
153 |
P <- ncol(X) |
|
|
154 |
|
|
|
155 |
## number of penalized and unpenalized columns in the design matrix |
|
|
156 |
K <- length(expand.terms(x, model.terms$penalized)[-1]) |
|
|
157 |
U <- P - K |
|
|
158 |
|
|
|
159 |
## thin QR decomposition |
|
|
160 |
if (P > N) qr <- FALSE |
|
|
161 |
if (qr) { |
|
|
162 |
qr.dec <- qr(X) |
|
|
163 |
Q.qr <- qr.Q(qr.dec) |
|
|
164 |
R.inv <- qr.solve(qr.dec, Q.qr) * sqrt(N - 1) |
|
|
165 |
Q.qr <- Q.qr * sqrt(N - 1) |
|
|
166 |
} |
|
|
167 |
|
|
|
168 |
## global scale for regularized horseshoe prior |
|
|
169 |
global.scale <- if (regularized) par.ratio / sqrt(N) else 1 |
|
|
170 |
|
|
|
171 |
## core block |
|
|
172 |
{ |
|
|
173 |
## parameters not used by a model are ignored |
|
|
174 |
data.input <- list(X=if (qr) Q.qr else X, y=y, N=N, |
|
|
175 |
P=P, U=U, scale_u=scale.u, |
|
|
176 |
regularized=regularized, nu=nu, |
|
|
177 |
global_scale=global.scale, global_df=global.df, |
|
|
178 |
slab_scale=slab.scale, slab_df=slab.df) |
|
|
179 |
|
|
|
180 |
## run the stan model |
|
|
181 |
samples <- rstan::sampling(stanmodels[[model]], data=data.input, |
|
|
182 |
iter=iter, warmup=warmup, seed=seed, ..., |
|
|
183 |
pars=hs.pars, include=keep.hs.pars, |
|
|
184 |
control=list(adapt_delta=adapt.delta)) |
|
|
185 |
if (is.na(nrow(samples))) |
|
|
186 |
stop("rstan::sampling failed, see error message above.", call.=FALSE) |
|
|
187 |
|
|
|
188 |
## assign proper names |
|
|
189 |
par.idx <- grep("^beta_[up]", names(samples)) |
|
|
190 |
stopifnot(length(par.idx) == ncol(X)) |
|
|
191 |
names(samples)[par.idx] <- colnames(X) |
|
|
192 |
|
|
|
193 |
if (qr) { |
|
|
194 |
pars <- grep("beta_", samples@sim$pars_oi, value=TRUE) |
|
|
195 |
stopifnot(pars[1] == "beta_u") |
|
|
196 |
beta.tilde <- rstan::extract(samples, pars=pars, |
|
|
197 |
inc_warmup=TRUE, permuted=FALSE) |
|
|
198 |
B <- apply(beta.tilde, 1:2, FUN=function(z) R.inv %*% z) |
|
|
199 |
chains <- ncol(beta.tilde) |
|
|
200 |
for (chain in 1:chains) { |
|
|
201 |
for (p in 1:P) |
|
|
202 |
samples@sim$samples[[chain]][[par.idx[p]]] <- B[p, , chain] |
|
|
203 |
} |
|
|
204 |
} |
|
|
205 |
|
|
|
206 |
## store the hierarchical shrinkage settings |
|
|
207 |
opts <- list(adapt.delta=adapt.delta, qr=qr, seed=seed, scale.u=scale.u) |
|
|
208 |
if (K > 0) |
|
|
209 |
opts <- c(opts, regularized=regularized, nu=nu, par.ratio=par.ratio, |
|
|
210 |
global.scale=global.scale, global.df=global.df, |
|
|
211 |
slab.scale=slab.scale, slab.df=slab.df) |
|
|
212 |
|
|
|
213 |
## compute the posterior means of the regression coefficients |
|
|
214 |
betas <- list(unpenalized=colMeans(as.matrix(samples, pars="beta_u")), |
|
|
215 |
penalized=tryCatch(colMeans(as.matrix(samples, |
|
|
216 |
pars="beta_p")), |
|
|
217 |
error=function(e) NULL)) |
|
|
218 |
obj <- list(stanfit=samples, betas=betas, call=call, data=x, |
|
|
219 |
model.terms=model.terms, family=family, hsstan.settings=opts) |
|
|
220 |
class(obj) <- "hsstan" |
|
|
221 |
} |
|
|
222 |
|
|
|
223 |
return(obj) |
|
|
224 |
} |
|
|
225 |
|
|
|
226 |
#' K-fold cross-validation |
|
|
227 |
#' |
|
|
228 |
#' Perform K-fold cross-validation using the same settings used when fitting |
|
|
229 |
#' the model on the whole data. |
|
|
230 |
#' |
|
|
231 |
#' @param x An object of class `hsstan`. |
|
|
232 |
#' @param folds Integer vector with one element per observation indicating the |
|
|
233 |
#' cross-validation fold in which the observation should be withdrawn. |
|
|
234 |
#' @param chains Number of Markov chains to run. By default this is set to 1, |
|
|
235 |
#' independently of the number of chains used for `x`. |
|
|
236 |
#' @param store.fits Whether the fitted models for each fold should be stored |
|
|
237 |
#' in the returned object (`TRUE` by default). |
|
|
238 |
#' @param cores Number of cores to use for parallelization (the value of |
|
|
239 |
#' `options("mc.cores")` by default). The cross-validation folds will |
|
|
240 |
#' be distributed to the available cores, and the Markov chains for each |
|
|
241 |
#' model will be run sequentially. |
|
|
242 |
#' @param ... Further arguments passed to [rstan::sampling()]. |
|
|
243 |
#' |
|
|
244 |
#' @return |
|
|
245 |
#' An object with classes `kfold` and `loo` that has a similar structure as the |
|
|
246 |
#' objects returned by [loo()] and [waic()] and is compatible with the |
|
|
247 |
#' \code{\link[loo]{loo_compare}} function for |
|
|
248 |
#' comparing models. The object contains the following fields: |
|
|
249 |
#' \item{estimates}{a matrix containing point estimates and standard errors of |
|
|
250 |
#' the expected log pointwise predictive density ("elpd_kfold"), |
|
|
251 |
#' the effective number of parameters ("p_kfold", always `NA`) and the |
|
|
252 |
#' K-fold information criterion "kfoldic" (which is `-2 * elpd_kfold`, |
|
|
253 |
#' i.e., converted to the deviance scale).} |
|
|
254 |
#' \item{pointwise}{a matrix containing the pointwise contributions of |
|
|
255 |
#' "elpd_kfold", "p_kfold" and "kfoldic".} |
|
|
256 |
#' \item{fits}{a matrix with two columns and number of rows equal to the number |
|
|
257 |
#' of cross-validation folds. Column `fit` contains the fitted |
|
|
258 |
#' `hsstan` objects for each fold, and column `test.idx` contains |
|
|
259 |
#' the indices of the withdrawn observations for each fold. This is not |
|
|
260 |
#' present if `store.fits=FALSE`.} |
|
|
261 |
#' \item{data}{the dataset used in fitting the model (before withdrawing |
|
|
262 |
#' observations). This is not present if `store.fits=FALSE`.} |
|
|
263 |
#' |
|
|
264 |
#' @examples |
|
|
265 |
#' \donttest{ |
|
|
266 |
#' \dontshow{utils::example("hsstan", echo=FALSE)} |
|
|
267 |
#' # continued from ?hsstan |
|
|
268 |
#' # only 2 folds for speed of example |
|
|
269 |
#' folds <- rep(1:2, length.out=length(df$Y)) |
|
|
270 |
#' cv.biom <- kfold(hs.biom, folds=folds, cores=2) |
|
|
271 |
#' } |
|
|
272 |
#' |
|
|
273 |
#' @importFrom loo kfold |
|
|
274 |
#' @method kfold hsstan |
|
|
275 |
#' @aliases kfold |
|
|
276 |
#' @export kfold |
|
|
277 |
#' @export |
|
|
278 |
kfold.hsstan <- function(x, folds, chains=1, store.fits=TRUE, |
|
|
279 |
cores=getOption("mc.cores", 1), ...) { |
|
|
280 |
data <- x$data |
|
|
281 |
N <- nrow(data) |
|
|
282 |
folds <- validate.folds(folds, N) |
|
|
283 |
num.folds <- max(folds) |
|
|
284 |
validate.positive.scalar(chains, "chains", int=TRUE) |
|
|
285 |
validate.rstan.args(...) |
|
|
286 |
|
|
|
287 |
## collect the list of calls to be evaluated in parallel |
|
|
288 |
calls <- list() |
|
|
289 |
for (fold in 1:num.folds) { |
|
|
290 |
test.idx <- which(folds == fold) |
|
|
291 |
fit.call <- stats::update(object=x, x=data[-test.idx, , drop=FALSE], |
|
|
292 |
chains=chains, cores=1, refresh=0, |
|
|
293 |
open_progress=FALSE, evaluate=FALSE, ...) |
|
|
294 |
fit.call$x <- eval(fit.call$x) |
|
|
295 |
calls[[fold]] <- fit.call |
|
|
296 |
} |
|
|
297 |
|
|
|
298 |
## evaluate the models |
|
|
299 |
message("Fitting ", num.folds, " models using ", |
|
|
300 |
min(cores, num.folds), " cores") |
|
|
301 |
par.fun <- function(fold) { |
|
|
302 |
fit <- eval(calls[[fold]]) |
|
|
303 |
|
|
|
304 |
## log pointwise predictive densities (pointwise test log-likelihood) |
|
|
305 |
lppd <- log_lik(fit, newdata=data[which(folds == fold), , drop=FALSE]) |
|
|
306 |
return(list(lppd=lppd, fit=if (store.fits) fit else NULL)) |
|
|
307 |
} |
|
|
308 |
if (.Platform$OS.type != "windows") { |
|
|
309 |
cv <- parallel::mclapply(X=1:num.folds, mc.cores=cores, |
|
|
310 |
mc.preschedule=FALSE, FUN=par.fun) |
|
|
311 |
} else { # windows |
|
|
312 |
cl <- parallel::makePSOCKcluster(cores) |
|
|
313 |
on.exit(parallel::stopCluster(cl)) |
|
|
314 |
cv <- parallel::parLapply(X=1:num.folds, cl=cl, fun=par.fun) |
|
|
315 |
} |
|
|
316 |
|
|
|
317 |
## expected log predictive densities |
|
|
318 |
elpds.unord <- unlist(lapply(cv, function(z) apply(z$lppd, 2, logMeanExp))) |
|
|
319 |
obs.idx <- unlist(lapply(1:num.folds, function(z) which(folds == z))) |
|
|
320 |
elpds <- elpds.unord[obs.idx] |
|
|
321 |
|
|
|
322 |
pointwise <- cbind(elpd_kfold=elpds, p_kfold=NA, kfoldic=-2 * elpds) |
|
|
323 |
estimates <- colSums(pointwise) |
|
|
324 |
se.est <- sqrt(N * apply(pointwise, 2, stats::var)) |
|
|
325 |
out <- list(estimates=cbind(Estimate=estimates, SE=se.est), |
|
|
326 |
pointwise=pointwise) |
|
|
327 |
rownames(out$estimates) <- colnames(pointwise) |
|
|
328 |
if (store.fits) { |
|
|
329 |
fits <- array(list(), c(num.folds, 2), list(NULL, c("fit", "test.idx"))) |
|
|
330 |
for (fold in 1:num.folds) |
|
|
331 |
fits[fold, ] <- list(fit=cv[[fold]][["fit"]], |
|
|
332 |
test.idx=which(folds == fold)) |
|
|
333 |
out$fits <- fits |
|
|
334 |
out$data <- data |
|
|
335 |
} |
|
|
336 |
attr(out, "K") <- num.folds |
|
|
337 |
class(out) <- c("kfold", "loo") |
|
|
338 |
return(out) |
|
|
339 |
} |