--- a
+++ b/R/outbreaker_chains_methods.R
@@ -0,0 +1,486 @@
+#' Basic methods for processing outbreaker results
+#'
+#' Several methods are defined for instances of the class
+#' \code{outbreaker_chains}, returned by \code{\link{outbreaker}}, including:
+#' \code{print}, \code{plot}, \code{summary}
+#'
+#' @rdname outbreaker_chains
+#'
+#' @aliases outbreaker_chains print.outbreaker_chains plot.outbreaker_chains
+#'   summary.outbreaker_chains
+#'
+#' @author Thibaut Jombart (\email{thibautjombart@@gmail.com}).
+#'
+#' @param x an \code{outbreaker_chains} object as returned by \code{outbreaker}.
+#' @param n_row the number of rows to display in head and tail; defaults to 3.
+#' @param n_col the number of columns to display; defaults to 8.
+#' @param ... further arguments to be passed to other methods.
+#'
+#' @export
+#' @importFrom utils head tail
+#'
+print.outbreaker_chains <- function(x, n_row = 3, n_col = 8, ...) {
+  cat("\n\n ///// outbreaker results ///\n")
+  cat("\nclass: ", class(x))
+  cat("\ndimensions", nrow(x), "rows, ", ncol(x), "columns")
+
+  ## process names of variables not shown
+  if (ncol(x) > n_col) {
+    ori_names <- names(x)
+    x <- x[, seq_len(min(n_col, ncol(x)))]
+
+    not_shown <- setdiff(ori_names, names(x))
+
+    alpha_txt <- paste(not_shown[range(grep("alpha", not_shown))], collapse=" - ")
+    t_inf_txt <- paste(not_shown[range(grep("t_inf", not_shown))], collapse=" - ")
+    kappa_txt <- paste(not_shown[range(grep("kappa", not_shown))], collapse=" - ")
+
+    cat("\nancestries not shown:", alpha_txt)
+    cat("\ninfection dates not shown:", t_inf_txt)
+    cat("\nintermediate generations not shown:", kappa_txt)
+  }
+
+  ## heads and tails
+  cat("\n\n/// head //\n")
+  print(head(as.data.frame(x), n_row))
+  cat("\n...")
+  cat("\n/// tail //\n")
+  print(tail(as.data.frame(x), n_row))
+}
+
+
+
+
+
+#' @rdname outbreaker_chains
+#'
+#' @param y a character string indicating which element of an
+#'   \code{outbreaker_chains} object to plot.
+#'
+#' @param type a character string indicating the kind of plot to be used (see details).
+#' 
+#' @param group a vector of character strings indicating the parameters to display, 
+#' or "all" to display all global parameters (non node-specific parameters).
+#'
+#' @param burnin the number of iterations to be discarded as burnin.
+#'
+#' @param min_support a number between 0 and 1 indicating the minimum support of
+#' ancestries to be plotted; only used if 'type' is 'network'.
+#'
+#' @param labels a vector of length N indicating the case labels (must be
+#'   provided in the same order used for dates of symptom onset).
+#'
+## #' @param dens_all a logical indicating if the overal density computed over
+## all runs should be displayed; defaults to TRUE #' @param col the colors to be
+## used for different runs.
+#'
+#' @export
+#'
+#' @seealso See \href{http://www.repidemicsconsortium.org/outbreaker2/articles/introduction.html#graphics}{introduction vignette} for detailed examples on how to visualise \code{outbreaker_chains} objects.
+#'
+#' @details \code{type} indicates the type of graphic to plot:
+#'
+#' \itemize{
+#'
+#' \item \code{trace} to visualise MCMC traces for parameters or augmented data (plots the
+#' log-likelihood by default)
+#'
+#' \item \code{hist} to plot histograms of quantitative values
+#'
+#' \item \code{density} to plot kernel density estimations of quantitative values
+#'
+#' \item \code{alpha} to visualise the posterior frequency of ancestries
+#'
+#' \item \code{network} to visualise the transmission tree; note that
+#'  this opens up an interactive plot and requires a web browser with
+#'  Javascript enabled; the argument `min_support` is useful to select only the
+#'  most supported ancestries and avoid displaying too many links
+#'
+#' \item \code{kappa} to visualise the distributions generations between cases and their
+#' ancestor/infector
+#'
+#' }
+#'
+#' @importFrom ggplot2 ggplot geom_line geom_point geom_histogram geom_density
+#'   geom_violin aes aes_string coord_flip labs guides scale_size_area
+#'   scale_x_discrete scale_y_discrete scale_color_manual scale_fill_manual
+#'   scale_x_continuous scale_y_continuous theme_bw facet_wrap
+#'
+#' @importFrom stats reshape
+#' @importFrom grDevices xyTable
+#' @importFrom graphics plot
+#'
+plot.outbreaker_chains <- function(x, y = "post",
+                                   type = c("trace", "hist", "density",
+                                            "alpha", "t_inf", "kappa", "network"),
+                                   group = NULL, 
+                                   burnin = 0, min_support = 0.1, labels = NULL, ...) {
+
+  ## CHECKS ##
+  type <- match.arg(type)
+  if (!y %in% names(x)) {
+    stop(paste(y,"is not a column of x"))
+  }
+
+  ## THIS IS JUST TO APPEASE R CMD check
+
+  ## hopefully cran will avoid spurious warnings along the lines of "no
+  ## visible binding for global variable" when using ggplot2::aes(...)
+
+  frequency <- NULL
+
+  ## GET DATA TO PLOT ##
+  if (burnin > max(x$step)) {
+    stop("burnin exceeds the number of steps in x")
+  }
+  x <- x[x$step>burnin,,drop = FALSE]
+
+  ## check group
+  if(!is.null(group)) {
+    if(length(group) == 1L && group == "all") {
+      ##remove _[digit] vars
+      y_vars = names(x)[!grepl("(_[[:digit:]]+$)", names(x))]
+    } else if(all(group %in% names(x))) {
+      y_vars = c("step", group)
+    } else {
+      stop("grouping variables not found in outbreaker object")
+    }
+    ## get only relevant data
+    x_sub = as.data.frame(x)[,y_vars]
+    ## switch it to long format to use in ggplot
+    x_long = reshape(x_sub, 
+                     idvar = "step", 
+                     ids = x_sub$step,
+                     direction = "long", 
+                     new.row.names = NULL,
+                     timevar = "Parameters", 
+                     v.names = "y",
+                     varying = list(names(x_sub)[2:ncol(x_sub)]), 
+                     times = names(x_sub)[2:ncol(x_sub)])
+  }
+
+  ## MAKE PLOT ##
+  if (type == "trace") {
+    if (!is.null(group)) {
+      out <- ggplot(x_long) +
+        geom_line(aes_string(x = "step", y = "y")) +
+        scale_x_continuous(name = "Iteration") + 
+        scale_y_continuous(name = NULL) + 
+        facet_wrap(~ Parameters, scales = "free") 
+    } else {
+      out <- ggplot(x) +
+        geom_line(aes_string(x = "step", y = y)) +
+        labs(x = "Iteration",
+             y = y,
+             title = paste("trace:",y))
+    }
+  }
+
+  if (type == "hist") {
+    if (!is.null(group)) {
+      out <- ggplot(x_long) +
+        geom_histogram(aes_string(x = "y")) +
+        geom_point(aes_string(x = "y", y = 0),
+                   shape="|",
+                   alpha = 0.5,
+                   size = 3) +
+        scale_x_continuous(name = NULL) + 
+        scale_y_continuous(name = NULL) + 
+        facet_wrap(~Parameters, scales = "free") 
+    } else {
+      out <- ggplot(x) +
+        geom_histogram(aes_string(x = y)) +
+        geom_point(aes_string(x = y, y = 0),
+                   shape="|",
+                   alpha = 0.5,
+                   size = 3) +
+        labs(x = y,
+             title = paste("histogram:",y))
+    }
+  }
+
+  if (type == "density") {
+    if (!is.null(group)) {
+      out <- ggplot(x_long) +
+        geom_density(aes_string(x = "y")) +
+        geom_point(aes_string(x = "y", y = 0),
+                   shape="|",
+                   alpha = 0.5,
+                   size = 3) +
+        scale_x_continuous(name = NULL) + 
+        scale_y_continuous(name = NULL) + 
+        facet_wrap(~Parameters, scales = "free") 
+    } else {
+      out <- ggplot(x) +
+        geom_density(aes_string(x = y)) +
+        geom_point(aes_string(x = y, y = 0),
+                   shape="|",
+                   alpha = 0.5,
+                   size = 3) +
+        labs(x = y,
+             title = paste("density:",y))
+    }
+  }
+
+  if (type == "alpha") {
+    alpha <- as.matrix(x[,grep("alpha", names(x))])
+    colnames(alpha) <- seq_len(ncol(alpha))
+    from <- as.vector(alpha)
+    to <- as.vector(col(alpha))
+    from[is.na(from)] <- 0
+    out_dat <- data.frame(xyTable(from,to))
+    names(out_dat) <- c("from", "to", "frequency")
+    ## Calculate proportion among ancestries
+    get_prop <- function(i) {
+        ind <- which(out_dat$to == out_dat$to[i])
+        out_dat[[3]][i]/sum(out_dat[[3]][ind])
+    }
+    ## Return labels, if provided
+    get_alpha_lab <- function(axis, labels = NULL) {
+      if(is.null(labels)) labels <- seq_len(ncol(alpha))
+      if(axis == 'x') return(labels) else
+      if(axis == 'y') return(c("Import", labels))
+    }
+    ## Return custom colors if provided
+    get_alpha_color <- function(color = NULL) {
+      if(is.null(color)) return(NULL)
+      else return(scale_color_manual(values = color))
+    }
+    ## This joining function is needed so that the '...' argument can be passed
+    ## to two functions with different arguments
+    get_lab_color <- function(labels = NULL, color = NULL) {
+      list(alpha_lab_x = get_alpha_lab('x', labels),
+           alpha_lab_y = get_alpha_lab('y', labels),
+           alpha_color = get_alpha_color(color))
+    }
+
+    tmp <- get_lab_color(labels, ...)
+    out_dat[3] <- vapply(seq_along(out_dat[[3]]), get_prop, 1)
+    out_dat$from <- factor(out_dat$from, levels = c(0, sort(unique(out_dat$to))))
+    out_dat$to <- factor(out_dat$to, levels = sort(unique(out_dat$to)))
+    out <- ggplot(out_dat) +
+      geom_point(aes(x = to, y = from, size = frequency, color = to)) +
+      scale_x_discrete(drop = FALSE, labels = tmp$alpha_lab_x) +
+      scale_y_discrete(drop = FALSE, labels = tmp$alpha_lab_y) +
+      labs(x = 'To', y = 'From', size = 'Posterior\nfrequency') +
+      tmp$alpha_color +
+      scale_size_area() +
+      guides(colour = "none")
+  }
+
+  if (type == "t_inf") {
+    get_t_inf_lab <- function(labels = NULL) {
+      N <- ncol(t_inf)
+      if(is.null(labels)) labels <- 1:N
+      return(labels)
+    }
+    ## Return custom colors if provided
+    get_t_inf_color <- function(color = NULL) {
+      if(is.null(color)) return(NULL)
+      else return(scale_fill_manual(values = color))
+    }
+    ## This joining function is needed so that the '...' argument can be passed
+    ## to two functions with different arguments
+    get_lab_color <- function(labels = NULL, color = NULL) {
+      list(t_inf_lab_x = get_t_inf_lab(labels),
+           t_inf_color = get_t_inf_color(color))
+    }
+
+    t_inf <- as.matrix(x[,grep("t_inf", names(x))])
+    tmp <- get_lab_color(labels, ...)
+    dates <- as.vector(t_inf)
+    cases <- as.vector(col(t_inf))
+    out_dat <- data.frame(cases = factor(cases), dates = dates)
+    out <- ggplot(out_dat) +
+      geom_violin(aes(x = cases, y = dates, fill = cases)) +
+      coord_flip() + 
+      guides(fill = "none") +
+      labs(y = 'Infection time', x = NULL) +
+      tmp$t_inf_color +
+      scale_x_discrete(labels = tmp$t_inf_lab)
+  }
+
+  if (type == "kappa") {
+    get_kappa_lab <- function(labels = NULL) {
+      N <- ncol(kappa)
+      if(is.null(labels)) labels <- 1:N
+      return(labels)
+    }
+    kappa <- as.matrix(x[,grep("kappa", names(x))])
+    generations <- as.vector(kappa)
+    cases <- as.vector(col(kappa))
+    to_keep <- !is.na(generations)
+    generations <- generations[to_keep]
+    cases <- cases[to_keep]
+    out_dat <- data.frame(xyTable(generations, cases))
+    get_prop <- function(i) {
+        ind <- which(out_dat$y == out_dat$y[i])
+        out_dat[[3]][i]/sum(out_dat[[3]][ind])
+    }
+    out_dat[3] <- vapply(seq_along(out_dat[[3]]), get_prop, 1)
+    names(out_dat) <- c("generations", "cases", "frequency")
+    out <- ggplot(out_dat) +
+      geom_point(aes(x = generations, y = as.factor(cases), size = frequency, color = factor(cases))) +
+      scale_size_area() +
+      scale_y_discrete(labels = get_kappa_lab(labels)) +
+      guides(colour = "none") +
+      labs(title = "number of generations between cases",
+           x = "number of generations to ancestor",
+           y = NULL)
+  }
+
+  if (type == "network") {
+    ## extract edge info: ancestries
+    alpha <- x[, grep("alpha",names(x)), drop = FALSE]
+    from <- unlist(alpha)
+    to <- as.vector(col(alpha))
+    N <- ncol(alpha)
+    edges <- stats::na.omit(data.frame(xyTable(from, to)))
+    edges[3] <- edges$number/nrow(alpha)
+    names(edges) <- c("from", "to", "value")
+    edges <- edges[edges$value > min_support,,drop = FALSE]
+    edges$arrows <- "to"
+    case_cols <- cases_pal(N)
+    edges$color <- case_cols[edges$from]
+
+
+    ## ## extract edge info: timing
+    ## t_inf <- x[, grep("t_inf",names(x)), drop = FALSE]
+    ## mean_time <- apply(t_inf, 2, mean)
+    ## mean_delay <- mean_time[edges$to] - mean_time[edges$from]
+    ## mean_delay[mean_delay<1] <- 1
+    ## edges$label <- paste(round(mean_delay), "days")
+
+    ## node info
+    find_nodes_size <- function(i) {
+      sum(from==i, na.rm = TRUE) / nrow(alpha)
+    }
+    get_node_lab <- function(labels = NULL) {
+      if(is.null(labels)) labels <- 1:N
+      return(labels)
+    }
+    nodes <- data.frame(id = seq_len(ncol(alpha)),
+                        label = seq_len(ncol(alpha)))
+    nodes$value <- vapply(nodes$id,
+                          find_nodes_size,
+                          numeric(1))
+    nodes$color <- case_cols
+    nodes$shape <- rep("dot", N)
+    nodes$label <- get_node_lab(labels)
+
+    smry <- summary(x, burnin = burnin)
+    is_imported <- is.na(smry$tree$from)
+    nodes$shaped[is_imported] <- "star"
+
+    ## generate graph
+    out <- visNetwork::visNetwork(nodes = nodes, edges = edges, ...)
+    out <- visNetwork::visNodes(out, shadow = list(enabled = TRUE, size = 10),
+                                color = list(highlight = "red"))
+    out <- visNetwork::visEdges(out, arrows = list(
+      to = list(enabled = TRUE, scaleFactor = 0.2)),
+      color = list(highlight = "red"))
+
+  }
+
+  return(out)
+}
+
+
+
+
+#' @rdname outbreaker_chains
+#' @param object an \code{outbreaker_chains} object as returned by \code{outbreaker}.
+#'
+#' @param method the method used to determine consensus ancestries. 'mpa'
+#'   (maximum posterior ancestry) simply returns the posterior ancestry with the
+#'   highest posterior support for each case, even if this includes
+#'   cycles. 'decycle' will return the maximum posterior ancestry, except when
+#'   cycles are detected, in which case the link in the cycle with the lowest
+#'   support is pruned and the tree recalculated.
+#'
+#' @export
+#' @importFrom stats median
+summary.outbreaker_chains <- function(object, burnin = 0, method = c("mpa", "decycle"), ...) {
+  ## check burnin ##
+  x <- object
+  if (burnin > max(x$step)) {
+    stop("burnin exceeds the number of steps in object")
+  }
+  x <- x[x$step>burnin,,drop = FALSE]
+
+
+  ## make output ##
+  out <- list()
+
+
+  ## summary for $step ##
+  interv <- ifelse(nrow(x)>2, diff(tail(x$step, 2)), NA)
+  out$step <- c(first = min(x$step),
+                last = max(x$step),
+                interval = interv,
+                n_steps = length(x$step)
+                )
+
+
+  ## summary of post, like, prior ##
+  out$post <- summary(x$post)
+  out$like <- summary(x$like)
+  out$prior <- summary(x$prior)
+
+
+  ## summary for mu ##
+  out$mu <- summary(x$mu)
+
+  ## summary for pi ##
+  out$pi <- summary(x$pi)
+
+
+  ## summary tree ##
+  out$tree <- list()
+
+  ## summary of alpha ##
+  alpha <- as.matrix(x[,grep("alpha", names(x))])
+
+  method <- match.arg(method)
+
+  if(method == 'mpa') {
+
+    ## function to get most frequent item
+    f1 <- function(x) {
+      as.integer(names(sort(table(x, exclude = NULL), decreasing = TRUE)[1]))
+    }
+    out$tree$from <- apply(alpha, 2, f1)
+    out$tree$to <- seq_len(ncol(alpha))
+
+    ## function to get frequency of most frequent item
+    f2 <- function(x) {
+      (sort(table(x), decreasing = TRUE)/length(x))[1]
+    }
+    support <- apply(alpha, 2, f2)
+
+  } else if(method == 'decycle') {
+
+    cons <- .decycle_tree(x)
+    out$tree$from <- cons$from
+    out$tree$to <- cons$to
+    support <- cons$support
+
+  }
+
+  ## summary of t_inf ##
+  t_inf <- as.matrix(x[,grep("t_inf", names(x))])
+  out$tree$time <- apply(t_inf, 2, median)
+
+  out$tree$support <- support
+
+  ## summary of kappa ##
+  kappa <- as.matrix(x[,grep("kappa", names(x))])
+  out$tree$generations <- apply(kappa, 2, median, na.rm = TRUE)
+  out$tree$generations[is.na(out$tree$from)] <- NA
+
+  ## shape tree as a data.frame
+  out$tree <- as.data.frame(out$tree)
+  rownames(out$tree) <- NULL
+
+  return(out)
+}