[dfe06d]: / R / outbreaker_chains_methods.R

Download this file

487 lines (426 with data), 16.3 kB

  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
#' Basic methods for processing outbreaker results
#'
#' Several methods are defined for instances of the class
#' \code{outbreaker_chains}, returned by \code{\link{outbreaker}}, including:
#' \code{print}, \code{plot}, \code{summary}
#'
#' @rdname outbreaker_chains
#'
#' @aliases outbreaker_chains print.outbreaker_chains plot.outbreaker_chains
#' summary.outbreaker_chains
#'
#' @author Thibaut Jombart (\email{thibautjombart@@gmail.com}).
#'
#' @param x an \code{outbreaker_chains} object as returned by \code{outbreaker}.
#' @param n_row the number of rows to display in head and tail; defaults to 3.
#' @param n_col the number of columns to display; defaults to 8.
#' @param ... further arguments to be passed to other methods.
#'
#' @export
#' @importFrom utils head tail
#'
print.outbreaker_chains <- function(x, n_row = 3, n_col = 8, ...) {
cat("\n\n ///// outbreaker results ///\n")
cat("\nclass: ", class(x))
cat("\ndimensions", nrow(x), "rows, ", ncol(x), "columns")
## process names of variables not shown
if (ncol(x) > n_col) {
ori_names <- names(x)
x <- x[, seq_len(min(n_col, ncol(x)))]
not_shown <- setdiff(ori_names, names(x))
alpha_txt <- paste(not_shown[range(grep("alpha", not_shown))], collapse=" - ")
t_inf_txt <- paste(not_shown[range(grep("t_inf", not_shown))], collapse=" - ")
kappa_txt <- paste(not_shown[range(grep("kappa", not_shown))], collapse=" - ")
cat("\nancestries not shown:", alpha_txt)
cat("\ninfection dates not shown:", t_inf_txt)
cat("\nintermediate generations not shown:", kappa_txt)
}
## heads and tails
cat("\n\n/// head //\n")
print(head(as.data.frame(x), n_row))
cat("\n...")
cat("\n/// tail //\n")
print(tail(as.data.frame(x), n_row))
}
#' @rdname outbreaker_chains
#'
#' @param y a character string indicating which element of an
#' \code{outbreaker_chains} object to plot.
#'
#' @param type a character string indicating the kind of plot to be used (see details).
#'
#' @param group a vector of character strings indicating the parameters to display,
#' or "all" to display all global parameters (non node-specific parameters).
#'
#' @param burnin the number of iterations to be discarded as burnin.
#'
#' @param min_support a number between 0 and 1 indicating the minimum support of
#' ancestries to be plotted; only used if 'type' is 'network'.
#'
#' @param labels a vector of length N indicating the case labels (must be
#' provided in the same order used for dates of symptom onset).
#'
## #' @param dens_all a logical indicating if the overal density computed over
## all runs should be displayed; defaults to TRUE #' @param col the colors to be
## used for different runs.
#'
#' @export
#'
#' @seealso See \href{http://www.repidemicsconsortium.org/outbreaker2/articles/introduction.html#graphics}{introduction vignette} for detailed examples on how to visualise \code{outbreaker_chains} objects.
#'
#' @details \code{type} indicates the type of graphic to plot:
#'
#' \itemize{
#'
#' \item \code{trace} to visualise MCMC traces for parameters or augmented data (plots the
#' log-likelihood by default)
#'
#' \item \code{hist} to plot histograms of quantitative values
#'
#' \item \code{density} to plot kernel density estimations of quantitative values
#'
#' \item \code{alpha} to visualise the posterior frequency of ancestries
#'
#' \item \code{network} to visualise the transmission tree; note that
#' this opens up an interactive plot and requires a web browser with
#' Javascript enabled; the argument `min_support` is useful to select only the
#' most supported ancestries and avoid displaying too many links
#'
#' \item \code{kappa} to visualise the distributions generations between cases and their
#' ancestor/infector
#'
#' }
#'
#' @importFrom ggplot2 ggplot geom_line geom_point geom_histogram geom_density
#' geom_violin aes aes_string coord_flip labs guides scale_size_area
#' scale_x_discrete scale_y_discrete scale_color_manual scale_fill_manual
#' scale_x_continuous scale_y_continuous theme_bw facet_wrap
#'
#' @importFrom stats reshape
#' @importFrom grDevices xyTable
#' @importFrom graphics plot
#'
plot.outbreaker_chains <- function(x, y = "post",
type = c("trace", "hist", "density",
"alpha", "t_inf", "kappa", "network"),
group = NULL,
burnin = 0, min_support = 0.1, labels = NULL, ...) {
## CHECKS ##
type <- match.arg(type)
if (!y %in% names(x)) {
stop(paste(y,"is not a column of x"))
}
## THIS IS JUST TO APPEASE R CMD check
## hopefully cran will avoid spurious warnings along the lines of "no
## visible binding for global variable" when using ggplot2::aes(...)
frequency <- NULL
## GET DATA TO PLOT ##
if (burnin > max(x$step)) {
stop("burnin exceeds the number of steps in x")
}
x <- x[x$step>burnin,,drop = FALSE]
## check group
if(!is.null(group)) {
if(length(group) == 1L && group == "all") {
##remove _[digit] vars
y_vars = names(x)[!grepl("(_[[:digit:]]+$)", names(x))]
} else if(all(group %in% names(x))) {
y_vars = c("step", group)
} else {
stop("grouping variables not found in outbreaker object")
}
## get only relevant data
x_sub = as.data.frame(x)[,y_vars]
## switch it to long format to use in ggplot
x_long = reshape(x_sub,
idvar = "step",
ids = x_sub$step,
direction = "long",
new.row.names = NULL,
timevar = "Parameters",
v.names = "y",
varying = list(names(x_sub)[2:ncol(x_sub)]),
times = names(x_sub)[2:ncol(x_sub)])
}
## MAKE PLOT ##
if (type == "trace") {
if (!is.null(group)) {
out <- ggplot(x_long) +
geom_line(aes_string(x = "step", y = "y")) +
scale_x_continuous(name = "Iteration") +
scale_y_continuous(name = NULL) +
facet_wrap(~ Parameters, scales = "free")
} else {
out <- ggplot(x) +
geom_line(aes_string(x = "step", y = y)) +
labs(x = "Iteration",
y = y,
title = paste("trace:",y))
}
}
if (type == "hist") {
if (!is.null(group)) {
out <- ggplot(x_long) +
geom_histogram(aes_string(x = "y")) +
geom_point(aes_string(x = "y", y = 0),
shape="|",
alpha = 0.5,
size = 3) +
scale_x_continuous(name = NULL) +
scale_y_continuous(name = NULL) +
facet_wrap(~Parameters, scales = "free")
} else {
out <- ggplot(x) +
geom_histogram(aes_string(x = y)) +
geom_point(aes_string(x = y, y = 0),
shape="|",
alpha = 0.5,
size = 3) +
labs(x = y,
title = paste("histogram:",y))
}
}
if (type == "density") {
if (!is.null(group)) {
out <- ggplot(x_long) +
geom_density(aes_string(x = "y")) +
geom_point(aes_string(x = "y", y = 0),
shape="|",
alpha = 0.5,
size = 3) +
scale_x_continuous(name = NULL) +
scale_y_continuous(name = NULL) +
facet_wrap(~Parameters, scales = "free")
} else {
out <- ggplot(x) +
geom_density(aes_string(x = y)) +
geom_point(aes_string(x = y, y = 0),
shape="|",
alpha = 0.5,
size = 3) +
labs(x = y,
title = paste("density:",y))
}
}
if (type == "alpha") {
alpha <- as.matrix(x[,grep("alpha", names(x))])
colnames(alpha) <- seq_len(ncol(alpha))
from <- as.vector(alpha)
to <- as.vector(col(alpha))
from[is.na(from)] <- 0
out_dat <- data.frame(xyTable(from,to))
names(out_dat) <- c("from", "to", "frequency")
## Calculate proportion among ancestries
get_prop <- function(i) {
ind <- which(out_dat$to == out_dat$to[i])
out_dat[[3]][i]/sum(out_dat[[3]][ind])
}
## Return labels, if provided
get_alpha_lab <- function(axis, labels = NULL) {
if(is.null(labels)) labels <- seq_len(ncol(alpha))
if(axis == 'x') return(labels) else
if(axis == 'y') return(c("Import", labels))
}
## Return custom colors if provided
get_alpha_color <- function(color = NULL) {
if(is.null(color)) return(NULL)
else return(scale_color_manual(values = color))
}
## This joining function is needed so that the '...' argument can be passed
## to two functions with different arguments
get_lab_color <- function(labels = NULL, color = NULL) {
list(alpha_lab_x = get_alpha_lab('x', labels),
alpha_lab_y = get_alpha_lab('y', labels),
alpha_color = get_alpha_color(color))
}
tmp <- get_lab_color(labels, ...)
out_dat[3] <- vapply(seq_along(out_dat[[3]]), get_prop, 1)
out_dat$from <- factor(out_dat$from, levels = c(0, sort(unique(out_dat$to))))
out_dat$to <- factor(out_dat$to, levels = sort(unique(out_dat$to)))
out <- ggplot(out_dat) +
geom_point(aes(x = to, y = from, size = frequency, color = to)) +
scale_x_discrete(drop = FALSE, labels = tmp$alpha_lab_x) +
scale_y_discrete(drop = FALSE, labels = tmp$alpha_lab_y) +
labs(x = 'To', y = 'From', size = 'Posterior\nfrequency') +
tmp$alpha_color +
scale_size_area() +
guides(colour = "none")
}
if (type == "t_inf") {
get_t_inf_lab <- function(labels = NULL) {
N <- ncol(t_inf)
if(is.null(labels)) labels <- 1:N
return(labels)
}
## Return custom colors if provided
get_t_inf_color <- function(color = NULL) {
if(is.null(color)) return(NULL)
else return(scale_fill_manual(values = color))
}
## This joining function is needed so that the '...' argument can be passed
## to two functions with different arguments
get_lab_color <- function(labels = NULL, color = NULL) {
list(t_inf_lab_x = get_t_inf_lab(labels),
t_inf_color = get_t_inf_color(color))
}
t_inf <- as.matrix(x[,grep("t_inf", names(x))])
tmp <- get_lab_color(labels, ...)
dates <- as.vector(t_inf)
cases <- as.vector(col(t_inf))
out_dat <- data.frame(cases = factor(cases), dates = dates)
out <- ggplot(out_dat) +
geom_violin(aes(x = cases, y = dates, fill = cases)) +
coord_flip() +
guides(fill = "none") +
labs(y = 'Infection time', x = NULL) +
tmp$t_inf_color +
scale_x_discrete(labels = tmp$t_inf_lab)
}
if (type == "kappa") {
get_kappa_lab <- function(labels = NULL) {
N <- ncol(kappa)
if(is.null(labels)) labels <- 1:N
return(labels)
}
kappa <- as.matrix(x[,grep("kappa", names(x))])
generations <- as.vector(kappa)
cases <- as.vector(col(kappa))
to_keep <- !is.na(generations)
generations <- generations[to_keep]
cases <- cases[to_keep]
out_dat <- data.frame(xyTable(generations, cases))
get_prop <- function(i) {
ind <- which(out_dat$y == out_dat$y[i])
out_dat[[3]][i]/sum(out_dat[[3]][ind])
}
out_dat[3] <- vapply(seq_along(out_dat[[3]]), get_prop, 1)
names(out_dat) <- c("generations", "cases", "frequency")
out <- ggplot(out_dat) +
geom_point(aes(x = generations, y = as.factor(cases), size = frequency, color = factor(cases))) +
scale_size_area() +
scale_y_discrete(labels = get_kappa_lab(labels)) +
guides(colour = "none") +
labs(title = "number of generations between cases",
x = "number of generations to ancestor",
y = NULL)
}
if (type == "network") {
## extract edge info: ancestries
alpha <- x[, grep("alpha",names(x)), drop = FALSE]
from <- unlist(alpha)
to <- as.vector(col(alpha))
N <- ncol(alpha)
edges <- stats::na.omit(data.frame(xyTable(from, to)))
edges[3] <- edges$number/nrow(alpha)
names(edges) <- c("from", "to", "value")
edges <- edges[edges$value > min_support,,drop = FALSE]
edges$arrows <- "to"
case_cols <- cases_pal(N)
edges$color <- case_cols[edges$from]
## ## extract edge info: timing
## t_inf <- x[, grep("t_inf",names(x)), drop = FALSE]
## mean_time <- apply(t_inf, 2, mean)
## mean_delay <- mean_time[edges$to] - mean_time[edges$from]
## mean_delay[mean_delay<1] <- 1
## edges$label <- paste(round(mean_delay), "days")
## node info
find_nodes_size <- function(i) {
sum(from==i, na.rm = TRUE) / nrow(alpha)
}
get_node_lab <- function(labels = NULL) {
if(is.null(labels)) labels <- 1:N
return(labels)
}
nodes <- data.frame(id = seq_len(ncol(alpha)),
label = seq_len(ncol(alpha)))
nodes$value <- vapply(nodes$id,
find_nodes_size,
numeric(1))
nodes$color <- case_cols
nodes$shape <- rep("dot", N)
nodes$label <- get_node_lab(labels)
smry <- summary(x, burnin = burnin)
is_imported <- is.na(smry$tree$from)
nodes$shaped[is_imported] <- "star"
## generate graph
out <- visNetwork::visNetwork(nodes = nodes, edges = edges, ...)
out <- visNetwork::visNodes(out, shadow = list(enabled = TRUE, size = 10),
color = list(highlight = "red"))
out <- visNetwork::visEdges(out, arrows = list(
to = list(enabled = TRUE, scaleFactor = 0.2)),
color = list(highlight = "red"))
}
return(out)
}
#' @rdname outbreaker_chains
#' @param object an \code{outbreaker_chains} object as returned by \code{outbreaker}.
#'
#' @param method the method used to determine consensus ancestries. 'mpa'
#' (maximum posterior ancestry) simply returns the posterior ancestry with the
#' highest posterior support for each case, even if this includes
#' cycles. 'decycle' will return the maximum posterior ancestry, except when
#' cycles are detected, in which case the link in the cycle with the lowest
#' support is pruned and the tree recalculated.
#'
#' @export
#' @importFrom stats median
summary.outbreaker_chains <- function(object, burnin = 0, method = c("mpa", "decycle"), ...) {
## check burnin ##
x <- object
if (burnin > max(x$step)) {
stop("burnin exceeds the number of steps in object")
}
x <- x[x$step>burnin,,drop = FALSE]
## make output ##
out <- list()
## summary for $step ##
interv <- ifelse(nrow(x)>2, diff(tail(x$step, 2)), NA)
out$step <- c(first = min(x$step),
last = max(x$step),
interval = interv,
n_steps = length(x$step)
)
## summary of post, like, prior ##
out$post <- summary(x$post)
out$like <- summary(x$like)
out$prior <- summary(x$prior)
## summary for mu ##
out$mu <- summary(x$mu)
## summary for pi ##
out$pi <- summary(x$pi)
## summary tree ##
out$tree <- list()
## summary of alpha ##
alpha <- as.matrix(x[,grep("alpha", names(x))])
method <- match.arg(method)
if(method == 'mpa') {
## function to get most frequent item
f1 <- function(x) {
as.integer(names(sort(table(x, exclude = NULL), decreasing = TRUE)[1]))
}
out$tree$from <- apply(alpha, 2, f1)
out$tree$to <- seq_len(ncol(alpha))
## function to get frequency of most frequent item
f2 <- function(x) {
(sort(table(x), decreasing = TRUE)/length(x))[1]
}
support <- apply(alpha, 2, f2)
} else if(method == 'decycle') {
cons <- .decycle_tree(x)
out$tree$from <- cons$from
out$tree$to <- cons$to
support <- cons$support
}
## summary of t_inf ##
t_inf <- as.matrix(x[,grep("t_inf", names(x))])
out$tree$time <- apply(t_inf, 2, median)
out$tree$support <- support
## summary of kappa ##
kappa <- as.matrix(x[,grep("kappa", names(x))])
out$tree$generations <- apply(kappa, 2, median, na.rm = TRUE)
out$tree$generations[is.na(out$tree$from)] <- NA
## shape tree as a data.frame
out$tree <- as.data.frame(out$tree)
rownames(out$tree) <- NULL
return(out)
}