diff --git a/R/backend-brms.R b/R/backend-brms.R index 8cc1001..40094c1 100644 --- a/R/backend-brms.R +++ b/R/backend-brms.R @@ -25,7 +25,7 @@ new_SBC_backend_brms <- function(compiled_model, validate_SBC_backend_brms_args <- function(args) { if(!is.null(args$algorithm) && args$algorithm != "sampling") { - stop("Algorithms other than sampling not supported yet") + stop("Algorithms other than sampling not supported yet. Comment on https://github.com/hyunjimoon/SBC/issues/91 to express your interest.") } unacceptable_params <- c("data", "cores", "empty") @@ -42,8 +42,10 @@ validate_SBC_backend_brms_args <- function(args) { #' @param template_data a representative value for the `data` argument in `brm` #' that can be used to generate code. #' @param template_dataset DEPRECATED. Use `template_data` +#' @param out_stan_file A filename for the generated Stan code. Useful for +#' debugging and for avoiding unnecessary recompilations. #' @export -SBC_backend_brms <- function(..., template_data, template_dataset = NULL) { +SBC_backend_brms <- function(..., template_data, out_stan_file = NULL, template_dataset = NULL) { if(!is.null(template_dataset)) { warning("Argument 'template_dataset' is deprecated, use 'template_data' instead") if(missing(template_data)) { @@ -53,7 +55,7 @@ SBC_backend_brms <- function(..., template_data, template_dataset = NULL) { args = list(...) validate_SBC_backend_brms_args(args) - stanmodel <- stanmodel_for_brms(data = template_data, ...) + stanmodel <- stanmodel_for_brms(data = template_data, out_stan_file = out_stan_file, ...) new_SBC_backend_brms(stanmodel, args) } diff --git a/R/backend-stan-shared.R b/R/backend-stan-shared.R index 778f1ae..2fadea3 100644 --- a/R/backend-stan-shared.R +++ b/R/backend-stan-shared.R @@ -69,6 +69,7 @@ get_expected_max_rhat <- function(n_vars, prob = 0.99, approx_sd = 0.005) { 1 + std_val_max * approx_sd } +#' @export get_diagnostic_messages.SBC_nuts_diagnostics_summary <- function(x) { message_list <- list() i <- 1 diff --git a/R/brms-helpers.R b/R/brms-helpers.R index 6ec5f9c..35bea2d 100644 --- a/R/brms-helpers.R +++ b/R/brms-helpers.R @@ -1,4 +1,4 @@ -stanmodel_for_brms <- function(...) { +stanmodel_for_brms <- function(..., out_stan_file = NULL) { model_code <- brms::make_stancode(...) args <- list(...) @@ -13,9 +13,19 @@ stanmodel_for_brms <- function(...) { backend <- getOption("brms.backend", "rstan") } if(backend == "cmdstanr") { - compiled_model <- cmdstanr::cmdstan_model(cmdstanr::write_stan_file(model_code)) + if(is.null(out_stan_file)) { + out_stan_file <- cmdstanr::write_stan_file(model_code) + } else { + write_stan_file_simple(out_stan_file, model_code) + } + compiled_model <- cmdstanr::cmdstan_model(out_stan_file) } else if(backend == "rstan") { - compiled_model <- rstan::stan_model(model_code = model_code) + if(is.null(out_stan_file)) { + compiled_model <- rstan::stan_model(model_code = model_code) + } else { + write_stan_file_simple(out_stan_file, model_code) + compiled_model <- rstan::stan_model(file = out_stan_file) + } } else { stop("Unsupported backend: ", backend) @@ -24,6 +34,27 @@ stanmodel_for_brms <- function(...) { compiled_model } +# write code to file, not touching the file if the code matches +write_stan_file_simple <- function(file, code) { + overwrite <- TRUE + if(file.exists(file)) { + collapsed_code <- paste0(code, collapse = "\n") + tryCatch({ + file_contents <- paste0(readLines(file), collapse = "\n") + if (gsub("(\r|\n)+", "\n", file_contents) == gsub("(\r|\n)+", "\n", collapsed_code)) { + overwrite <- FALSE + } + }, + error = function(e) { + warning("Error when checking old file contents", e) + }) + } + + if(overwrite) { + cat(code, file = file, sep = "\n") + } +} + translate_rstan_args_to_cmdstan <- function(args, include_unrecognized = TRUE) { ignored_args <- c("cores", "data") recognized_but_unchanged <- c("thin", "refresh") diff --git a/R/generator-brms.R b/R/generator-brms.R index 36506bc..d327475 100644 --- a/R/generator-brms.R +++ b/R/generator-brms.R @@ -11,9 +11,12 @@ #' the chunk size used with [future.apply::future_mapply()]. Set quite high #' by default as the parallelism only benefits for large individual datasets/number of #' simulations. +#' @param out_stan_file A filename for the generated Stan code. Useful for +#' debugging and for avoiding unnecessary recompilations. #' @export SBC_generator_brms <- function(formula, data, ..., generate_lp = TRUE, - generate_lp_chunksize = 5000 / nrow(data)) { + generate_lp_chunksize = 5000 / nrow(data), + out_stan_file = NULL) { require_brms_version("brms generator") model_data <- brms::make_standata(formula = formula, data = data, ..., sample_prior = "only") @@ -25,7 +28,7 @@ SBC_generator_brms <- function(formula, data, ..., generate_lp = TRUE, stop("Algorithms other than sampling not supported yet") } - compiled_model <- stanmodel_for_brms(formula = formula, data = data, ...) + compiled_model <- stanmodel_for_brms(formula = formula, data = data, out_stan_file = out_stan_file, ...) diff --git a/R/plot.R b/R/plot.R index 169dc95..5bc170e 100755 --- a/R/plot.R +++ b/R/plot.R @@ -292,6 +292,7 @@ data_for_ecdf_plots <- function(x, ..., } +#' @export data_for_ecdf_plots.SBC_results <- function(x, variables = NULL, prob = 0.95, gamma = NULL, @@ -308,6 +309,7 @@ data_for_ecdf_plots.SBC_results <- function(x, variables = NULL, } +#' @export data_for_ecdf_plots.data.frame <- function(x, variables = NULL, prob = 0.95, gamma = NULL, @@ -370,6 +372,7 @@ data_for_ecdf_plots.data.frame <- function(x, variables = NULL, gamma = gamma, K = K) } +#' @export data_for_ecdf_plots.matrix <- function(x, max_rank, variables = NULL, diff --git a/R/results.R b/R/results.R index 4e7815a..b9daf52 100644 --- a/R/results.R +++ b/R/results.R @@ -1054,6 +1054,7 @@ validate_diagnostic_messages <- function(x) { x } +#' @export print.SBC_diagnostic_messages <- function(x, include_ok = TRUE, print_func = cat) { x <- validate_diagnostic_messages(x) if(!include_ok) { diff --git a/man/SBC_backend_brms.Rd b/man/SBC_backend_brms.Rd index 97ea664..60ceb75 100644 --- a/man/SBC_backend_brms.Rd +++ b/man/SBC_backend_brms.Rd @@ -1,10 +1,15 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/backends.R +% Please edit documentation in R/backend-brms.R \name{SBC_backend_brms} \alias{SBC_backend_brms} \title{Build a backend based on the \code{brms} package.} \usage{ -SBC_backend_brms(..., template_data, template_dataset = NULL) +SBC_backend_brms( + ..., + template_data, + out_stan_file = NULL, + template_dataset = NULL +) } \arguments{ \item{...}{arguments passed to \code{brm}.} @@ -12,6 +17,9 @@ SBC_backend_brms(..., template_data, template_dataset = NULL) \item{template_data}{a representative value for the \code{data} argument in \code{brm} that can be used to generate code.} +\item{out_stan_file}{A filename for the generated Stan code. Useful for +debugging and for avoiding unnecessary recompilations.} + \item{template_dataset}{DEPRECATED. Use \code{template_data}} } \description{ diff --git a/man/SBC_backend_brms_from_generator.Rd b/man/SBC_backend_brms_from_generator.Rd index 93eb1ee..b01c7be 100644 --- a/man/SBC_backend_brms_from_generator.Rd +++ b/man/SBC_backend_brms_from_generator.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/backends.R +% Please edit documentation in R/backend-brms.R \name{SBC_backend_brms_from_generator} \alias{SBC_backend_brms_from_generator} \title{Build a brms backend, reusing the compiled model from a previously created \code{SBC_generator_brms} diff --git a/man/SBC_generator_brms.Rd b/man/SBC_generator_brms.Rd index 79d47e8..434eba9 100644 --- a/man/SBC_generator_brms.Rd +++ b/man/SBC_generator_brms.Rd @@ -9,7 +9,8 @@ SBC_generator_brms( data, ..., generate_lp = TRUE, - generate_lp_chunksize = 5000/nrow(data) + generate_lp_chunksize = 5000/nrow(data), + out_stan_file = NULL ) } \arguments{ @@ -23,6 +24,9 @@ but improves sensitivity of the SBC process.} the chunk size used with \code{\link[future.apply:future_mapply]{future.apply::future_mapply()}}. Set quite high by default as the parallelism only benefits for large individual datasets/number of simulations.} + +\item{out_stan_file}{A filename for the generated Stan code. Useful for +debugging and for avoiding unnecessary recompilations.} } \description{ Brms generator uses a brms model with \code{sample_prior = "only"} to generate diff --git a/vignettes/brms.Rmd b/vignettes/brms.Rmd index c611c8c..81cdb5c 100644 --- a/vignettes/brms.Rmd +++ b/vignettes/brms.Rmd @@ -77,7 +77,8 @@ generator <- SBC_generator_brms(y ~ x, data = template_data, prior = priors, thin = 50, warmup = 10000, refresh = 2000, # Will generate the log density - this is useful, #but a bit computationally expensive - generate_lp = TRUE + generate_lp = TRUE, + out_stan_file = file.path(cache_dir, "brms_linreg1.stan") ) ``` @@ -208,8 +209,9 @@ priors_func <- prior(normal(0,1), class = "b") + backend_func <- SBC_backend_brms(y ~ x + (1 | group), - prior = priors_func, chains = 1, - template_data = datasets_func$generated[[1]]) + prior = priors_func, chains = 1, + template_data = datasets_func$generated[[1]], + out_stan_file = file.path(cache_dir, "brms_linreg2.stan")) ``` So we can happily compute: