Skip to content

Commit

Permalink
Follow ggplot2 updates on facet_grid() and facet_wrap()
Browse files Browse the repository at this point in the history
- Replace facets with rows/cols for facet_grid()
- Replace formula ~parameter with vars(parameter)

Addresses stan-dev#304
  • Loading branch information
heavywatal committed Jun 4, 2023
1 parent bf145fc commit 260804e
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 18 deletions.
6 changes: 3 additions & 3 deletions R/mcmc-diagnostics-nuts.R
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
#' mcmc_nuts_energy(np)
#' mcmc_nuts_energy(np, merge_chains = TRUE, binwidth = .15)
#' mcmc_nuts_energy(np) +
#' facet_wrap(~ Chain, nrow = 1) +
#' facet_wrap(vars(Chain), nrow = 1) +
#' coord_fixed(ratio = 150) +
#' ggtitle("NUTS Energy Diagnostic")
#' }
Expand Down Expand Up @@ -180,7 +180,7 @@ mcmc_nuts_acceptance <-
}
hists <- hists +
dont_expand_y_axis(c(0.005, 0)) +
facet_wrap(~ Parameter, scales = "free") +
facet_wrap(vars(Parameter), scales = "free") +
yaxis_text(FALSE) +
yaxis_title(FALSE) +
yaxis_ticks(FALSE) +
Expand Down Expand Up @@ -476,7 +476,7 @@ mcmc_nuts_energy <-
}

graph +
facet_wrap(~ Chain) +
facet_wrap(vars(Chain)) +
force_axes_in_facets()
}

Expand Down
3 changes: 2 additions & 1 deletion R/mcmc-diagnostics.R
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,8 @@ drop_NAs_and_warn <- function(x) {
plot_data <- acf_data(x = x, lags = lags)

if (num_chains(x) > 1) {
facet_args$facets <- "Chain ~ Parameter"
facet_args$rows <- vars(Chain)
facet_args$cols <- vars(Parameter)
facet_fun <- "facet_grid"
} else { # 1 chain
facet_args$facets <- "Parameter"
Expand Down
11 changes: 5 additions & 6 deletions R/mcmc-distributions.R
Original file line number Diff line number Diff line change
Expand Up @@ -405,14 +405,13 @@ mcmc_violin <- function(
facet_args[["scales"]] <- facet_args[["scales"]] %||% "free"
if (!by_chain) {
if (n_param > 1) {
facet_args[["facets"]] <- ~ Parameter
facet_args[["facets"]] <- vars(Parameter)
graph <- graph + do.call("facet_wrap", facet_args)
}
} else {
facet_args[["facets"]] <- if (n_param > 1) {
"Chain ~ Parameter"
} else {
"Chain ~ ."
facet_args[["rows"]] <- vars(Chain)
if (n_param > 1) {
facet_args[["cols"]] <- vars(Parameter)
}
graph <- graph +
do.call("facet_grid", facet_args) +
Expand Down Expand Up @@ -527,7 +526,7 @@ mcmc_violin <- function(
labs(x = if (violin) "Chain" else levels(data$Parameter),
y = if (violin) levels(data$Parameter) else NULL)
} else {
facet_args[["facets"]] <- ~ Parameter
facet_args[["facets"]] <- vars(Parameter)
facet_args[["scales"]] <- facet_args[["scales"]] %||% "free"
graph <- graph + do.call("facet_wrap", facet_args)
}
Expand Down
9 changes: 5 additions & 4 deletions R/mcmc-traces.R
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ mcmc_rank_overlay <- function(x,

facet_call <- NULL
if (n_param > 1) {
facet_args$facets <- ~ parameter
facet_args$facets <- vars(parameter)
facet_args$scales <- facet_args$scales %||% "fixed"
facet_call <- do.call("facet_wrap", facet_args)
}
Expand Down Expand Up @@ -387,7 +387,8 @@ mcmc_rank_hist <- function(x,
right_edge <- max(data_boundaries$value_rank)

facet_args[["scales"]] <- facet_args[["scales"]] %||% "fixed"
facet_args[["facets"]] <- facet_args[["facets"]] %||% (parameter ~ chain)
facet_args[["rows"]] <- facet_args[["rows"]] %||% vars(parameter)
facet_args[["cols"]] <- facet_args[["cols"]] %||% vars(chain)

# If there is one parameter, put the chains in one row.
# Otherwise, use a grid.
Expand Down Expand Up @@ -526,7 +527,7 @@ mcmc_rank_ecdf <-
if (n_param == 1) {
facet_call <- ylab(levels(data$parameter))
} else {
facet_args$facets <- ~parameter
facet_args$facets <- vars(parameter)
facet_args$scales <- facet_args$scales %||% "free"
facet_call <- do.call("facet_wrap", facet_args)
}
Expand Down Expand Up @@ -705,7 +706,7 @@ mcmc_trace_data <- function(x,
if (n_param == 1) {
facet_call <- ylab(levels(data$parameter))
} else {
facet_args$facets <- ~ parameter
facet_args$facets <- vars(parameter)
facet_args$scales <- facet_args$scales %||% "free"
facet_call <- do.call("facet_wrap", facet_args)
}
Expand Down
5 changes: 3 additions & 2 deletions R/ppc-errors.R
Original file line number Diff line number Diff line change
Expand Up @@ -398,10 +398,11 @@ error_hist_facets <-

if (grouped) {
facet_fun <- "facet_grid"
facet_args[["facets"]] <- rep_id ~ group
facet_args[["rows"]] <- vars(rep_id)
facet_args[["cols"]] <- vars(group)
} else {
facet_fun <- "facet_wrap"
facet_args[["facets"]] <- ~ rep_id
facet_args[["facets"]] <- vars(rep_id)
}
facet_args[["scales"]] <- facet_args[["scales"]] %||% scales_default

Expand Down
2 changes: 1 addition & 1 deletion man/MCMC-nuts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tests/testthat/test-ppc-distributions.R
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ test_that("ppc_dens,pp_hist,ppc_freqpoly,ppc_boxplot return ggplot objects", {

expect_gg(p <- ppc_hist(y, yrep[1:8, ], binwidth = 3))
if (utils::packageVersion("ggplot2") >= "3.0.0") {
facet_var <- "~rep_label"
facet_var <- vars(rep_label)
expect_equal(as.character(p$facet$params$facets[1]), facet_var)
}

Expand Down

0 comments on commit 260804e

Please sign in to comment.