Skip to content

Commit

Permalink
Merge pull request #247 from stan-dev/E_loo-improvements
Browse files Browse the repository at this point in the history
improved E_loo Pareto-k diagnostics
  • Loading branch information
jgabry authored Feb 22, 2024
2 parents 8108ac4 + 69de50f commit a362836
Show file tree
Hide file tree
Showing 15 changed files with 79 additions and 68 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ Imports:
checkmate,
matrixStats (>= 0.52),
parallel,
posterior (>= 1.5.0),
stats
Suggests:
bayesplot (>= 1.7.0),
brms (>= 2.10.0),
ggplot2,
graphics,
knitr,
posterior,
rmarkdown,
rstan,
rstanarm (>= 2.19.0),
Expand Down
91 changes: 57 additions & 34 deletions R/E_loo.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#' @param log_ratios Optionally, a vector or matrix (the same dimensions as `x`)
#' of raw (not smoothed) log ratios. If working with log-likelihood values,
#' the log ratios are the **negative** of those values. If `log_ratios` is
#' specified we are able to compute [Pareto k][pareto-k-diagnostic]
#' specified we are able to compute more accurate [Pareto k][pareto-k-diagnostic]
#' diagnostics specific to `E_loo()`.
#' @param type The type of expectation to compute. The options are
#' `"mean"`, `"variance"`, `"sd"`, and `"quantile"`.
Expand All @@ -33,18 +33,26 @@
#' the returned object is a `length(probs)` by `ncol(x)` matrix.
#'
#' For the default/vector method the `value` component is scalar, with
#' one exception: when `type` is `"quantile"` and multiple values
#' one exception: when `type="quantile"` and multiple values
#' are specified in `probs` the `value` component is a vector with
#' `length(probs)` elements.
#' }
#' \item{`pareto_k`}{
#' Function-specific diagnostic.
#'
#' If `log_ratios` is not specified when calling `E_loo()`,
#' `pareto_k` will be `NULL`. Otherwise, for the matrix method it
#' will be a vector of length `ncol(x)` containing estimates of the shape
#' parameter \eqn{k} of the generalized Pareto distribution. For the
#' default/vector method, the estimate is a scalar.
#' For the matrix method it will be a vector of length `ncol(x)`
#' containing estimates of the shape parameter \eqn{k} of the
#' generalized Pareto distribution. For the default/vector method,
#' the estimate is a scalar. If `log_ratios` is not specified when
#' calling `E_loo()`, the smoothed log-weights are used to estimate
#' Pareto-k's, which may produce optimistic estimates.
#'
#' For `type="mean"`, `type="var"`, and `type="sd"`, the returned Pareto-k is
#' the maximum of the Pareto-k's for the left and right tail of \eqn{hr} and
#' the right tail of \eqn{r}, where \eqn{r} is the importance ratio and
#' \eqn{h=x} for `type="mean"` and \eqn{h=x^2} for `type="var"` and
#' `type="sd"`. For `type="quantile"`, the returned Pareto-k is the Pareto-k
#' for the right tail of \eqn{r}.
#' }
#' }
#'
Expand Down Expand Up @@ -79,8 +87,8 @@
#' E_loo(yrep, psis_object, type = "quantile", probs = 0.5) # median
#' E_loo(yrep, psis_object, type = "quantile", probs = c(0.1, 0.9))
#'
#' # To get Pareto k diagnostic with E_loo we also need to provide the negative
#' # log-likelihood values using the log_ratios argument.
#' # We can get more accurate Pareto k diagnostic if we also provide
#' # the log_ratios argument
#' E_loo(yrep, psis_object, type = "mean", log_ratios = log_ratios)
#' }
#' }
Expand Down Expand Up @@ -111,12 +119,18 @@ E_loo.default <-
out <- E_fun(x, w, probs)

if (is.null(log_ratios)) {
warning("'log_ratios' not specified. Can't compute k-hat diagnostic.",
call. = FALSE)
khat <- NULL
} else {
khat <- E_loo_khat.default(x, psis_object, log_ratios)
# Use of smoothed ratios gives slightly optimistic
# Pareto-k's, but these are still better than nothing
log_ratios <- weights(psis_object, log = TRUE)
}
h <- switch(
type,
"mean" = x,
"variance" = x^2,
"sd" = x^2,
"quantile" = NULL
)
khat <- E_loo_khat.default(h, psis_object, log_ratios)
list(value = out, pareto_k = khat)
}

Expand Down Expand Up @@ -153,12 +167,18 @@ E_loo.matrix <-
}, FUN.VALUE = fun_val)

if (is.null(log_ratios)) {
warning("'log_ratios' not specified. Can't compute k-hat diagnostic.",
call. = FALSE)
khat <- NULL
} else {
khat <- E_loo_khat.matrix(x, psis_object, log_ratios)
# Use of smoothed ratios gives slightly optimistic
# Pareto-k's, but these are still better than nothing
log_ratios <- weights(psis_object, log = TRUE)
}
h <- switch(
type,
"mean" = x,
"variance" = x^2,
"sd" = x^2,
"quantile" = NULL
)
khat <- E_loo_khat.matrix(h, psis_object, log_ratios)
list(value = out, pareto_k = khat)
}

