|
a |
|
b/R/misc.R |
|
|
1 |
##============================================================================= |
|
|
2 |
## |
|
|
3 |
## Copyright (c) 2019 Marco Colombo |
|
|
4 |
## |
|
|
5 |
## This program is free software: you can redistribute it and/or modify |
|
|
6 |
## it under the terms of the GNU General Public License as published by |
|
|
7 |
## the Free Software Foundation, either version 3 of the License, or |
|
|
8 |
## (at your option) any later version. |
|
|
9 |
## |
|
|
10 |
## This program is distributed in the hope that it will be useful, |
|
|
11 |
## but WITHOUT ANY WARRANTY; without even the implied warranty of |
|
|
12 |
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the |
|
|
13 |
## GNU General Public License for more details. |
|
|
14 |
## |
|
|
15 |
## You should have received a copy of the GNU General Public License |
|
|
16 |
## along with this program. If not, see <http://www.gnu.org/licenses/>. |
|
|
17 |
## |
|
|
18 |
##============================================================================= |
|
|
19 |
|
|
|
20 |
|
|
|
21 |
#' Validate an hsstan object |
|
|
22 |
#' |
|
|
23 |
#' Check that the object has been created by [hsstan()]. |
|
|
24 |
#' |
|
|
25 |
#' @param obj An object to be checked. |
|
|
26 |
#' |
|
|
27 |
#' @return |
|
|
28 |
#' Throws an error if the object is not an `hsstan` object. |
|
|
29 |
#' |
|
|
30 |
#' @noRd |
|
|
31 |
validate.hsstan <- function(obj) { |
|
|
32 |
if (!inherits(obj, "hsstan")) { |
|
|
33 |
stop("Not an object of class 'hsstan'.") |
|
|
34 |
} |
|
|
35 |
} |
|
|
36 |
|
|
|
37 |
#' Validate the posterior samples |
|
|
38 |
#' |
|
|
39 |
#' Check that the object contains valid posterior samples in the |
|
|
40 |
#' `stanfit` field. |
|
|
41 |
#' |
|
|
42 |
#' @param obj An object of class `hsstan`. |
|
|
43 |
#' |
|
|
44 |
#' @return |
|
|
45 |
#' Throws an error if the object does not contain posterior samples. |
|
|
46 |
#' |
|
|
47 |
#' @noRd |
|
|
48 |
validate.samples <- function(obj) { |
|
|
49 |
if (!inherits(obj$stanfit, "stanfit")) { |
|
|
50 |
stop("No valid posterior samples stored in the 'hsstan' object.") |
|
|
51 |
} |
|
|
52 |
} |
|
|
53 |
|
|
|
54 |
#' Validate new data |
|
|
55 |
#' |
|
|
56 |
#' Check that the new data contains all variables used in the model and no |
|
|
57 |
#' missing values, and generate the corresponding model matrix. |
|
|
58 |
#' |
|
|
59 |
#' @param obj Object of class `hsstan`. |
|
|
60 |
#' @param newdata Optional data frame containing the variables used in the |
|
|
61 |
#' model. If `NULL`, the model matrix used when fitting the model |
|
|
62 |
#' is returned. |
|
|
63 |
#' |
|
|
64 |
#' @return |
|
|
65 |
#' A design matrix corresponding to the variables used in the model. |
|
|
66 |
#' |
|
|
67 |
#' @noRd |
|
|
68 |
validate.newdata <- function(obj, newdata) { |
|
|
69 |
if (is.null(newdata)) |
|
|
70 |
newdata <- obj$data |
|
|
71 |
else if (!inherits(newdata, c("data.frame", "matrix"))) |
|
|
72 |
stop("'newdata' must be a data frame or a matrix.") |
|
|
73 |
if (nrow(newdata) == 0 || ncol(newdata) == 0) |
|
|
74 |
stop("'newdata' contains no rows or no columns.") |
|
|
75 |
|
|
|
76 |
## only check for NAs in the variables used in the model |
|
|
77 |
vars <- with(obj$model.terms, c(outcome, unpenalized, penalized)) |
|
|
78 |
newdata <- newdata[, colnames(newdata) %in% vars, drop=FALSE] |
|
|
79 |
if (any(is.na(newdata))) |
|
|
80 |
stop("'newdata' contains missing values.") |
|
|
81 |
|
|
|
82 |
## this adds the intercept column back |
|
|
83 |
ordered.model.matrix(as.data.frame(newdata), |
|
|
84 |
obj$model.terms$unpenalized, |
|
|
85 |
obj$model.terms$penalized) |
|
|
86 |
} |
|
|
87 |
|
|
|
88 |
#' Validate a model formula |
|
|
89 |
#' |
|
|
90 |
#' Check that the formula that specifies a model contains all required elements. |
|
|
91 |
#' |
|
|
92 |
#' @param model Formula to be checked. |
|
|
93 |
#' @param penalized Vector of names for the penalized predictors. |
|
|
94 |
#' |
|
|
95 |
#' @return |
|
|
96 |
#' A list containing the formula representing the covariates model, the name of |
|
|
97 |
#' the outcome variable, the names of the upenalized and penalized predictors. |
|
|
98 |
#' |
|
|
99 |
#' @importFrom stats as.formula terms |
|
|
100 |
#' @noRd |
|
|
101 |
validate.model <- function(model, penalized) { |
|
|
102 |
if (is.character(model) && length(model) > 1) |
|
|
103 |
stop("Model formula specified incorrectly.") |
|
|
104 |
model <- as.formula(model) |
|
|
105 |
tt <- terms(model) |
|
|
106 |
if (attr(tt, "response") == 0) |
|
|
107 |
stop("No outcome variable specified in the model.") |
|
|
108 |
if (attr(tt, "intercept") == 0) |
|
|
109 |
stop("Models with no intercept are not supported.") |
|
|
110 |
if (length(penalized) > 0 && !is.character(penalized)) |
|
|
111 |
stop("'penalized' must be a character vector.") |
|
|
112 |
if (any(grepl("[:*]", penalized))) |
|
|
113 |
stop("Interaction terms in penalized predictors are not supported.") |
|
|
114 |
penalized <- setdiff(unique(trimws(penalized)), "") |
|
|
115 |
return(list(outcome=as.character(model)[2], |
|
|
116 |
unpenalized=setdiff(attr(tt, "term.labels"), penalized), |
|
|
117 |
penalized=penalized)) |
|
|
118 |
} |
|
|
119 |
|
|
|
120 |
#' Validate the model data |
|
|
121 |
#' |
|
|
122 |
#' Check if the model data can be used with the given model formula and |
|
|
123 |
#' penalized predictors. |
|
|
124 |
#' |
|
|
125 |
#' @param x An object to be checked. |
|
|
126 |
#' @param model Validated model formula. |
|
|
127 |
#' |
|
|
128 |
#' @return |
|
|
129 |
#' A data frame containing the model data. A factor or logical outcome variable |
|
|
130 |
#' is replaced by its numeric equivalent. |
|
|
131 |
#' |
|
|
132 |
#' @noRd |
|
|
133 |
validate.data <- function(x, model) { |
|
|
134 |
if (!inherits(x, c("data.frame", "matrix"))) |
|
|
135 |
stop("'x' must be a data frame or a matrix.") |
|
|
136 |
x <- as.data.frame(x) |
|
|
137 |
validate.variables(x, model$outcome) |
|
|
138 |
validate.variables(x, c(model$unpenalized, model$penalized)) |
|
|
139 |
x[[model$outcome]] <- validate.outcome(x[[model$outcome]]) |
|
|
140 |
return(x) |
|
|
141 |
} |
|
|
142 |
|
|
|
143 |
#' Validate variables |
|
|
144 |
#' |
|
|
145 |
#' Check that the required variables are in the dataset. |
|
|
146 |
#' |
|
|
147 |
#' @param x Data frame containing the variables of interest. |
|
|
148 |
#' @param variables Vector of variable names. |
|
|
149 |
#' |
|
|
150 |
#' @return |
|
|
151 |
#' Throws if variables are not present in the dataset or contain missing values. |
|
|
152 |
#' |
|
|
153 |
#' @noRd |
|
|
154 |
validate.variables <- function(x, variables) { |
|
|
155 |
## unpack interaction terms |
|
|
156 |
variables <- unique(unlist(strsplit(as.character(variables), ":"))) |
|
|
157 |
if (length(variables) == 0) |
|
|
158 |
stop("No predictors present in the model.") |
|
|
159 |
var.match <- match(variables, colnames(x)) |
|
|
160 |
if (anyNA(var.match)) |
|
|
161 |
stop(collapse(variables[is.na(var.match)]), " not present in 'x'.") |
|
|
162 |
if (anyNA(x[, variables])) |
|
|
163 |
stop("Model variables contain missing values.") |
|
|
164 |
} |
|
|
165 |
|
|
|
166 |
#' Validate the outcome variable |
|
|
167 |
#' |
|
|
168 |
#' Check that the outcome variable can be converted to a valid numerical |
|
|
169 |
#' vector. |
|
|
170 |
#' |
|
|
171 |
#' @param y Outcome vector to be checked. |
|
|
172 |
#' |
|
|
173 |
#' @return |
|
|
174 |
#' A numeric vector. |
|
|
175 |
#' |
|
|
176 |
#' @noRd |
|
|
177 |
validate.outcome <- function(y) { |
|
|
178 |
if (is.factor(y)) { |
|
|
179 |
if (nlevels(y) != 2) |
|
|
180 |
stop("A factor outcome variable can only have two levels.") |
|
|
181 |
y <- as.integer(y) - 1 |
|
|
182 |
} |
|
|
183 |
if (!(is.numeric(y) || is.logical(y))) |
|
|
184 |
stop("Outcome variable of invalid type.") |
|
|
185 |
return(as.numeric(y)) |
|
|
186 |
} |
|
|
187 |
|
|
|
188 |
#' Validate the family argument |
|
|
189 |
#' |
|
|
190 |
#' Ensure that the family argument has been specified correctly. |
|
|
191 |
#' This is inspired by code in \code{\link{glm}}. |
|
|
192 |
#' |
|
|
193 |
#' @param family Family argument to test. |
|
|
194 |
#' @param y Outcome variable. |
|
|
195 |
#' |
|
|
196 |
#' @return |
|
|
197 |
#' A valid family. The function throws an error if the family argument cannot |
|
|
198 |
#' be used. |
|
|
199 |
#' |
|
|
200 |
#' @importFrom methods is |
|
|
201 |
#' @noRd |
|
|
202 |
validate.family <- function(family, y) { |
|
|
203 |
if (missing(family)) |
|
|
204 |
stop("Argument of 'family' is missing.") |
|
|
205 |
if (is.character(family)) |
|
|
206 |
tryCatch( |
|
|
207 |
family <- get(family, mode="function", envir=parent.frame(2)), |
|
|
208 |
error=function(e) |
|
|
209 |
stop("'", family, "' is not a valid family.") |
|
|
210 |
) |
|
|
211 |
if (is.function(family)) |
|
|
212 |
family <- family() |
|
|
213 |
if (!is(family, "family")) |
|
|
214 |
stop("Argument of 'family' is not a valid family.") |
|
|
215 |
if (!family$family %in% c("gaussian", "binomial")) |
|
|
216 |
stop("Only 'gaussian' and 'binomial' are supported families.") |
|
|
217 |
|
|
|
218 |
if (family$family == "binomial") { |
|
|
219 |
if (length(table(y)) != 2) |
|
|
220 |
stop("Outcome variable must contain two classes with family=binomial.") |
|
|
221 |
if (!is.factor(y) && any(y < 0 | y > 1)) |
|
|
222 |
stop("Outcome variable must contain 0-1 values with family=binomial.") |
|
|
223 |
} |
|
|
224 |
|
|
|
225 |
return(family) |
|
|
226 |
} |
|
|
227 |
|
|
|
228 |
#' Validate a vector of indices |
|
|
229 |
#' |
|
|
230 |
#' @param x Vector to be checked. |
|
|
231 |
#' @param N Maximum valid index. |
|
|
232 |
#' @param name Name of the vector to report in error messages. |
|
|
233 |
#' @param throw.duplicates Whether the function should throw if the vector |
|
|
234 |
#' contains duplicate elements (`TRUE` by default). |
|
|
235 |
#' |
|
|
236 |
#' @return |
|
|
237 |
#' Throws an error if the given vector is not an integer vector or contains |
|
|
238 |
#' missing, out of bounds or duplicate indices (if `throw.duplicates` is `TRUE`). |
|
|
239 |
#' |
|
|
240 |
#' @noRd |
|
|
241 |
validate.indices <- function(x, N, name, throw.duplicates=TRUE) { |
|
|
242 |
if (anyNA(x)) |
|
|
243 |
stop("'", name, "' contains missing values.") |
|
|
244 |
if (!is.numeric(x) || NCOL(x) > 1 || any(x != as.integer(x))) |
|
|
245 |
stop("'", name, "' must be an integer vector.") |
|
|
246 |
if (length(x) < 2) |
|
|
247 |
stop("'", name, "' must contain at least two elements.") |
|
|
248 |
if (any(x < 1 | x > N)) |
|
|
249 |
stop("'", name, "' contains out of bounds indices.") |
|
|
250 |
if (throw.duplicates && any(duplicated(x))) |
|
|
251 |
stop("'", name, "' contains duplicate indices.") |
|
|
252 |
} |
|
|
253 |
|
|
|
254 |
#' Validate the cross-validation folds |
|
|
255 |
#' |
|
|
256 |
#' @param folds Folds to be checked or `NULL`. |
|
|
257 |
#' @param N Number of observations. |
|
|
258 |
#' |
|
|
259 |
#' @return |
|
|
260 |
#' An integer vector with one element per observation indicating the |
|
|
261 |
#' cross-validation fold in which the observation should be withdrawn. |
|
|
262 |
#' |
|
|
263 |
#' @noRd |
|
|
264 |
validate.folds <- function(folds, N) { |
|
|
265 |
if (is.null(folds)) |
|
|
266 |
return(rep(1, N)) |
|
|
267 |
validate.indices(folds, N, "folds", throw.duplicates=FALSE) |
|
|
268 |
if (length(folds) != N) |
|
|
269 |
stop("'folds' should have length ", N, ".") |
|
|
270 |
K <- length(unique(folds)) |
|
|
271 |
if (!all(1:K %in% folds)) |
|
|
272 |
stop("'folds' must contain all indices up to ", K, ".") |
|
|
273 |
folds <- as.integer(folds) |
|
|
274 |
} |
|
|
275 |
|
|
|
276 |
#' Validate start.from |
|
|
277 |
#' |
|
|
278 |
#' Check that the predictor names provided is a valid subset of the variables |
|
|
279 |
#' used in the model. |
|
|
280 |
#' |
|
|
281 |
#' @param obj An object of class `hsstan`. |
|
|
282 |
#' @param start.from Vector to be checked. |
|
|
283 |
#' |
|
|
284 |
#' @return |
|
|
285 |
#' A list of two elements: the names of the model terms matching `start.from` |
|
|
286 |
#' and a vector of indices corresponding to the names listed in `start.from`. |
|
|
287 |
#' Throws an error if any of the names mentioned does not match those available |
|
|
288 |
#' in the model terms. |
|
|
289 |
#' |
|
|
290 |
#' @noRd |
|
|
291 |
validate.start.from <- function(obj, start.from) { |
|
|
292 |
unp.terms <- obj$model.terms$unpenalized |
|
|
293 |
unp.betas <- names(obj$betas$unpenalized) |
|
|
294 |
mod.terms <- c(unp.terms, obj$model.terms$penalized) |
|
|
295 |
mod.betas <- c(unp.betas, names(obj$betas$penalized)) |
|
|
296 |
start.from <- setdiff(start.from, "") |
|
|
297 |
if (is.null(start.from)) { |
|
|
298 |
if (length(obj$model.terms$penalized) > 0) |
|
|
299 |
return(list(start.from=unp.terms, idx=seq_along(unp.betas))) |
|
|
300 |
else |
|
|
301 |
return(list(start.from=character(0), idx=1)) |
|
|
302 |
} |
|
|
303 |
if (length(start.from) == 0) |
|
|
304 |
return(list(start.from=character(0), idx=1)) |
|
|
305 |
if (anyNA(start.from)) |
|
|
306 |
stop("'start.from' contains missing values.") |
|
|
307 |
var.match <- match(start.from, mod.terms) |
|
|
308 |
if (anyNA(var.match)) |
|
|
309 |
stop("'start.from' contains ", collapse(start.from[is.na(var.match)]), |
|
|
310 |
", which cannot be matched.") |
|
|
311 |
|
|
|
312 |
## unpack interaction terms so that also main effects are matched |
|
|
313 |
start.from <- mod.terms[mod.terms %in% |
|
|
314 |
c(start.from, unlist(strsplit(start.from, ":")))] |
|
|
315 |
chosen <- expand.terms(obj$data, start.from) |
|
|
316 |
|
|
|
317 |
## also consider interaction terms in reverse order |
|
|
318 |
chosen <- c(chosen, sapply(strsplit(chosen[grep(":", chosen)], ":"), |
|
|
319 |
function(z) c(z, paste(rev(z), collapse=":")))) |
|
|
320 |
return(list(start.from=start.from, idx=which(mod.betas %in% chosen))) |
|
|
321 |
} |
|
|
322 |
|
|
|
323 |
#' Validate a positive or non-negative scalar value |
|
|
324 |
#' |
|
|
325 |
#' @param x Value to validate. |
|
|
326 |
#' @param name Variable name to report in case of error. |
|
|
327 |
#' @param int Whether the value has to be an integer (`FALSE` by default). |
|
|
328 |
#' |
|
|
329 |
#' @return |
|
|
330 |
#' Throws an error if the given value is not a positive or non-negative |
|
|
331 |
#' scalar (or integer scalar). |
|
|
332 |
#' |
|
|
333 |
#' @noRd |
|
|
334 |
validate.positive.scalar <- function(x, name, int=FALSE) { |
|
|
335 |
if (!is.numeric(x) || length(x) != 1 || is.na(x) || x <= 0 || |
|
|
336 |
(int && (x > .Machine$integer.max || x != as.integer(x)))) |
|
|
337 |
stop(sprintf("'%s' must be a positive %s.", name, |
|
|
338 |
ifelse(int, "integer", "scalar")), call.=FALSE) |
|
|
339 |
} |
|
|
340 |
|
|
|
341 |
#' @noRd |
|
|
342 |
validate.nonnegative.scalar <- function(x, name, int=FALSE) { |
|
|
343 |
if (!is.numeric(x) || length(x) != 1 || is.na(x) || x < 0 || |
|
|
344 |
(int && (x > .Machine$integer.max || x != as.integer(x)))) |
|
|
345 |
stop(sprintf("'%s' must be a non-negative %s.", name, |
|
|
346 |
ifelse(int, "integer", "scalar")), call.=FALSE) |
|
|
347 |
} |
|
|
348 |
|
|
|
349 |
#' Validate adapt.delta |
|
|
350 |
#' |
|
|
351 |
#' Check that an adaptation acceptance probability is valid. |
|
|
352 |
#' |
|
|
353 |
#' @param adapt.delta Value to be checked. |
|
|
354 |
#' |
|
|
355 |
#' @return |
|
|
356 |
#' Throws an error if the given value is not a valid acceptance probability |
|
|
357 |
#' for adaptation. |
|
|
358 |
#' |
|
|
359 |
#' @noRd |
|
|
360 |
validate.adapt.delta <- function(adapt.delta) { |
|
|
361 |
if (!is.numeric(adapt.delta) || length(adapt.delta) != 1) { |
|
|
362 |
stop("'adapt.delta' must be a single numerical value.") |
|
|
363 |
} |
|
|
364 |
if (adapt.delta < 0.8) { |
|
|
365 |
stop("'adapt.delta' must be at least 0.8.") |
|
|
366 |
} |
|
|
367 |
if (adapt.delta >= 1) { |
|
|
368 |
stop("'adapt.delta' must be less than 1.") |
|
|
369 |
} |
|
|
370 |
} |
|
|
371 |
|
|
|
372 |
#' Validate a probability |
|
|
373 |
#' |
|
|
374 |
#' Check that a probability value is valid. |
|
|
375 |
#' |
|
|
376 |
#' @param prob Value to be checked. |
|
|
377 |
#' |
|
|
378 |
#' @return |
|
|
379 |
#' Throws an error if the given value is not a valid probability. |
|
|
380 |
#' |
|
|
381 |
#' @noRd |
|
|
382 |
validate.probability <- function(prob) { |
|
|
383 |
if (length(prob) != 1 || prob <= 0 || prob >= 1) |
|
|
384 |
stop("'prob' must be a single value between 0 and 1.\n") |
|
|
385 |
} |
|
|
386 |
|
|
|
387 |
#' Validate arguments passed to rstan |
|
|
388 |
#' |
|
|
389 |
#' Ensure that the options to be passed to \code{\link[rstan]{sampling}} are |
|
|
390 |
#' valid, as to work around rstan issue #681. |
|
|
391 |
#' |
|
|
392 |
#' @param ... List of arguments to be checked. |
|
|
393 |
#' |
|
|
394 |
#' @return |
|
|
395 |
#' Throws an error if any argument is not valid for \code{\link[rstan]{sampling}}. |
|
|
396 |
#' |
|
|
397 |
#' @noRd |
|
|
398 |
validate.rstan.args <- function(...) { |
|
|
399 |
valid.args <- c("chains", "cores", "pars", "thin", "init", "check_data", |
|
|
400 |
"sample_file", "diagnostic_file", "verbose", "algorithm", |
|
|
401 |
"control", "open_progress", "show_messages", "chain_id", |
|
|
402 |
"init_r", "test_grad", "append_samples", "refresh", |
|
|
403 |
"save_warmup", "enable_random_init", "iter", "warmup") |
|
|
404 |
dots <- list(...) |
|
|
405 |
for (arg in names(dots)) |
|
|
406 |
if (!arg %in% valid.args) |
|
|
407 |
stop("Argument '", arg, "' not recognized.") |
|
|
408 |
} |
|
|
409 |
|
|
|
410 |
#' Parameter names |
|
|
411 |
#' |
|
|
412 |
#' Get the parameter names corresponding to the regression coefficients or |
|
|
413 |
#' matching a regular expression. |
|
|
414 |
#' |
|
|
415 |
#' @param obj An object of class `hsstan`. |
|
|
416 |
#' @param pars Regular expression to match a parameter name, or `NULL` |
|
|
417 |
#' to retrieve the names of all regression coefficients. |
|
|
418 |
#' |
|
|
419 |
#' @return |
|
|
420 |
#' A character vector. |
|
|
421 |
#' |
|
|
422 |
#' @noRd |
|
|
423 |
get.pars <- function(object, pars) { |
|
|
424 |
if (is.null(pars)) |
|
|
425 |
pars <- grep("^beta_", object$stanfit@model_pars, value=TRUE) |
|
|
426 |
else { |
|
|
427 |
if (!is.character(pars)) |
|
|
428 |
stop("'pars' must be a character vector.") |
|
|
429 |
get.pars <- function(x) grep(x, object$stanfit@sim$fnames_oi, value=TRUE) |
|
|
430 |
pars <- unlist(lapply(pars, get.pars)) |
|
|
431 |
if (length(pars) == 0) |
|
|
432 |
stop("No pattern in 'pars' matches parameter names.") |
|
|
433 |
} |
|
|
434 |
return(pars) |
|
|
435 |
} |
|
|
436 |
|
|
|
437 |
#' Create a design matrix with all unpenalized predictors first |
|
|
438 |
#' |
|
|
439 |
#' This is required as `model.matrix` puts the interaction terms after the |
|
|
440 |
#' penalized predictors, but the Stan models expects all unpenalized terms to |
|
|
441 |
#' appear before the penalized ones. |
|
|
442 |
#' |
|
|
443 |
#' @param x Data frame containing the variables of interest. |
|
|
444 |
#' @param unpenalized Vector of variable names for the unpenalized covariates. |
|
|
445 |
#' @param penalized Vector of variable names for the penalized predictors. |
|
|
446 |
#' |
|
|
447 |
#' @return |
|
|
448 |
#' A design matrix with all unpenalized covariates (including interaction terms) |
|
|
449 |
#' before the penalized predictors. |
|
|
450 |
#' |
|
|
451 |
#' @importFrom stats model.matrix reformulate |
|
|
452 |
#' @noRd |
|
|
453 |
ordered.model.matrix <- function(x, unpenalized, penalized) { |
|
|
454 |
X <- model.matrix(reformulate(c(unpenalized, penalized)), data=x) |
|
|
455 |
if (any(grepl("[:*]", unpenalized))) |
|
|
456 |
X <- X[, c(expand.terms(x, unpenalized), expand.terms(x, penalized)[-1])] |
|
|
457 |
return(X) |
|
|
458 |
} |
|
|
459 |
|
|
|
460 |
#' Expand variable names into formula terms |
|
|
461 |
#' |
|
|
462 |
#' @param x Data frame containing the variables of interest. |
|
|
463 |
#' @param variables Vector of variable names. |
|
|
464 |
#' |
|
|
465 |
#' @return |
|
|
466 |
#' A vector of variable names expanded by factor levels and interaction terms. |
|
|
467 |
#' |
|
|
468 |
#' @importFrom stats model.matrix reformulate |
|
|
469 |
#' @noRd |
|
|
470 |
expand.terms <- function(x, variables) { |
|
|
471 |
if (length(variables) == 0) |
|
|
472 |
return(character(0)) |
|
|
473 |
colnames(model.matrix(reformulate(variables), x[1, ])) |
|
|
474 |
} |
|
|
475 |
|
|
|
476 |
#' Summarize a vector |
|
|
477 |
#' |
|
|
478 |
#' @param x A numerical vector. |
|
|
479 |
#' @param prob Width of the interval between quantiles. |
|
|
480 |
#' |
|
|
481 |
#' @return |
|
|
482 |
#' The mean, standard deviation and quantiles for the input vector. |
|
|
483 |
#' |
|
|
484 |
#' @noRd |
|
|
485 |
vector.summary <- function(x, prob) { |
|
|
486 |
lower <- (1 - prob) / 2 |
|
|
487 |
upper <- 1 - lower |
|
|
488 |
c(mean=mean(x), sd=stats::sd(x), stats::quantile(x, c(lower, upper))) |
|
|
489 |
} |
|
|
490 |
|
|
|
491 |
#' Check whether the model fitted is a logistic regression model. |
|
|
492 |
#' |
|
|
493 |
#' @param obj An object of class `hsstan`. |
|
|
494 |
#' |
|
|
495 |
#' @return |
|
|
496 |
#' `TRUE` for logistic regression models, `FALSE` otherwise. |
|
|
497 |
#' |
|
|
498 |
#' @noRd |
|
|
499 |
is.logistic <- function(obj) { |
|
|
500 |
obj$family$family == "binomial" |
|
|
501 |
} |
|
|
502 |
|
|
|
503 |
#' Comma-separated string concatenation |
|
|
504 |
#' |
|
|
505 |
#' Collapse the elements of a character vector into a comma-separated string. |
|
|
506 |
#' |
|
|
507 |
#' @param x Character vector. |
|
|
508 |
#' |
|
|
509 |
#' @return |
|
|
510 |
#' A comma-separated string where each element of the original vector is |
|
|
511 |
#' surrounded by single quotes. |
|
|
512 |
#' |
|
|
513 |
#' @noRd |
|
|
514 |
collapse <- function(x) { |
|
|
515 |
paste0("'", x, "'", collapse=", ") |
|
|
516 |
} |
|
|
517 |
|
|
|
518 |
#' Fast computation of correlations |
|
|
519 |
#' |
|
|
520 |
#' This provides a loopless version of the computation of the correlation |
|
|
521 |
#' coefficient between observed and predicted outcomes. |
|
|
522 |
#' |
|
|
523 |
#' @param y Vector of observed outcome. |
|
|
524 |
#' @param x Matrix with as many columns as the number of elements in `y`, |
|
|
525 |
#' where each row corresponds to a predicted outcome. |
|
|
526 |
#' |
|
|
527 |
#' @return |
|
|
528 |
#' A vector of correlations with as many elements as the number of rows in `x`. |
|
|
529 |
#' |
|
|
530 |
#' @noRd |
|
|
531 |
fastCor <- function(y, x) { |
|
|
532 |
yx <- rbind(y, x) |
|
|
533 |
if (.Machine$sizeof.pointer == 8) { |
|
|
534 |
yx <- yx - rowMeans(yx) |
|
|
535 |
yx <- yx / sqrt(rowSums(yx^2)) |
|
|
536 |
corr <- tcrossprod(yx, yx) |
|
|
537 |
} else { |
|
|
538 |
corr <- stats::cor(yx) |
|
|
539 |
} |
|
|
540 |
return(corr[-1, 1]) |
|
|
541 |
} |
|
|
542 |
|
|
|
543 |
#' Log of sum of exponentials |
|
|
544 |
#' |
|
|
545 |
#' @noRd |
|
|
546 |
logSumExp <- function(x) { |
|
|
547 |
xmax <- max(x) |
|
|
548 |
xmax + log(sum(exp(x - xmax))) |
|
|
549 |
} |
|
|
550 |
|
|
|
551 |
#' Log of average of exponentials |
|
|
552 |
#' |
|
|
553 |
#' @noRd |
|
|
554 |
logMeanExp <- function(x) { |
|
|
555 |
logSumExp(x) - log(length(x)) |
|
|
556 |
} |