From 3ac550c129ea944d588c602ad664cd8605c29229 Mon Sep 17 00:00:00 2001 From: Maximilian Scholz Date: Mon, 16 Dec 2024 14:35:49 +0100 Subject: [PATCH 1/3] Add split-chain option to rank overlay plots Related to #333 --- R/mcmc-traces.R | 36 ++++++++-- .../mcmc-rank-overlay-split-chains.svg | 67 +++++++++++++++++++ tests/testthat/data-for-mcmc-tests.R | 7 ++ tests/testthat/test-mcmc-traces.R | 7 ++ 4 files changed, 113 insertions(+), 4 deletions(-) create mode 100644 tests/testthat/_snaps/mcmc-traces/mcmc-rank-overlay-split-chains.svg diff --git a/R/mcmc-traces.R b/R/mcmc-traces.R index 571a6ce2..ac5bb467 100644 --- a/R/mcmc-traces.R +++ b/R/mcmc-traces.R @@ -277,6 +277,9 @@ trace_style_np <- function(div_color = "red", div_size = 0.25, div_alpha = 1) { #' of rank-normalized MCMC samples. Defaults to `20`. #' @param ref_line For the rank plots, whether to draw a horizontal line at the #' average number of ranks per bin. Defaults to `FALSE`. +#' @param split_chains Logical indicating whether to split each chain into two parts. +#' If TRUE, each chain is split into first and second half with "_1" and "_2" suffixes. +#' Defaults to `FALSE`. #' @export mcmc_rank_overlay <- function(x, pars = character(), @@ -285,7 +288,8 @@ mcmc_rank_overlay <- function(x, facet_args = list(), ..., n_bins = 20, - ref_line = FALSE) { + ref_line = FALSE, + split_chains = FALSE) { check_ignored_arguments(...) data <- mcmc_trace_data( x, @@ -294,7 +298,28 @@ mcmc_rank_overlay <- function(x, transformations = transformations ) - n_chains <- unique(data$n_chains) + # Split chains if requested + if (split_chains) { + data$n_chains = data$n_chains/2 + data$n_iterations = data$n_iterations/2 + # Calculate midpoint for each chain + n_samples <- length(unique(data$iteration)) + midpoint <- n_samples/2 + + # Create new data frame with split chains + data <- data %>% + group_by(.data$chain) %>% + mutate( + chain = ifelse( + iteration <= midpoint, + paste0(.data$chain, "_1"), + paste0(.data$chain, "_2") + ) + ) %>% + ungroup() + } + + n_chains <- length(unique(data$chain)) n_param <- unique(data$n_parameters) # We have to bin and count the data ourselves because @@ -319,6 +344,7 @@ mcmc_rank_overlay <- function(x, bin_start = unique(histobins$bin_start), stringsAsFactors = FALSE )) + d_bin_counts <- all_combos %>% left_join(d_bin_counts, by = c("parameter", "chain", "bin_start")) %>% mutate(n = dplyr::if_else(is.na(n), 0L, n)) @@ -331,7 +357,9 @@ mcmc_rank_overlay <- function(x, mutate(bin_start = right_edge) %>% dplyr::bind_rows(d_bin_counts) - scale_color <- scale_color_manual("Chain", values = chain_colors(n_chains)) + # Update legend title based on split_chains + legend_title <- if (split_chains) "Split Chains" else "Chain" + scale_color <- scale_color_manual(legend_title, values = chain_colors(n_chains)) layer_ref_line <- if (ref_line) { geom_hline( @@ -352,7 +380,7 @@ mcmc_rank_overlay <- function(x, } ggplot(d_bin_counts) + - aes(x = .data$bin_start, y = .data$n, color = .data$chain) + + aes(x = .data$bin_start, y = .data$n, color = .data$chain) + geom_step() + layer_ref_line + facet_call + diff --git a/tests/testthat/_snaps/mcmc-traces/mcmc-rank-overlay-split-chains.svg b/tests/testthat/_snaps/mcmc-traces/mcmc-rank-overlay-split-chains.svg new file mode 100644 index 00000000..0079d95e --- /dev/null +++ b/tests/testthat/_snaps/mcmc-traces/mcmc-rank-overlay-split-chains.svg @@ -0,0 +1,67 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + +0 +20 +40 +60 + + + + + + + + + + +0 +500 +1000 +1500 +2000 +Rank +Split Chains + + + + +1_1 +1_2 +2_1 +2_2 +mcmc_rank_overlay (split chains) + + diff --git a/tests/testthat/data-for-mcmc-tests.R b/tests/testthat/data-for-mcmc-tests.R index 1136cefe..17f5b8d0 100644 --- a/tests/testthat/data-for-mcmc-tests.R +++ b/tests/testthat/data-for-mcmc-tests.R @@ -80,4 +80,11 @@ vdiff_dframe_rank_overlay_bins_test <- posterior::as_draws_df( ) ) +vdiff_dframe_rank_overlay_split_chain_test <- posterior::as_draws_df( + list( + list(theta = -2 + 0.003 * 1:1000 + stats::arima.sim(list(ar = 0.7), n = 1000, sd = 0.5)), + list(theta = 1 + -0.003 * 1:1000 + stats::arima.sim(list(ar = 0.7), n = 1000, sd = 0.5)) + ) +) + set.seed(seed = NULL) diff --git a/tests/testthat/test-mcmc-traces.R b/tests/testthat/test-mcmc-traces.R index 62d46c88..7c6da533 100644 --- a/tests/testthat/test-mcmc-traces.R +++ b/tests/testthat/test-mcmc-traces.R @@ -157,6 +157,10 @@ test_that("mcmc_rank_overlay renders correctly", { # https://github.com/stan-dev/bayesplot/issues/331 p_not_all_bins_exist <- mcmc_rank_overlay(vdiff_dframe_rank_overlay_bins_test) + # https://github.com/stan-dev/bayesplot/issues/333 + p_split_chains <- mcmc_rank_overlay(vdiff_dframe_rank_overlay_split_chain_test, + split_chains = TRUE) + vdiffr::expect_doppelganger("mcmc_rank_overlay (default)", p_base) vdiffr::expect_doppelganger( "mcmc_rank_overlay (reference line)", @@ -170,6 +174,9 @@ test_that("mcmc_rank_overlay renders correctly", { # https://github.com/stan-dev/bayesplot/issues/331 vdiffr::expect_doppelganger("mcmc_rank_overlay (not all bins)", p_not_all_bins_exist) + + # https://github.com/stan-dev/bayesplot/issues/333 + vdiffr::expect_doppelganger("mcmc_rank_overlay (split chains)", p_split_chains) }) test_that("mcmc_rank_hist renders correctly", { From 8d5df1f89f3a330084767d202440a6943ecf49b9 Mon Sep 17 00:00:00 2001 From: Maximilian Scholz Date: Mon, 16 Dec 2024 15:08:23 +0100 Subject: [PATCH 2/3] Add split-chain option to rank ecdf plots Related to #333 --- R/mcmc-traces.R | 32 ++++++++- .../mcmc-rank-ecdf-split-chain.svg | 70 +++++++++++++++++++ tests/testthat/data-for-mcmc-tests.R | 2 +- tests/testthat/test-mcmc-traces.R | 10 ++- 4 files changed, 109 insertions(+), 5 deletions(-) create mode 100644 tests/testthat/_snaps/mcmc-traces/mcmc-rank-ecdf-split-chain.svg diff --git a/R/mcmc-traces.R b/R/mcmc-traces.R index ac5bb467..01f79e1f 100644 --- a/R/mcmc-traces.R +++ b/R/mcmc-traces.R @@ -485,6 +485,9 @@ mcmc_rank_hist <- function(x, #' @param plot_diff For `mcmc_rank_ecdf()`, a boolean specifying if the #' difference between the observed rank ECDFs and the theoretical expectation #' should be drawn instead of the unmodified rank ECDF plots. +#' @param split_chains Logical indicating whether to split each chain into two parts. +#' If TRUE, each chain is split into first and second half with "_1" and "_2" suffixes. +#' Defaults to `FALSE`. #' @export mcmc_rank_ecdf <- function(x, @@ -496,7 +499,8 @@ mcmc_rank_ecdf <- facet_args = list(), prob = 0.99, plot_diff = FALSE, - interpolate_adj = NULL) { + interpolate_adj = NULL, + split_chains = FALSE) { check_ignored_arguments(..., ok_args = c("K", "pit", "prob", "plot_diff", "interpolate_adj", "M") ) @@ -507,8 +511,28 @@ mcmc_rank_ecdf <- transformations = transformations, highlight = 1 ) + + # Split chains if requested + if (split_chains) { + data$n_chains = data$n_chains/2 + data$n_iterations = data$n_iterations/2 + n_samples <- length(unique(data$iteration)) + midpoint <- n_samples/2 + + data <- data %>% + group_by(.data$chain) %>% + mutate( + chain = ifelse( + iteration <= midpoint, + paste0(.data$chain, "_1"), + paste0(.data$chain, "_2") + ) + ) %>% + ungroup() + } + n_iter <- unique(data$n_iterations) - n_chain <- unique(data$n_chains) + n_chain <- length(unique(data$chain)) n_param <- unique(data$n_parameters) x <- if (is.null(K)) { @@ -561,7 +585,9 @@ mcmc_rank_ecdf <- group = .data$chain ) - scale_color <- scale_color_manual("Chain", values = chain_colors(n_chain)) + # Update legend title based on split_chains + legend_title <- if (split_chains) "Split Chains" else "Chain" + scale_color <- scale_color_manual(legend_title, values = chain_colors(n_chain)) facet_call <- NULL if (n_param == 1) { diff --git a/tests/testthat/_snaps/mcmc-traces/mcmc-rank-ecdf-split-chain.svg b/tests/testthat/_snaps/mcmc-traces/mcmc-rank-ecdf-split-chain.svg new file mode 100644 index 00000000..d16c9ca3 --- /dev/null +++ b/tests/testthat/_snaps/mcmc-traces/mcmc-rank-ecdf-split-chain.svg @@ -0,0 +1,70 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + +-0.4 +-0.2 +0.0 +0.2 + + + + + + + + + + + +0.0 +0.2 +0.4 +0.6 +0.8 +1.0 +theta +Split Chains + + + + +1_1 +1_2 +2_1 +2_2 +mcmc_rank_ecdf (split chain) + + diff --git a/tests/testthat/data-for-mcmc-tests.R b/tests/testthat/data-for-mcmc-tests.R index 17f5b8d0..fe892579 100644 --- a/tests/testthat/data-for-mcmc-tests.R +++ b/tests/testthat/data-for-mcmc-tests.R @@ -80,7 +80,7 @@ vdiff_dframe_rank_overlay_bins_test <- posterior::as_draws_df( ) ) -vdiff_dframe_rank_overlay_split_chain_test <- posterior::as_draws_df( +vdiff_dframe_rank_split_chain_test <- posterior::as_draws_df( list( list(theta = -2 + 0.003 * 1:1000 + stats::arima.sim(list(ar = 0.7), n = 1000, sd = 0.5)), list(theta = 1 + -0.003 * 1:1000 + stats::arima.sim(list(ar = 0.7), n = 1000, sd = 0.5)) diff --git a/tests/testthat/test-mcmc-traces.R b/tests/testthat/test-mcmc-traces.R index 7c6da533..f79b4c2e 100644 --- a/tests/testthat/test-mcmc-traces.R +++ b/tests/testthat/test-mcmc-traces.R @@ -158,7 +158,7 @@ test_that("mcmc_rank_overlay renders correctly", { p_not_all_bins_exist <- mcmc_rank_overlay(vdiff_dframe_rank_overlay_bins_test) # https://github.com/stan-dev/bayesplot/issues/333 - p_split_chains <- mcmc_rank_overlay(vdiff_dframe_rank_overlay_split_chain_test, + p_split_chains <- mcmc_rank_overlay(vdiff_dframe_rank_split_chain_test, split_chains = TRUE) vdiffr::expect_doppelganger("mcmc_rank_overlay (default)", p_base) @@ -261,6 +261,11 @@ test_that("mcmc_rank_ecdf renders correctly", { plot_diff = TRUE ) + # https://github.com/stan-dev/bayesplot/issues/333 + p_split_chains <- mcmc_rank_ecdf(vdiff_dframe_rank_split_chain_test, + plot_diff = TRUE, + split_chains = TRUE) + vdiffr::expect_doppelganger("mcmc_rank_ecdf (default)", p_base) vdiffr::expect_doppelganger("mcmc_rank_ecdf (one parameter)", p_one_param) vdiffr::expect_doppelganger("mcmc_rank_ecdf (diff)", p_diff) @@ -268,6 +273,9 @@ test_that("mcmc_rank_ecdf renders correctly", { "mcmc_rank_ecdf (one param, diff)", p_diff_one_param ) + + # https://github.com/stan-dev/bayesplot/issues/333 + vdiffr::expect_doppelganger("mcmc_rank_ecdf (split chain)", p_split_chains) }) test_that("mcmc_trace with 'np' renders correctly", { From 1b81c0b3a7f74b3ba44da7488b3268efbca957c8 Mon Sep 17 00:00:00 2001 From: Maximilian Scholz Date: Mon, 16 Dec 2024 15:45:35 +0100 Subject: [PATCH 3/3] Build documentation and fix some check() problems. --- R/mcmc-traces.R | 4 ++-- man/MCMC-traces.Rd | 10 ++++++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/R/mcmc-traces.R b/R/mcmc-traces.R index 01f79e1f..2e5f32e1 100644 --- a/R/mcmc-traces.R +++ b/R/mcmc-traces.R @@ -311,7 +311,7 @@ mcmc_rank_overlay <- function(x, group_by(.data$chain) %>% mutate( chain = ifelse( - iteration <= midpoint, + .data$iteration <= midpoint, paste0(.data$chain, "_1"), paste0(.data$chain, "_2") ) @@ -523,7 +523,7 @@ mcmc_rank_ecdf <- group_by(.data$chain) %>% mutate( chain = ifelse( - iteration <= midpoint, + .data$iteration <= midpoint, paste0(.data$chain, "_1"), paste0(.data$chain, "_2") ) diff --git a/man/MCMC-traces.Rd b/man/MCMC-traces.Rd index 1054591b..4f631067 100644 --- a/man/MCMC-traces.Rd +++ b/man/MCMC-traces.Rd @@ -51,7 +51,8 @@ mcmc_rank_overlay( facet_args = list(), ..., n_bins = 20, - ref_line = FALSE + ref_line = FALSE, + split_chains = FALSE ) mcmc_rank_hist( @@ -75,7 +76,8 @@ mcmc_rank_ecdf( facet_args = list(), prob = 0.99, plot_diff = FALSE, - interpolate_adj = NULL + interpolate_adj = NULL, + split_chains = FALSE ) mcmc_trace_data( @@ -193,6 +195,10 @@ of rank-normalized MCMC samples. Defaults to \code{20}.} \item{ref_line}{For the rank plots, whether to draw a horizontal line at the average number of ranks per bin. Defaults to \code{FALSE}.} +\item{split_chains}{Logical indicating whether to split each chain into two parts. +If TRUE, each chain is split into first and second half with "_1" and "_2" suffixes. +Defaults to \code{FALSE}.} + \item{K}{An optional integer defining the number of equally spaced evaluation points for the PIT-ECDF. Reducing K when using \code{interpolate_adj = FALSE} makes computing the confidence bands faster. For \code{ppc_pit_ecdf} and