--- a +++ b/R/misc.R @@ -0,0 +1,556 @@ +##============================================================================= +## +## Copyright (c) 2019 Marco Colombo +## +## This program is free software: you can redistribute it and/or modify +## it under the terms of the GNU General Public License as published by +## the Free Software Foundation, either version 3 of the License, or +## (at your option) any later version. +## +## This program is distributed in the hope that it will be useful, +## but WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +## GNU General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with this program. If not, see <http://www.gnu.org/licenses/>. +## +##============================================================================= + + +#' Validate an hsstan object +#' +#' Check that the object has been created by [hsstan()]. +#' +#' @param obj An object to be checked. +#' +#' @return +#' Throws an error if the object is not an `hsstan` object. +#' +#' @noRd +validate.hsstan <- function(obj) { + if (!inherits(obj, "hsstan")) { + stop("Not an object of class 'hsstan'.") + } +} + +#' Validate the posterior samples +#' +#' Check that the object contains valid posterior samples in the +#' `stanfit` field. +#' +#' @param obj An object of class `hsstan`. +#' +#' @return +#' Throws an error if the object does not contain posterior samples. +#' +#' @noRd +validate.samples <- function(obj) { + if (!inherits(obj$stanfit, "stanfit")) { + stop("No valid posterior samples stored in the 'hsstan' object.") + } +} + +#' Validate new data +#' +#' Check that the new data contains all variables used in the model and no +#' missing values, and generate the corresponding model matrix. +#' +#' @param obj Object of class `hsstan`. +#' @param newdata Optional data frame containing the variables used in the +#' model. If `NULL`, the model matrix used when fitting the model +#' is returned. +#' +#' @return +#' A design matrix corresponding to the variables used in the model. +#' +#' @noRd +validate.newdata <- function(obj, newdata) { + if (is.null(newdata)) + newdata <- obj$data + else if (!inherits(newdata, c("data.frame", "matrix"))) + stop("'newdata' must be a data frame or a matrix.") + if (nrow(newdata) == 0 || ncol(newdata) == 0) + stop("'newdata' contains no rows or no columns.") + + ## only check for NAs in the variables used in the model + vars <- with(obj$model.terms, c(outcome, unpenalized, penalized)) + newdata <- newdata[, colnames(newdata) %in% vars, drop=FALSE] + if (any(is.na(newdata))) + stop("'newdata' contains missing values.") + + ## this adds the intercept column back + ordered.model.matrix(as.data.frame(newdata), + obj$model.terms$unpenalized, + obj$model.terms$penalized) +} + +#' Validate a model formula +#' +#' Check that the formula that specifies a model contains all required elements. +#' +#' @param model Formula to be checked. +#' @param penalized Vector of names for the penalized predictors. +#' +#' @return +#' A list containing the formula representing the covariates model, the name of +#' the outcome variable, the names of the upenalized and penalized predictors. +#' +#' @importFrom stats as.formula terms +#' @noRd +validate.model <- function(model, penalized) { + if (is.character(model) && length(model) > 1) + stop("Model formula specified incorrectly.") + model <- as.formula(model) + tt <- terms(model) + if (attr(tt, "response") == 0) + stop("No outcome variable specified in the model.") + if (attr(tt, "intercept") == 0) + stop("Models with no intercept are not supported.") + if (length(penalized) > 0 && !is.character(penalized)) + stop("'penalized' must be a character vector.") + if (any(grepl("[:*]", penalized))) + stop("Interaction terms in penalized predictors are not supported.") + penalized <- setdiff(unique(trimws(penalized)), "") + return(list(outcome=as.character(model)[2], + unpenalized=setdiff(attr(tt, "term.labels"), penalized), + penalized=penalized)) +} + +#' Validate the model data +#' +#' Check if the model data can be used with the given model formula and +#' penalized predictors. +#' +#' @param x An object to be checked. +#' @param model Validated model formula. +#' +#' @return +#' A data frame containing the model data. A factor or logical outcome variable +#' is replaced by its numeric equivalent. +#' +#' @noRd +validate.data <- function(x, model) { + if (!inherits(x, c("data.frame", "matrix"))) + stop("'x' must be a data frame or a matrix.") + x <- as.data.frame(x) + validate.variables(x, model$outcome) + validate.variables(x, c(model$unpenalized, model$penalized)) + x[[model$outcome]] <- validate.outcome(x[[model$outcome]]) + return(x) +} + +#' Validate variables +#' +#' Check that the required variables are in the dataset. +#' +#' @param x Data frame containing the variables of interest. +#' @param variables Vector of variable names. +#' +#' @return +#' Throws if variables are not present in the dataset or contain missing values. +#' +#' @noRd +validate.variables <- function(x, variables) { + ## unpack interaction terms + variables <- unique(unlist(strsplit(as.character(variables), ":"))) + if (length(variables) == 0) + stop("No predictors present in the model.") + var.match <- match(variables, colnames(x)) + if (anyNA(var.match)) + stop(collapse(variables[is.na(var.match)]), " not present in 'x'.") + if (anyNA(x[, variables])) + stop("Model variables contain missing values.") +} + +#' Validate the outcome variable +#' +#' Check that the outcome variable can be converted to a valid numerical +#' vector. +#' +#' @param y Outcome vector to be checked. +#' +#' @return +#' A numeric vector. +#' +#' @noRd +validate.outcome <- function(y) { + if (is.factor(y)) { + if (nlevels(y) != 2) + stop("A factor outcome variable can only have two levels.") + y <- as.integer(y) - 1 + } + if (!(is.numeric(y) || is.logical(y))) + stop("Outcome variable of invalid type.") + return(as.numeric(y)) +} + +#' Validate the family argument +#' +#' Ensure that the family argument has been specified correctly. +#' This is inspired by code in \code{\link{glm}}. +#' +#' @param family Family argument to test. +#' @param y Outcome variable. +#' +#' @return +#' A valid family. The function throws an error if the family argument cannot +#' be used. +#' +#' @importFrom methods is +#' @noRd +validate.family <- function(family, y) { + if (missing(family)) + stop("Argument of 'family' is missing.") + if (is.character(family)) + tryCatch( + family <- get(family, mode="function", envir=parent.frame(2)), + error=function(e) + stop("'", family, "' is not a valid family.") + ) + if (is.function(family)) + family <- family() + if (!is(family, "family")) + stop("Argument of 'family' is not a valid family.") + if (!family$family %in% c("gaussian", "binomial")) + stop("Only 'gaussian' and 'binomial' are supported families.") + + if (family$family == "binomial") { + if (length(table(y)) != 2) + stop("Outcome variable must contain two classes with family=binomial.") + if (!is.factor(y) && any(y < 0 | y > 1)) + stop("Outcome variable must contain 0-1 values with family=binomial.") + } + + return(family) +} + +#' Validate a vector of indices +#' +#' @param x Vector to be checked. +#' @param N Maximum valid index. +#' @param name Name of the vector to report in error messages. +#' @param throw.duplicates Whether the function should throw if the vector +#' contains duplicate elements (`TRUE` by default). +#' +#' @return +#' Throws an error if the given vector is not an integer vector or contains +#' missing, out of bounds or duplicate indices (if `throw.duplicates` is `TRUE`). +#' +#' @noRd +validate.indices <- function(x, N, name, throw.duplicates=TRUE) { + if (anyNA(x)) + stop("'", name, "' contains missing values.") + if (!is.numeric(x) || NCOL(x) > 1 || any(x != as.integer(x))) + stop("'", name, "' must be an integer vector.") + if (length(x) < 2) + stop("'", name, "' must contain at least two elements.") + if (any(x < 1 | x > N)) + stop("'", name, "' contains out of bounds indices.") + if (throw.duplicates && any(duplicated(x))) + stop("'", name, "' contains duplicate indices.") +} + +#' Validate the cross-validation folds +#' +#' @param folds Folds to be checked or `NULL`. +#' @param N Number of observations. +#' +#' @return +#' An integer vector with one element per observation indicating the +#' cross-validation fold in which the observation should be withdrawn. +#' +#' @noRd +validate.folds <- function(folds, N) { + if (is.null(folds)) + return(rep(1, N)) + validate.indices(folds, N, "folds", throw.duplicates=FALSE) + if (length(folds) != N) + stop("'folds' should have length ", N, ".") + K <- length(unique(folds)) + if (!all(1:K %in% folds)) + stop("'folds' must contain all indices up to ", K, ".") + folds <- as.integer(folds) +} + +#' Validate start.from +#' +#' Check that the predictor names provided is a valid subset of the variables +#' used in the model. +#' +#' @param obj An object of class `hsstan`. +#' @param start.from Vector to be checked. +#' +#' @return +#' A list of two elements: the names of the model terms matching `start.from` +#' and a vector of indices corresponding to the names listed in `start.from`. +#' Throws an error if any of the names mentioned does not match those available +#' in the model terms. +#' +#' @noRd +validate.start.from <- function(obj, start.from) { + unp.terms <- obj$model.terms$unpenalized + unp.betas <- names(obj$betas$unpenalized) + mod.terms <- c(unp.terms, obj$model.terms$penalized) + mod.betas <- c(unp.betas, names(obj$betas$penalized)) + start.from <- setdiff(start.from, "") + if (is.null(start.from)) { + if (length(obj$model.terms$penalized) > 0) + return(list(start.from=unp.terms, idx=seq_along(unp.betas))) + else + return(list(start.from=character(0), idx=1)) + } + if (length(start.from) == 0) + return(list(start.from=character(0), idx=1)) + if (anyNA(start.from)) + stop("'start.from' contains missing values.") + var.match <- match(start.from, mod.terms) + if (anyNA(var.match)) + stop("'start.from' contains ", collapse(start.from[is.na(var.match)]), + ", which cannot be matched.") + + ## unpack interaction terms so that also main effects are matched + start.from <- mod.terms[mod.terms %in% + c(start.from, unlist(strsplit(start.from, ":")))] + chosen <- expand.terms(obj$data, start.from) + + ## also consider interaction terms in reverse order + chosen <- c(chosen, sapply(strsplit(chosen[grep(":", chosen)], ":"), + function(z) c(z, paste(rev(z), collapse=":")))) + return(list(start.from=start.from, idx=which(mod.betas %in% chosen))) +} + +#' Validate a positive or non-negative scalar value +#' +#' @param x Value to validate. +#' @param name Variable name to report in case of error. +#' @param int Whether the value has to be an integer (`FALSE` by default). +#' +#' @return +#' Throws an error if the given value is not a positive or non-negative +#' scalar (or integer scalar). +#' +#' @noRd +validate.positive.scalar <- function(x, name, int=FALSE) { + if (!is.numeric(x) || length(x) != 1 || is.na(x) || x <= 0 || + (int && (x > .Machine$integer.max || x != as.integer(x)))) + stop(sprintf("'%s' must be a positive %s.", name, + ifelse(int, "integer", "scalar")), call.=FALSE) +} + +#' @noRd +validate.nonnegative.scalar <- function(x, name, int=FALSE) { + if (!is.numeric(x) || length(x) != 1 || is.na(x) || x < 0 || + (int && (x > .Machine$integer.max || x != as.integer(x)))) + stop(sprintf("'%s' must be a non-negative %s.", name, + ifelse(int, "integer", "scalar")), call.=FALSE) +} + +#' Validate adapt.delta +#' +#' Check that an adaptation acceptance probability is valid. +#' +#' @param adapt.delta Value to be checked. +#' +#' @return +#' Throws an error if the given value is not a valid acceptance probability +#' for adaptation. +#' +#' @noRd +validate.adapt.delta <- function(adapt.delta) { + if (!is.numeric(adapt.delta) || length(adapt.delta) != 1) { + stop("'adapt.delta' must be a single numerical value.") + } + if (adapt.delta < 0.8) { + stop("'adapt.delta' must be at least 0.8.") + } + if (adapt.delta >= 1) { + stop("'adapt.delta' must be less than 1.") + } +} + +#' Validate a probability +#' +#' Check that a probability value is valid. +#' +#' @param prob Value to be checked. +#' +#' @return +#' Throws an error if the given value is not a valid probability. +#' +#' @noRd +validate.probability <- function(prob) { + if (length(prob) != 1 || prob <= 0 || prob >= 1) + stop("'prob' must be a single value between 0 and 1.\n") +} + +#' Validate arguments passed to rstan +#' +#' Ensure that the options to be passed to \code{\link[rstan]{sampling}} are +#' valid, as to work around rstan issue #681. +#' +#' @param ... List of arguments to be checked. +#' +#' @return +#' Throws an error if any argument is not valid for \code{\link[rstan]{sampling}}. +#' +#' @noRd +validate.rstan.args <- function(...) { + valid.args <- c("chains", "cores", "pars", "thin", "init", "check_data", + "sample_file", "diagnostic_file", "verbose", "algorithm", + "control", "open_progress", "show_messages", "chain_id", + "init_r", "test_grad", "append_samples", "refresh", + "save_warmup", "enable_random_init", "iter", "warmup") + dots <- list(...) + for (arg in names(dots)) + if (!arg %in% valid.args) + stop("Argument '", arg, "' not recognized.") +} + +#' Parameter names +#' +#' Get the parameter names corresponding to the regression coefficients or +#' matching a regular expression. +#' +#' @param obj An object of class `hsstan`. +#' @param pars Regular expression to match a parameter name, or `NULL` +#' to retrieve the names of all regression coefficients. +#' +#' @return +#' A character vector. +#' +#' @noRd +get.pars <- function(object, pars) { + if (is.null(pars)) + pars <- grep("^beta_", object$stanfit@model_pars, value=TRUE) + else { + if (!is.character(pars)) + stop("'pars' must be a character vector.") + get.pars <- function(x) grep(x, object$stanfit@sim$fnames_oi, value=TRUE) + pars <- unlist(lapply(pars, get.pars)) + if (length(pars) == 0) + stop("No pattern in 'pars' matches parameter names.") + } + return(pars) +} + +#' Create a design matrix with all unpenalized predictors first +#' +#' This is required as `model.matrix` puts the interaction terms after the +#' penalized predictors, but the Stan models expects all unpenalized terms to +#' appear before the penalized ones. +#' +#' @param x Data frame containing the variables of interest. +#' @param unpenalized Vector of variable names for the unpenalized covariates. +#' @param penalized Vector of variable names for the penalized predictors. +#' +#' @return +#' A design matrix with all unpenalized covariates (including interaction terms) +#' before the penalized predictors. +#' +#' @importFrom stats model.matrix reformulate +#' @noRd +ordered.model.matrix <- function(x, unpenalized, penalized) { + X <- model.matrix(reformulate(c(unpenalized, penalized)), data=x) + if (any(grepl("[:*]", unpenalized))) + X <- X[, c(expand.terms(x, unpenalized), expand.terms(x, penalized)[-1])] + return(X) +} + +#' Expand variable names into formula terms +#' +#' @param x Data frame containing the variables of interest. +#' @param variables Vector of variable names. +#' +#' @return +#' A vector of variable names expanded by factor levels and interaction terms. +#' +#' @importFrom stats model.matrix reformulate +#' @noRd +expand.terms <- function(x, variables) { + if (length(variables) == 0) + return(character(0)) + colnames(model.matrix(reformulate(variables), x[1, ])) +} + +#' Summarize a vector +#' +#' @param x A numerical vector. +#' @param prob Width of the interval between quantiles. +#' +#' @return +#' The mean, standard deviation and quantiles for the input vector. +#' +#' @noRd +vector.summary <- function(x, prob) { + lower <- (1 - prob) / 2 + upper <- 1 - lower + c(mean=mean(x), sd=stats::sd(x), stats::quantile(x, c(lower, upper))) +} + +#' Check whether the model fitted is a logistic regression model. +#' +#' @param obj An object of class `hsstan`. +#' +#' @return +#' `TRUE` for logistic regression models, `FALSE` otherwise. +#' +#' @noRd +is.logistic <- function(obj) { + obj$family$family == "binomial" +} + +#' Comma-separated string concatenation +#' +#' Collapse the elements of a character vector into a comma-separated string. +#' +#' @param x Character vector. +#' +#' @return +#' A comma-separated string where each element of the original vector is +#' surrounded by single quotes. +#' +#' @noRd +collapse <- function(x) { + paste0("'", x, "'", collapse=", ") +} + +#' Fast computation of correlations +#' +#' This provides a loopless version of the computation of the correlation +#' coefficient between observed and predicted outcomes. +#' +#' @param y Vector of observed outcome. +#' @param x Matrix with as many columns as the number of elements in `y`, +#' where each row corresponds to a predicted outcome. +#' +#' @return +#' A vector of correlations with as many elements as the number of rows in `x`. +#' +#' @noRd +fastCor <- function(y, x) { + yx <- rbind(y, x) + if (.Machine$sizeof.pointer == 8) { + yx <- yx - rowMeans(yx) + yx <- yx / sqrt(rowSums(yx^2)) + corr <- tcrossprod(yx, yx) + } else { + corr <- stats::cor(yx) + } + return(corr[-1, 1]) +} + +#' Log of sum of exponentials +#' +#' @noRd +logSumExp <- function(x) { + xmax <- max(x) + xmax + log(sum(exp(x - xmax))) +} + +#' Log of average of exponentials +#' +#' @noRd +logMeanExp <- function(x) { + logSumExp(x) - log(length(x)) +}