Skip to content

Commit

Permalink
make E_loo Pareto-k diagnostic more robust
Browse files Browse the repository at this point in the history
  • Loading branch information
avehtari committed Feb 29, 2024
1 parent 45c261e commit efcf522
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
20 changes: 13 additions & 7 deletions R/E_loo.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@
#' Pareto-k's, which may produce optimistic estimates.
#'
#' For `type="mean"`, `type="var"`, and `type="sd"`, the returned Pareto-k is
#' the maximum of the Pareto-k's for the left and right tail of \eqn{hr} and
#' the right tail of \eqn{r}, where \eqn{r} is the importance ratio and
#' \eqn{h=x} for `type="mean"` and \eqn{h=x^2} for `type="var"` and
#' `type="sd"`. For `type="quantile"`, the returned Pareto-k is the Pareto-k
#' for the right tail of \eqn{r}.
#' usually the maximum of the Pareto-k's for the left and right tail of \eqn{hr}
#' and the right tail of \eqn{r}, where \eqn{r} is the importance ratio and
#' \eqn{h=x} for `type="mean"` and \eqn{h=x^2} for `type="var"` and `type="sd"`.
#' If \eqn{h} is binary, constant, or not finite, or if type="quantile"`, the
#' returned Pareto-k is the Pareto-k for the right tail of \eqn{r}.
#' }
#' }
#'
Expand Down Expand Up @@ -291,10 +291,16 @@ E_loo_khat.matrix <- function(x, psis_object, log_ratios, ...) {
h_theta <- x_i
r_theta <- exp(log_ratios_i - max(log_ratios_i))
khat_r <- posterior::pareto_khat(r_theta, tail = "right", ndraws_tail = tail_len_i)$khat
if (is.null(x_i)) {
if (is.null(x_i) || is_constant(x_i) || length(unique(x_i))==2 ||
anyNA(x_i) || any(is.infinite(x_i))) {
khat_r
} else {
khat_hr <- posterior::pareto_khat(h_theta * r_theta, tail = "both", ndraws_tail = tail_len_i)$khat
max(khat_hr, khat_r)
if (is.na(khat_hr) && is.na(khat_r)) {
k <- NA
} else {
k <- max(khat_hr, khat_r, na.rm=TRUE)
}
k
}
}
7 changes: 6 additions & 1 deletion tests/testthat/test_E_loo.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ test_that("E_loo.matrix equal to reference", {
test_that("E_loo throws correct errors and warnings", {
# warnings
expect_no_warning(E_loo.matrix(x, psis_mat))
# now warnings if x is constant, binary, NA, NaN, Inf
expect_no_warning(E_loo.matrix(x*0, psis_mat))
expect_no_warning(E_loo.matrix(0+(x>0), psis_mat))
expect_no_warning(E_loo.matrix(x+NA, psis_mat))
expect_no_warning(E_loo.matrix(x*NaN, psis_mat))
expect_no_warning(E_loo.matrix(x*Inf, psis_mat))
expect_no_warning(E_test <- E_loo.default(x[, 1], psis_vec))
expect_length(E_test$pareto_k, 1)

Expand Down Expand Up @@ -191,4 +197,3 @@ test_that("weighted variance works", {
w <- c(rep(0.1, 10), rep(0, 90))
expect_equal(.wvar(x, w), var(x[w > 0]))
})

0 comments on commit efcf522

Please sign in to comment.