Diff of /R/misc.R [000000] .. [ede2d4]

Switch to unified view

a b/R/misc.R
1
##=============================================================================
2
##
3
## Copyright (c) 2019 Marco Colombo
4
##
5
## This program is free software: you can redistribute it and/or modify
6
## it under the terms of the GNU General Public License as published by
7
## the Free Software Foundation, either version 3 of the License, or
8
## (at your option) any later version.
9
##
10
## This program is distributed in the hope that it will be useful,
11
## but WITHOUT ANY WARRANTY; without even the implied warranty of
12
## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13
## GNU General Public License for more details.
14
##
15
## You should have received a copy of the GNU General Public License
16
## along with this program.  If not, see <http://www.gnu.org/licenses/>.
17
##
18
##=============================================================================
19
20
21
#' Validate an hsstan object
22
#'
23
#' Check that the object has been created by [hsstan()].
24
#'
25
#' @param obj An object to be checked.
26
#'
27
#' @return
28
#' Throws an error if the object is not an `hsstan` object.
29
#'
30
#' @noRd
31
validate.hsstan <- function(obj) {
32
    if (!inherits(obj, "hsstan")) {
33
        stop("Not an object of class 'hsstan'.")
34
    }
35
}
36
37
#' Validate the posterior samples
38
#'
39
#' Check that the object contains valid posterior samples in the
40
#' `stanfit` field.
41
#'
42
#' @param obj An object of class `hsstan`.
43
#'
44
#' @return
45
#' Throws an error if the object does not contain posterior samples.
46
#'
47
#' @noRd
48
validate.samples <- function(obj) {
49
    if (!inherits(obj$stanfit, "stanfit")) {
50
        stop("No valid posterior samples stored in the 'hsstan' object.")
51
    }
52
}
53
54
#' Validate new data
55
#'
56
#' Check that the new data contains all variables used in the model and no
57
#' missing values, and generate the corresponding model matrix.
58
#'
59
#' @param obj Object of class `hsstan`.
60
#' @param newdata Optional data frame containing the variables used in the
61
#'        model. If `NULL`, the model matrix used when fitting the model
62
#'        is returned.
63
#'
64
#' @return
65
#' A design matrix corresponding to the variables used in the model.
66
#'
67
#' @noRd
68
validate.newdata <- function(obj, newdata) {
69
    if (is.null(newdata))
70
        newdata <- obj$data
71
    else if (!inherits(newdata, c("data.frame", "matrix")))
72
        stop("'newdata' must be a data frame or a matrix.")
73
    if (nrow(newdata) == 0 || ncol(newdata) == 0)
74
        stop("'newdata' contains no rows or no columns.")
75
76
    ## only check for NAs in the variables used in the model
77
    vars <- with(obj$model.terms, c(outcome, unpenalized, penalized))
78
    newdata <- newdata[, colnames(newdata) %in% vars, drop=FALSE]
79
    if (any(is.na(newdata)))
80
        stop("'newdata' contains missing values.")
81
82
    ## this adds the intercept column back
83
    ordered.model.matrix(as.data.frame(newdata),
84
                         obj$model.terms$unpenalized,
85
                         obj$model.terms$penalized)
86
}
87
88
#' Validate a model formula
89
#'
90
#' Check that the formula that specifies a model contains all required elements.
91
#'
92
#' @param model Formula to be checked.
93
#' @param penalized Vector of names for the penalized predictors.
94
#'
95
#' @return
96
#' A list containing the formula representing the covariates model, the name of
97
#' the outcome variable, the names of the upenalized and penalized predictors.
98
#'
99
#' @importFrom stats as.formula terms
100
#' @noRd
101
validate.model <- function(model, penalized) {
102
    if (is.character(model) && length(model) > 1)
103
        stop("Model formula specified incorrectly.")
104
    model <- as.formula(model)
105
    tt <- terms(model)
106
    if (attr(tt, "response") == 0)
107
        stop("No outcome variable specified in the model.")
108
    if (attr(tt, "intercept") == 0)
109
        stop("Models with no intercept are not supported.")
110
    if (length(penalized) > 0 && !is.character(penalized))
111
        stop("'penalized' must be a character vector.")
112
    if (any(grepl("[:*]", penalized)))
113
        stop("Interaction terms in penalized predictors are not supported.")
114
    penalized <- setdiff(unique(trimws(penalized)), "")
115
    return(list(outcome=as.character(model)[2],
116
                unpenalized=setdiff(attr(tt, "term.labels"), penalized),
117
                penalized=penalized))
118
}
119
120
#' Validate the model data
121
#'
122
#' Check if the model data can be used with the given model formula and
123
#' penalized predictors.
124
#'
125
#' @param x An object to be checked.
126
#' @param model Validated model formula.
127
#'
128
#' @return
129
#' A data frame containing the model data. A factor or logical outcome variable
130
#' is replaced by its numeric equivalent.
131
#'
132
#' @noRd
133
validate.data <- function(x, model) {
134
    if (!inherits(x, c("data.frame", "matrix")))
135
        stop("'x' must be a data frame or a matrix.")
136
    x <- as.data.frame(x)
137
    validate.variables(x, model$outcome)
138
    validate.variables(x, c(model$unpenalized, model$penalized))
139
    x[[model$outcome]] <- validate.outcome(x[[model$outcome]])
140
    return(x)
141
}
142
143
#' Validate variables
144
#'
145
#' Check that the required variables are in the dataset.
146
#'
147
#' @param x Data frame containing the variables of interest.
148
#' @param variables Vector of variable names.
149
#'
150
#' @return
151
#' Throws if variables are not present in the dataset or contain missing values.
152
#'
153
#' @noRd
154
validate.variables <- function(x, variables) {
155
    ## unpack interaction terms
156
    variables <- unique(unlist(strsplit(as.character(variables), ":")))
157
    if (length(variables) == 0)
158
        stop("No predictors present in the model.")
159
    var.match <- match(variables, colnames(x))
160
    if (anyNA(var.match))
161
        stop(collapse(variables[is.na(var.match)]), " not present in 'x'.")
162
    if (anyNA(x[, variables]))
163
        stop("Model variables contain missing values.")
164
}
165
166
#' Validate the outcome variable
167
#'
168
#' Check that the outcome variable can be converted to a valid numerical
169
#' vector.
170
#'
171
#' @param y Outcome vector to be checked.
172
#'
173
#' @return
174
#' A numeric vector.
175
#'
176
#' @noRd
177
validate.outcome <- function(y) {
178
    if (is.factor(y)) {
179
        if (nlevels(y) != 2)
180
            stop("A factor outcome variable can only have two levels.")
181
        y <- as.integer(y) - 1
182
    }
183
    if (!(is.numeric(y) || is.logical(y)))
184
        stop("Outcome variable of invalid type.")
185
    return(as.numeric(y))
186
}
187
188
#' Validate the family argument
189
#'
190
#' Ensure that the family argument has been specified correctly.
191
#' This is inspired by code in \code{\link{glm}}.
192
#'
193
#' @param family Family argument to test.
194
#' @param y Outcome variable.
195
#'
196
#' @return
197
#' A valid family. The function throws an error if the family argument cannot
198
#' be used.
199
#'
200
#' @importFrom methods is
201
#' @noRd
202
validate.family <- function(family, y) {
203
    if (missing(family))
204
        stop("Argument of 'family' is missing.")
205
    if (is.character(family))
206
        tryCatch(
207
            family <- get(family, mode="function", envir=parent.frame(2)),
208
                          error=function(e)
209
                              stop("'", family, "' is not a valid family.")
210
        )
211
    if (is.function(family))
212
        family <- family()
213
    if (!is(family, "family"))
214
        stop("Argument of 'family' is not a valid family.")
215
    if (!family$family %in% c("gaussian", "binomial"))
216
        stop("Only 'gaussian' and 'binomial' are supported families.")
217
218
    if (family$family == "binomial") {
219
        if (length(table(y)) != 2)
220
            stop("Outcome variable must contain two classes with family=binomial.")
221
        if (!is.factor(y) && any(y < 0 | y > 1))
222
            stop("Outcome variable must contain 0-1 values with family=binomial.")
223
    }
224
225
    return(family)
226
}
227
228
#' Validate a vector of indices
229
#'
230
#' @param x Vector to be checked.
231
#' @param N Maximum valid index.
232
#' @param name Name of the vector to report in error messages.
233
#' @param throw.duplicates Whether the function should throw if the vector
234
#'        contains duplicate elements (`TRUE` by default).
235
#'
236
#' @return
237
#' Throws an error if the given vector is not an integer vector or contains
238
#' missing, out of bounds or duplicate indices (if `throw.duplicates` is `TRUE`).
239
#'
240
#' @noRd
241
validate.indices <- function(x, N, name, throw.duplicates=TRUE) {
242
    if (anyNA(x))
243
        stop("'", name, "' contains missing values.")
244
    if (!is.numeric(x) || NCOL(x) > 1 || any(x != as.integer(x)))
245
        stop("'", name, "' must be an integer vector.")
246
    if (length(x) < 2)
247
        stop("'", name, "' must contain at least two elements.")
248
    if (any(x < 1 | x > N))
249
        stop("'", name, "' contains out of bounds indices.")
250
    if (throw.duplicates && any(duplicated(x)))
251
        stop("'", name, "' contains duplicate indices.")
252
}
253
254
#' Validate the cross-validation folds
255
#'
256
#' @param folds Folds to be checked or `NULL`.
257
#' @param N Number of observations.
258
#'
259
#' @return
260
#' An integer vector with one element per observation indicating the
261
#' cross-validation fold in which the observation should be withdrawn.
262
#'
263
#' @noRd
264
validate.folds <- function(folds, N) {
265
    if (is.null(folds))
266
        return(rep(1, N))
267
    validate.indices(folds, N, "folds", throw.duplicates=FALSE)
268
    if (length(folds) != N)
269
        stop("'folds' should have length ", N, ".")
270
    K <- length(unique(folds))
271
    if (!all(1:K %in% folds))
272
        stop("'folds' must contain all indices up to ", K, ".")
273
    folds <- as.integer(folds)
274
}
275
276
#' Validate start.from
277
#'
278
#' Check that the predictor names provided is a valid subset of the variables
279
#' used in the model.
280
#'
281
#' @param obj An object of class `hsstan`.
282
#' @param start.from Vector to be checked.
283
#'
284
#' @return
285
#' A list of two elements: the names of the model terms matching `start.from`
286
#' and a vector of indices corresponding to the names listed in `start.from`.
287
#' Throws an error if any of the names mentioned does not match those available
288
#' in the model terms.
289
#'
290
#' @noRd
291
validate.start.from <- function(obj, start.from) {
292
    unp.terms <- obj$model.terms$unpenalized
293
    unp.betas <- names(obj$betas$unpenalized)
294
    mod.terms <- c(unp.terms, obj$model.terms$penalized)
295
    mod.betas <- c(unp.betas, names(obj$betas$penalized))
296
    start.from <- setdiff(start.from, "")
297
    if (is.null(start.from)) {
298
        if (length(obj$model.terms$penalized) > 0)
299
            return(list(start.from=unp.terms, idx=seq_along(unp.betas)))
300
        else
301
            return(list(start.from=character(0), idx=1))
302
    }
303
    if (length(start.from) == 0)
304
        return(list(start.from=character(0), idx=1))
305
    if (anyNA(start.from))
306
        stop("'start.from' contains missing values.")
307
    var.match <- match(start.from, mod.terms)
308
    if (anyNA(var.match))
309
        stop("'start.from' contains ", collapse(start.from[is.na(var.match)]),
310
             ", which cannot be matched.")
311
312
    ## unpack interaction terms so that also main effects are matched
313
    start.from <- mod.terms[mod.terms %in%
314
                            c(start.from, unlist(strsplit(start.from, ":")))]
315
    chosen <- expand.terms(obj$data, start.from)
316
317
    ## also consider interaction terms in reverse order
318
    chosen <- c(chosen, sapply(strsplit(chosen[grep(":", chosen)], ":"),
319
                               function(z) c(z, paste(rev(z), collapse=":"))))
320
    return(list(start.from=start.from, idx=which(mod.betas %in% chosen)))
321
}
322
323
#' Validate a positive or non-negative scalar value
324
#'
325
#' @param x Value to validate.
326
#' @param name Variable name to report in case of error.
327
#' @param int Whether the value has to be an integer (`FALSE` by default).
328
#'
329
#' @return
330
#' Throws an error if the given value is not a positive or non-negative
331
#' scalar (or integer scalar).
332
#'
333
#' @noRd
334
validate.positive.scalar <- function(x, name, int=FALSE) {
335
    if (!is.numeric(x) || length(x) != 1 || is.na(x) || x <= 0 ||
336
        (int && (x > .Machine$integer.max || x != as.integer(x))))
337
        stop(sprintf("'%s' must be a positive %s.", name,
338
                     ifelse(int, "integer", "scalar")), call.=FALSE)
339
}
340
341
#' @noRd
342
validate.nonnegative.scalar <- function(x, name, int=FALSE) {
343
    if (!is.numeric(x) || length(x) != 1 || is.na(x) || x < 0 ||
344
        (int && (x > .Machine$integer.max || x != as.integer(x))))
345
        stop(sprintf("'%s' must be a non-negative %s.", name,
346
                     ifelse(int, "integer", "scalar")), call.=FALSE)
347
}
348
349
#' Validate adapt.delta
350
#'
351
#' Check that an adaptation acceptance probability is valid.
352
#'
353
#' @param adapt.delta Value to be checked.
354
#'
355
#' @return
356
#' Throws an error if the given value is not a valid acceptance probability
357
#' for adaptation.
358
#'
359
#' @noRd
360
validate.adapt.delta <- function(adapt.delta) {
361
    if (!is.numeric(adapt.delta) || length(adapt.delta) != 1) {
362
        stop("'adapt.delta' must be a single numerical value.")
363
    }
364
    if (adapt.delta < 0.8) {
365
        stop("'adapt.delta' must be at least 0.8.")
366
    }
367
    if (adapt.delta >= 1) {
368
        stop("'adapt.delta' must be less than 1.")
369
    }
370
}
371
372
#' Validate a probability
373
#'
374
#' Check that a probability value is valid.
375
#'
376
#' @param prob Value to be checked.
377
#'
378
#' @return
379
#' Throws an error if the given value is not a valid probability.
380
#'
381
#' @noRd
382
validate.probability <- function(prob) {
383
    if (length(prob) != 1 || prob <= 0 || prob >= 1)
384
        stop("'prob' must be a single value between 0 and 1.\n")
385
}
386
387
#' Validate arguments passed to rstan
388
#'
389
#' Ensure that the options to be passed to \code{\link[rstan]{sampling}} are
390
#' valid, as to work around rstan issue #681.
391
#'
392
#' @param ... List of arguments to be checked.
393
#'
394
#' @return
395
#' Throws an error if any argument is not valid for \code{\link[rstan]{sampling}}.
396
#'
397
#' @noRd
398
validate.rstan.args <- function(...) {
399
    valid.args <- c("chains", "cores", "pars", "thin", "init", "check_data",
400
                    "sample_file", "diagnostic_file", "verbose", "algorithm",
401
                    "control", "open_progress", "show_messages", "chain_id",
402
                    "init_r", "test_grad", "append_samples", "refresh",
403
                    "save_warmup", "enable_random_init", "iter", "warmup")
404
    dots <- list(...)
405
    for (arg in names(dots))
406
        if (!arg %in% valid.args)
407
           stop("Argument '", arg, "' not recognized.")
408
}
409
410
#' Parameter names
411
#'
412
#' Get the parameter names corresponding to the regression coefficients or
413
#' matching a regular expression.
414
#'
415
#' @param obj An object of class `hsstan`.
416
#' @param pars Regular expression to match a parameter name, or `NULL`
417
#'        to retrieve the names of all regression coefficients.
418
#'
419
#' @return
420
#' A character vector.
421
#'
422
#' @noRd
423
get.pars <- function(object, pars) {
424
    if (is.null(pars))
425
        pars <- grep("^beta_", object$stanfit@model_pars, value=TRUE)
426
    else {
427
        if (!is.character(pars))
428
            stop("'pars' must be a character vector.")
429
        get.pars <- function(x) grep(x, object$stanfit@sim$fnames_oi, value=TRUE)
430
        pars <- unlist(lapply(pars, get.pars))
431
        if (length(pars) == 0)
432
            stop("No pattern in 'pars' matches parameter names.")
433
    }
434
    return(pars)
435
}
436
437
#' Create a design matrix with all unpenalized predictors first
438
#'
439
#' This is required as `model.matrix` puts the interaction terms after the
440
#' penalized predictors, but the Stan models expects all unpenalized terms to
441
#' appear before the penalized ones.
442
#'
443
#' @param x Data frame containing the variables of interest.
444
#' @param unpenalized Vector of variable names for the unpenalized covariates.
445
#' @param penalized Vector of variable names for the penalized predictors.
446
#'
447
#' @return
448
#' A design matrix with all unpenalized covariates (including interaction terms)
449
#' before the penalized predictors.
450
#'
451
#' @importFrom stats model.matrix reformulate
452
#' @noRd
453
ordered.model.matrix <- function(x, unpenalized, penalized) {
454
    X <- model.matrix(reformulate(c(unpenalized, penalized)), data=x)
455
    if (any(grepl("[:*]", unpenalized)))
456
        X <- X[, c(expand.terms(x, unpenalized), expand.terms(x, penalized)[-1])]
457
    return(X)
458
}
459
460
#' Expand variable names into formula terms
461
#'
462
#' @param x Data frame containing the variables of interest.
463
#' @param variables Vector of variable names.
464
#'
465
#' @return
466
#' A vector of variable names expanded by factor levels and interaction terms.
467
#'
468
#' @importFrom stats model.matrix reformulate
469
#' @noRd
470
expand.terms <- function(x, variables) {
471
    if (length(variables) == 0)
472
        return(character(0))
473
    colnames(model.matrix(reformulate(variables), x[1, ]))
474
}
475
476
#' Summarize a vector
477
#'
478
#' @param x A numerical vector.
479
#' @param prob Width of the interval between quantiles.
480
#'
481
#' @return
482
#' The mean, standard deviation and quantiles for the input vector.
483
#'
484
#' @noRd
485
vector.summary <- function(x, prob) {
486
    lower <- (1 - prob) / 2
487
    upper <- 1 - lower
488
    c(mean=mean(x), sd=stats::sd(x), stats::quantile(x, c(lower, upper)))
489
}
490
491
#' Check whether the model fitted is a logistic regression model.
492
#'
493
#' @param obj An object of class `hsstan`.
494
#'
495
#' @return
496
#' `TRUE` for logistic regression models, `FALSE` otherwise.
497
#'
498
#' @noRd
499
is.logistic <- function(obj) {
500
    obj$family$family == "binomial"
501
}
502
503
#' Comma-separated string concatenation
504
#'
505
#' Collapse the elements of a character vector into a comma-separated string.
506
#'
507
#' @param x Character vector.
508
#'
509
#' @return
510
#' A comma-separated string where each element of the original vector is
511
#' surrounded by single quotes.
512
#'
513
#' @noRd
514
collapse <- function(x) {
515
    paste0("'", x, "'", collapse=", ")
516
}
517
518
#' Fast computation of correlations
519
#'
520
#' This provides a loopless version of the computation of the correlation
521
#' coefficient between observed and predicted outcomes.
522
#'
523
#' @param y Vector of observed outcome.
524
#' @param x Matrix with as many columns as the number of elements in `y`,
525
#'          where each row corresponds to a predicted outcome.
526
#'
527
#' @return
528
#' A vector of correlations with as many elements as the number of rows in `x`.
529
#'
530
#' @noRd
531
fastCor <- function(y, x) {
532
    yx <- rbind(y, x)
533
    if (.Machine$sizeof.pointer == 8) {
534
        yx <- yx - rowMeans(yx)
535
        yx <- yx / sqrt(rowSums(yx^2))
536
        corr <- tcrossprod(yx, yx)
537
    } else {
538
        corr <- stats::cor(yx)
539
    }
540
    return(corr[-1, 1])
541
}
542
543
#' Log of sum of exponentials
544
#'
545
#' @noRd
546
logSumExp <- function(x) {
547
    xmax <- max(x)
548
    xmax + log(sum(exp(x - xmax)))
549
}
550
551
#' Log of average of exponentials
552
#'
553
#' @noRd
554
logMeanExp <- function(x) {
555
    logSumExp(x) - log(length(x))
556
}