Expand Down Expand Up @@ -247,9 +267,15 @@ E_loo_khat.default <- function(x, psis_object, log_ratios, ...) {
#' @export
E_loo_khat.matrix <- function(x, psis_object, log_ratios, ...) {
tail_lengths <- attr(psis_object, "tail_len")
sapply(seq_len(ncol(x)), function(i) {
.E_loo_khat_i(x[, i], log_ratios[, i], tail_lengths[i])
})
if (is.null(x)) {
sapply(seq_len(ncol(log_ratios)), function(i) {
.E_loo_khat_i(x, log_ratios[, i], tail_lengths[i])
})
} else {
sapply(seq_len(ncol(log_ratios)), function(i) {
.E_loo_khat_i(x[, i], log_ratios[, i], tail_lengths[i])
})
}
}

#' Compute function-specific khat estimates
Expand All @@ -262,16 +288,13 @@ E_loo_khat.matrix <- function(x, psis_object, log_ratios, ...) {
#' @return Scalar h-specific k-hat estimate.
#'
.E_loo_khat_i <- function(x_i, log_ratios_i, tail_len_i) {
h_theta <- x_i
r_theta <- exp(log_ratios_i - max(log_ratios_i))
a <- sqrt(1 + h_theta^2) * r_theta
log_a <- sort(log(a))

S <- length(log_a)
tail_ids <- seq(S - tail_len_i + 1, S)
tail_sample <- log_a[tail_ids]
cutoff <- log_a[min(tail_ids) - 1]

smoothed <- psis_smooth_tail(tail_sample, cutoff)
smoothed$k
h_theta <- x_i
r_theta <- exp(log_ratios_i - max(log_ratios_i))
khat_r <- posterior::pareto_khat(r_theta, tail = "right", ndraws_tail = tail_len_i)$khat
if (is.null(x_i)) {
khat_r
} else {
khat_hr <- posterior::pareto_khat(h_theta * r_theta, tail = "both", ndraws_tail = tail_len_i)$khat
max(khat_hr, khat_r)
}
}
22 changes: 1 addition & 21 deletions R/effective_sample_sizes.R
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,7 @@ ess_rfun <- function(sims) {
if (is.vector(sims)) dim(sims) <- c(length(sims), 1)
chains <- ncol(sims)
n_samples <- nrow(sims)
if (requireNamespace("posterior", quietly = TRUE)) {
acov <- lapply(1:chains, FUN = function(i) posterior::autocovariance(sims[,i]))
} else {
acov <- lapply(1:chains, FUN = function(i) autocovariance(sims[,i]))
}
acov <- lapply(1:chains, FUN = function(i) posterior::autocovariance(sims[,i]))
acov <- do.call(cbind, acov)
chain_mean <- colMeans(sims)
mean_var <- mean(acov[1,]) * n_samples / (n_samples - 1)
Expand Down Expand Up @@ -275,19 +271,3 @@ fft_next_good_size <- function(N) {
N = N + 1
}
}

# autocovariance function to use if posterior::autocovariance is not available
autocovariance <- function(y) {
# Compute autocovariance estimates for every lag for the specified
# input sequence using a fast Fourier transform approach.
N <- length(y)
M <- fft_next_good_size(N)
Mt2 <- 2 * M
yc <- y - mean(y)
yc <- c(yc, rep.int(0, Mt2-N))
transform <- stats::fft(yc)
ac <- stats::fft(Conj(transform) * transform, inverse = TRUE)
# use "biased" estimate as recommended by Geyer (1992)
ac <- Re(ac)[1:N] / (N^2 * 2)
ac
}
26 changes: 17 additions & 9 deletions man/E_loo.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Binary file modified tests/testthat/reference-results/E_loo_default_mean.rds
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified tests/testthat/reference-results/E_loo_default_sd.rds
Binary file not shown.
Binary file modified tests/testthat/reference-results/E_loo_default_var.rds
Binary file not shown.
Binary file modified tests/testthat/reference-results/E_loo_matrix_mean.rds
Binary file not shown.
Binary file modified tests/testthat/reference-results/E_loo_matrix_quantile_10_90.rds
Binary file not shown.
Binary file modified tests/testthat/reference-results/E_loo_matrix_quantile_50.rds
Binary file not shown.
Binary file modified tests/testthat/reference-results/E_loo_matrix_sd.rds
Binary file not shown.
Binary file modified tests/testthat/reference-results/E_loo_matrix_var.rds
Binary file not shown.
6 changes: 3 additions & 3 deletions tests/testthat/test_E_loo.R
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,9 @@ test_that("E_loo.matrix equal to reference", {

test_that("E_loo throws correct errors and warnings", {
# warnings
expect_warning(E_loo.matrix(x, psis_mat), "'log_ratios' not specified")
expect_warning(E_test <- E_loo.default(x[, 1], psis_vec), "'log_ratios' not specified")
expect_null(E_test$pareto_k)
expect_no_warning(E_loo.matrix(x, psis_mat))
expect_no_warning(E_test <- E_loo.default(x[, 1], psis_vec))
expect_length(E_test$pareto_k, 1)

# errors
expect_error(E_loo(x, 1), "is.psis")
Expand Down

0 comments on commit a362836

Please sign in to comment.