From 70a6201c3a20cb89371533c0a780b8822c97954c Mon Sep 17 00:00:00 2001 From: Andrew Manderson Date: Fri, 6 Nov 2020 14:16:48 +0000 Subject: [PATCH] add ref_interval option to mcmc_rank_* functions. Also add tests for mcmc_rank_* functions --- R/mcmc-traces.R | 67 ++++++++++++++++++++++++++++++- man/MCMC-traces.Rd | 17 +++++++- tests/testthat/test-mcmc-traces.R | 61 ++++++++++++++++++++++++++++ 3 files changed, 141 insertions(+), 4 deletions(-) diff --git a/R/mcmc-traces.R b/R/mcmc-traces.R index c23f269b..c2047bb6 100644 --- a/R/mcmc-traces.R +++ b/R/mcmc-traces.R @@ -261,6 +261,13 @@ trace_style_np <- function(div_color = "red", div_size = 0.25, div_alpha = 1) { #' of rank-normalized MCMC samples. Defaults to `20`. #' @param ref_line For the rank plots, whether to draw a horizontal line at the #' average number of ranks per bin. Defaults to `FALSE`. +#' @param ref_interval For the rank plots, whether to draw a reference +#' uncertainty interval based on the expected distribution of the rank histogram +#' bins. Defaults to `FALSE`. +#' @param interval_args If `ref_interval = TRUE`, optional arguments controlling +#' the width and alpha of the reference interval. The default is a `95\%` +#' uncertainty interval plotted with an alpha value of `0.2`. This must be a +#' list with elements named `width` and `alpha`. #' @export mcmc_rank_overlay <- function(x, pars = character(), @@ -269,7 +276,9 @@ mcmc_rank_overlay <- function(x, facet_args = list(), ..., n_bins = 20, - ref_line = FALSE) { + ref_line = FALSE, + ref_interval = FALSE, + interval_args = list(width = 0.95, alpha = 0.2)) { check_ignored_arguments(...) data <- mcmc_trace_data( x, @@ -278,6 +287,14 @@ mcmc_rank_overlay <- function(x, transformations = transformations ) + # mcmc_rank plots make no sense if there aren't multiple chains + # a rank plot of 1 chain is perfectly uniform by construction, and + # has no power as a diagnostic. + if (!(unique(data$n_chains) > 1)){ + STOP_need_multiple_chains() + } + + n_iter <- unique(data$n_iterations) n_chains <- unique(data$n_chains) n_param <- unique(data$n_parameters) @@ -316,6 +333,12 @@ mcmc_rank_overlay <- function(x, } else { NULL } + + interval_call <- if (ref_interval) { + rank_polygon_geom(n_iter, n_chains, n_bins, interval_args) + } else { + NULL + } facet_call <- NULL if (n_param > 1) { @@ -329,6 +352,7 @@ mcmc_rank_overlay <- function(x, geom_step() + layer_ref_line + facet_call + + interval_call + scale_color + ylim(c(0, NA)) + bayesplot_theme_get() + @@ -345,7 +369,9 @@ mcmc_rank_hist <- function(x, ..., facet_args = list(), n_bins = 20, - ref_line = FALSE) { + ref_line = FALSE, + ref_interval = FALSE, + interval_args = list(width = 0.95, alpha = 0.2)) { check_ignored_arguments(...) data <- mcmc_trace_data( x, @@ -354,6 +380,10 @@ mcmc_rank_hist <- function(x, transformations = transformations ) + if (!(unique(data$n_chains) > 1)){ + STOP_need_multiple_chains() + } + n_iter <- unique(data$n_iterations) n_chains <- unique(data$n_chains) n_param <- unique(data$n_parameters) @@ -396,6 +426,11 @@ mcmc_rank_hist <- function(x, } facet_call <- do.call(facet_f, facet_args) + interval_call <- if (ref_interval) { + rank_polygon_geom(n_iter, n_chains, n_bins, interval_args) + } else { + NULL + } ggplot(data) + aes_(x = ~ value_rank) + @@ -409,6 +444,7 @@ mcmc_rank_hist <- function(x, layer_ref_line + geom_blank(data = data_boundaries) + facet_call + + interval_call + force_x_axis_in_facets() + dont_expand_y_axis(c(0.005, 0)) + bayesplot_theme_get() + @@ -681,3 +717,30 @@ divergence_rug <- function(np, np_style, n_iter, n_chain) { alpha = np_style$alpha[["div"]] ) } + +rank_polygon_geom <- function(n_iter, n_chains, n_bins, interval_args) { + validate_interval_args(interval_args) + polygon_y_vals <- qbinom( + c((1 - interval_args$width) / 2, (1 + interval_args$width) / 2), + size = n_iter, + prob = (n_bins)^(-1) + ) + + polygon_df <- data.frame( + x = rep(c(0, n_iter * n_chains), each = 2), + y = c(polygon_y_vals, rev(polygon_y_vals)) + ) + + geom_polygon( + mapping = aes(x = x, y = y), + data = polygon_df, + inherit.aes = FALSE, + alpha = interval_args$alpha + ) +} + +validate_interval_args <- function(interval_args) { + stopifnot(all(names(interval_args) %in% c("width", "alpha"))) + stopifnot(interval_args$width %>% dplyr::between(0, 1)) + stopifnot(interval_args$alpha %>% dplyr::between(0, 1)) +} \ No newline at end of file diff --git a/man/MCMC-traces.Rd b/man/MCMC-traces.Rd index 9c12838c..dd56da09 100644 --- a/man/MCMC-traces.Rd +++ b/man/MCMC-traces.Rd @@ -50,7 +50,9 @@ mcmc_rank_overlay( facet_args = list(), ..., n_bins = 20, - ref_line = FALSE + ref_line = FALSE, + ref_interval = FALSE, + interval_args = list(width = 0.95, alpha = 0.2) ) mcmc_rank_hist( @@ -61,7 +63,9 @@ mcmc_rank_hist( ..., facet_args = list(), n_bins = 20, - ref_line = FALSE + ref_line = FALSE, + ref_interval = FALSE, + interval_args = list(width = 0.95, alpha = 0.2) ) mcmc_trace_data( @@ -172,6 +176,15 @@ of rank-normalized MCMC samples. Defaults to \code{20}.} \item{ref_line}{For the rank plots, whether to draw a horizontal line at the average number of ranks per bin. Defaults to \code{FALSE}.} + +\item{ref_interval}{For the rank plots, whether to draw a reference +uncertainty interval based on the expected distribution of the rank histogram +bins. Defaults to \code{FALSE}.} + +\item{interval_args}{If \code{ref_interval = TRUE}, optional arguments controlling +the width and alpha of the reference interval. The default is a \verb{95\\\%} +uncertainty interval plotted with an alpha value of \code{0.2}. This must be a +list with elements named \code{width} and \code{alpha}.} } \value{ The plotting functions return a ggplot object that can be further diff --git a/tests/testthat/test-mcmc-traces.R b/tests/testthat/test-mcmc-traces.R index 82dcec1e..bb416e8a 100644 --- a/tests/testthat/test-mcmc-traces.R +++ b/tests/testthat/test-mcmc-traces.R @@ -33,6 +33,30 @@ test_that("mcmc_trace_highlight throws error if highlight > number of chains", { expect_error(mcmc_trace_highlight(arr, pars = "sigma", highlight = 7), "'highlight' is 7") }) +test_that("mcmc_rank_hist returns a ggplot object", { + expect_gg(mcmc_rank_hist(arr, pars = "beta[1]", regex_pars = "x\\:")) + expect_gg(mcmc_rank_hist(dframe_multiple_chains)) + expect_gg(mcmc_rank_hist(chainlist)) +}) + +test_that("mcmc_rank_overlay returns a ggplot object", { + expect_gg(mcmc_rank_overlay(arr, pars = "beta[1]", regex_pars = "x\\:")) + expect_gg(mcmc_rank_overlay(dframe_multiple_chains)) + expect_gg(mcmc_rank_overlay(chainlist)) +}) + +test_that("mcmc_rank_hist errors if there is only 1 chain", { + expect_error(mcmc_rank_hist(mat), "requires multiple") + expect_error(mcmc_rank_hist(dframe), "requires multiple chains") + expect_error(mcmc_rank_hist(arr1chain), "requires multiple chains") +}) + +test_that("mcmc_rank_overlay errors if there is only 1 chain", { + expect_error(mcmc_rank_overlay(mat), "requires multiple") + expect_error(mcmc_rank_overlay(dframe), "requires multiple chains") + expect_error(mcmc_rank_overlay(arr1chain), "requires multiple chains") +}) + # options ----------------------------------------------------------------- test_that("mcmc_trace options work", { expect_gg(g1 <- mcmc_trace(arr, regex_pars = "beta", window = c(5, 10))) @@ -47,6 +71,43 @@ test_that("mcmc_trace options work", { expect_error(mcmc_trace(arr, n_warmup = 50, iter1 = 20)) }) +test_that("mcmc_rank_hist options work", { + expect_gg(mcmc_rank_hist(arr, regex_pars = "beta", ref_interval = TRUE)) + expect_gg( + mcmc_rank_hist(arr, + regex_pars = "beta", + n_bins = 15, + ref_line = TRUE, + ref_interval = TRUE, + interval_args = list(width = 0.8, alpha = 0.1)) + ) +}) + +test_that("mcmc_rank_overlay options work", { + expect_gg(mcmc_rank_overlay(arr, regex_pars = "beta", ref_interval = TRUE)) + expect_gg( + mcmc_rank_overlay(arr, + regex_pars = "beta", + n_bins = 15, + ref_line = TRUE, + ref_interval = TRUE, + interval_args = list(width = 0.8, alpha = 0.1)) + ) +}) + +test_that("mcmc_rank_* interval_args get validated", { + expect_error( + mcmc_rank_overlay(arr, + regex_pars = "beta", + n_bins = 15, + ref_line = TRUE, + ref_interval = TRUE, + interval_args = list(with = 0.8, alpha = 0.1)), # intended typo + "is not TRUE" + ) +}) + + # displaying divergences in traceplot ------------------------------------- test_that("mcmc_trace 'np' argument works", {