Skip to content

Commit

Permalink
out_stan_file for brms backends
Browse files Browse the repository at this point in the history
  • Loading branch information
martinmodrak committed Feb 18, 2024
1 parent a3a452d commit 1970994
Show file tree
Hide file tree
Showing 10 changed files with 70 additions and 15 deletions.
8 changes: 5 additions & 3 deletions R/backend-brms.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ new_SBC_backend_brms <- function(compiled_model,

validate_SBC_backend_brms_args <- function(args) {
if(!is.null(args$algorithm) && args$algorithm != "sampling") {
stop("Algorithms other than sampling not supported yet")
stop("Algorithms other than sampling not supported yet. Comment on https://github.com/hyunjimoon/SBC/issues/91 to express your interest.")
}

unacceptable_params <- c("data", "cores", "empty")
Expand All @@ -42,8 +42,10 @@ validate_SBC_backend_brms_args <- function(args) {
#' @param template_data a representative value for the `data` argument in `brm`
#' that can be used to generate code.
#' @param template_dataset DEPRECATED. Use `template_data`
#' @param out_stan_file A filename for the generated Stan code. Useful for
#' debugging and for avoiding unnecessary recompilations.
#' @export
SBC_backend_brms <- function(..., template_data, template_dataset = NULL) {
SBC_backend_brms <- function(..., template_data, out_stan_file = NULL, template_dataset = NULL) {
if(!is.null(template_dataset)) {
warning("Argument 'template_dataset' is deprecated, use 'template_data' instead")
if(missing(template_data)) {
Expand All @@ -53,7 +55,7 @@ SBC_backend_brms <- function(..., template_data, template_dataset = NULL) {
args = list(...)
validate_SBC_backend_brms_args(args)

stanmodel <- stanmodel_for_brms(data = template_data, ...)
stanmodel <- stanmodel_for_brms(data = template_data, out_stan_file = out_stan_file, ...)

new_SBC_backend_brms(stanmodel, args)
}
Expand Down
1 change: 1 addition & 0 deletions R/backend-stan-shared.R
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ get_expected_max_rhat <- function(n_vars, prob = 0.99, approx_sd = 0.005) {
1 + std_val_max * approx_sd
}

#' @export
get_diagnostic_messages.SBC_nuts_diagnostics_summary <- function(x) {
message_list <- list()
i <- 1
Expand Down
37 changes: 34 additions & 3 deletions R/brms-helpers.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
stanmodel_for_brms <- function(...) {
stanmodel_for_brms <- function(..., out_stan_file = NULL) {
model_code <- brms::make_stancode(...)

args <- list(...)
Expand All @@ -13,9 +13,19 @@ stanmodel_for_brms <- function(...) {
backend <- getOption("brms.backend", "rstan")
}
if(backend == "cmdstanr") {
compiled_model <- cmdstanr::cmdstan_model(cmdstanr::write_stan_file(model_code))
if(is.null(out_stan_file)) {
out_stan_file <- cmdstanr::write_stan_file(model_code)
} else {
write_stan_file_simple(out_stan_file, model_code)
}
compiled_model <- cmdstanr::cmdstan_model(out_stan_file)
} else if(backend == "rstan") {
compiled_model <- rstan::stan_model(model_code = model_code)
if(is.null(out_stan_file)) {
compiled_model <- rstan::stan_model(model_code = model_code)
} else {
write_stan_file_simple(out_stan_file, model_code)
compiled_model <- rstan::stan_model(file = out_stan_file)
}
} else {
stop("Unsupported backend: ", backend)

Expand All @@ -24,6 +34,27 @@ stanmodel_for_brms <- function(...) {
compiled_model
}

# write code to file, not touching the file if the code matches
write_stan_file_simple <- function(file, code) {
overwrite <- TRUE
if(file.exists(file)) {
collapsed_code <- paste0(code, collapse = "\n")
tryCatch({
file_contents <- paste0(readLines(file), collapse = "\n")
if (gsub("(\r|\n)+", "\n", file_contents) == gsub("(\r|\n)+", "\n", collapsed_code)) {
overwrite <- FALSE
}
},
error = function(e) {
warning("Error when checking old file contents", e)
})
}

if(overwrite) {
cat(code, file = file, sep = "\n")
}
}

translate_rstan_args_to_cmdstan <- function(args, include_unrecognized = TRUE) {
ignored_args <- c("cores", "data")
recognized_but_unchanged <- c("thin", "refresh")
Expand Down
7 changes: 5 additions & 2 deletions R/generator-brms.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@
#' the chunk size used with [future.apply::future_mapply()]. Set quite high
#' by default as the parallelism only benefits for large individual datasets/number of
#' simulations.
#' @param out_stan_file A filename for the generated Stan code. Useful for
#' debugging and for avoiding unnecessary recompilations.
#' @export
SBC_generator_brms <- function(formula, data, ..., generate_lp = TRUE,
generate_lp_chunksize = 5000 / nrow(data)) {
generate_lp_chunksize = 5000 / nrow(data),
out_stan_file = NULL) {
require_brms_version("brms generator")

model_data <- brms::make_standata(formula = formula, data = data, ..., sample_prior = "only")
Expand All @@ -25,7 +28,7 @@ SBC_generator_brms <- function(formula, data, ..., generate_lp = TRUE,
stop("Algorithms other than sampling not supported yet")
}

compiled_model <- stanmodel_for_brms(formula = formula, data = data, ...)
compiled_model <- stanmodel_for_brms(formula = formula, data = data, out_stan_file = out_stan_file, ...)



Expand Down
3 changes: 3 additions & 0 deletions R/plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ data_for_ecdf_plots <- function(x, ...,
}


#' @export
data_for_ecdf_plots.SBC_results <- function(x, variables = NULL,
prob = 0.95,
gamma = NULL,
Expand All @@ -308,6 +309,7 @@ data_for_ecdf_plots.SBC_results <- function(x, variables = NULL,
}


#' @export
data_for_ecdf_plots.data.frame <- function(x, variables = NULL,
prob = 0.95,
gamma = NULL,
Expand Down Expand Up @@ -370,6 +372,7 @@ data_for_ecdf_plots.data.frame <- function(x, variables = NULL,
gamma = gamma, K = K)
}

#' @export
data_for_ecdf_plots.matrix <- function(x,
max_rank,
variables = NULL,
Expand Down
1 change: 1 addition & 0 deletions R/results.R
Original file line number Diff line number Diff line change
Expand Up @@ -1054,6 +1054,7 @@ validate_diagnostic_messages <- function(x) {
x
}

#' @export
print.SBC_diagnostic_messages <- function(x, include_ok = TRUE, print_func = cat) {
x <- validate_diagnostic_messages(x)
if(!include_ok) {
Expand Down
12 changes: 10 additions & 2 deletions man/SBC_backend_brms.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/SBC_backend_brms_from_generator.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 5 additions & 1 deletion man/SBC_generator_brms.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 5 additions & 3 deletions vignettes/brms.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ generator <- SBC_generator_brms(y ~ x, data = template_data, prior = priors,
thin = 50, warmup = 10000, refresh = 2000,
# Will generate the log density - this is useful,
#but a bit computationally expensive
generate_lp = TRUE
generate_lp = TRUE,
out_stan_file = file.path(cache_dir, "brms_linreg1.stan")
)
```

Expand Down Expand Up @@ -208,8 +209,9 @@ priors_func <- prior(normal(0,1), class = "b") +
backend_func <- SBC_backend_brms(y ~ x + (1 | group),
prior = priors_func, chains = 1,
template_data = datasets_func$generated[[1]])
prior = priors_func, chains = 1,
template_data = datasets_func$generated[[1]],
out_stan_file = file.path(cache_dir, "brms_linreg2.stan"))
```
So we can happily compute:
Expand Down

0 comments on commit 1970994

Please sign in to comment.