Abstract
Background Multiple approaches for conducting approximate Bayesian inference exist.
Task We compare methods for assessing the accuracy of posterior distribution. These methods include marginal KS tests, simulation based calibration and maximum mean discrepancy.
There are lots of ways you can compare posterior distributions.
The simplest is to compare point estimates (such as the mean or mode) or moments (such as the variance). In some sense, it’s the least we should expect from any good approximate Bayesian inference method to be able to recover these summaries.
More challenging is to match the distribution everywhere. One way to assess this is the Kolmogorov-Smirnov (KS) distance, which gives the maximum distance between two empirical cumulative distribution functions (ECDFs). It is directly interpretable: a value of 0.05 means that no matter where you compute a tail probability, you’re never more than 5% away.
A downside of the KS distance is that it only captures marginal distributions. This, in some way, favors methods like INLA where we directly approximate marginals. A further challenge would be to assess agreement between joint distributions, rather than only marginals.
Throughout this notebook we will use exemplar data from the report dev_sinla
.
samples <- readRDS("depends/model1-samples-m250.rds")
samples_mcmc <- samples$mcmc
samples_gaussian <- samples$gaussian
samples_laplace <- samples$laplace
Compute the empirical CDFs, and look for the maximum distance between them.
n <- length(unique(samples_mcmc$index))
ks_gaussian_mcmc <- lapply(1:n, function(i) {
samples_gaussian_i <- filter(samples_gaussian, index == i)$value
samples_mcmc_i <- filter(samples_mcmc, index == i)$value
inf.utils::ks_test(samples_gaussian_i, samples_mcmc_i)
})
ks_laplace_mcmc <- lapply(1:n, function(i) {
samples_laplace_i <- filter(samples_laplace, index == i)$value
samples_mcmc_i <- filter(samples_mcmc, index == i)$value
inf.utils::ks_test(samples_laplace_i, samples_mcmc_i)
})
ks_results <- bind_rows(
bind_rows(ks_gaussian_mcmc, .id = "index") %>%
mutate(type = "gaussian"),
bind_rows(ks_laplace_mcmc, .id = "index") %>%
mutate(type = "laplace")
)
ks_results %>%
select(index, D, type) %>%
pivot_wider(
names_from = type,
values_from = D
) %>%
ggplot(aes(x = laplace, y = gaussian)) +
geom_point(alpha = 0.5) +
geom_abline(intercept = 0, slope = 1, linetype = "dashed") +
lims(x = c(0, 1), y = c(0, 1)) +
labs(x = "KS(laplace, mcmc)", y = "KS(gaussian, mcmc)") +
theme_minimal()
ks_results %>%
group_by(type) %>%
summarise(mean = mean(D)) %>%
knitr::kable()
type | mean |
---|---|
gaussian | 0.0509201 |
laplace | 0.0464028 |
SBC (Talts et al. 2018) is a corrected version of Cook, Gelman, and Rubin (2006). Consider the performance of an algorithm over the entire Bayesian joint distribution. Generate \((\tilde \theta, \tilde y)\) via \(\tilde \theta \sim p(\theta)\) then \(\tilde y \sim p(y \, | \, \tilde \theta)\). Let the data averaged posterior (DAP) be \[ p(\theta) = \int p(y \, | \, \tilde \theta) p(\tilde \theta, \tilde y) \text{d} \tilde \theta \text{d} \tilde y \tag{4.1} \] For any model, the average of any exact posterior expectation with respect to data generated from the Bayesian joint distribution reduces to the corresponding prior expectation. Any discrepancy between the DAP and the prior indicates some error in the analysis (either implementation or algorithmic).
Consider a sequence of samples from the Bayesian joint distribution \[ \tilde \theta \sim p(\theta) \\ \tilde y \sim p(y \, | \, \tilde \theta) \\ \{\theta_1, \ldots, \theta_L \} \sim p(\theta \, | \, \tilde y) \] Equation (4.1) implies that \(\tilde \theta\) and \(\{\theta_1, \ldots, \theta_L \}\) should both be distributed according to the prior. Consequently, for any one-dimensional random variable \(f: \Theta \to \mathbb{R}\) the rank statistic of the prior sample relative to the posterior sample \[ r(\{f(\theta_1), \ldots, f(\theta_L) \}, f(\tilde \theta)) = \sum_{l = 1}^L \mathbb{I}[f(\theta_l) < f(\tilde \theta)] \in [0, L] \] will be distributed uniformly across the integers \([0, L]\). I assume we can use histograms and ECDF difference plots to assess this, and perhaps there will be some additional discussion about how to interpret deviations from uniformity.
There are two procedures you could use to generate samples from the joint distribution. Firstly, you could sample from the prior on \(\theta\) first: \[ \tilde \theta_1 \sim p(\theta) \\ \tilde y_1 \sim p(y \, | \, \tilde \theta_1) \\ \implies (\tilde \theta_1, \tilde y_1) \sim p(\theta, y) \] Secondly, you could sample from the “prior” on \(y\) first: \[ \tilde y_2 \sim p(y) \\ \tilde \theta_2 \sim p(\theta \, | \, \tilde y_2) \\ \implies (\tilde \theta_2, \tilde y_2) \sim p(\theta, y) \] Both will generate samples from the joint distribution \(p(\theta, y)\). This implies that if you look at the draws from procedure one \(\{ \tilde \theta_1\}\) and the draws from procedure two \(\{ \tilde \theta_2\}\) then they have the same distribution.
Importantly, procedure two involves computing a posterior distribution, and therefore the samples will only be the same if the Bayesian inference procedure is exact. It’s this fact which SBC uses to create a procedure for checking implementation and accuracy of a given approximate Bayesian inference method.
Let’s have a go at running the R-INLA
section from Talts et al. (2018) as described in the GitHub repository seantalts/simulation-based-calibration
.
load("sbc-talts.rdata")
#' To-do...
Consider the model:
\[ y_i \sim \mathcal{N}(\mu, \sigma^2) \\ \mu \sim \mathcal{N}(0 , 1) \\ \sigma \sim \text{lognormal}(0, 1) \]
Want M
samples.
M <- 1000
n_chains <- 4
n_iter <- 2000
fit <- rstan::stan(file = "sbc.stan", iter = n_iter, chains = n_chains, refresh = 0)
Check that all ESS are greater than M
:
all(summary(fit)$summary[, "n_eff"] > M)
## [1] TRUE
Extract samples and thin down to just M
:
samples <- extract(fit)
samples <- lapply(samples, function(sample) sample[1:M * floor(length(sample) / M)])
So, the quantiles (you probably want to be more careful about ranks than I’m being here) of the simulated data within the posterior are:
sum(samples$mu_lt_sim) / length(samples$mu_lt_sim)
## [1] 0.707
sum(samples$sigma_lt_sim) / length(samples$sigma_lt_sim)
## [1] 0.534
These should be \(\mathcal{U}(0, 1)\), so by repeating this process many times (it looks like it’s a quite costly procedure!) we can assess whether this is in fact true.
Say that we have draws \((\theta_1, \ldots, \theta_S)\) from a proposal distribution \(q(\theta)\), then we can estimate \(\mathbb{E}_p[h(\theta)]\) by \[ \frac{\sum_{s = 1}^S h(\theta_s) w_s}{\sum_{s = 1}^S w_s}. \] When \(w_s = 1\) then this is the direct Monte Carlo estimate, when \(w_s = r_s = p(\theta_s, y) / q(\theta)\) then this is importance sampling (IS). The finite sample performance of IS contains information about how close the proposal distribution is to the target distribution (the true posterior distribution of interest).
Pareto smoothed importance sampling (PSIS) Yao et al. (2018) can improve the estimates from IS by stabilising the importance ratios. Let \[ p(y \, | \, \mu, \sigma, k) = \begin{cases} \frac{1}{\sigma} \left( 1 + k \left( \frac{y - \mu}{\sigma} \right) \right)^{- \frac{1}{k} - 1}, \quad k \neq 0 \\ \frac{1}{\sigma} \exp \left( \frac{y - \mu}{\sigma} \right), \quad k = 0 \end{cases} \] be the generalised Pareto distribution with shape parameter \(k\) and location-scale parameter \((\mu, \tau)\). In PSIS this distribution is fit to the \(M\) largest samples of \(r_s\) where \(M = \min (S/5, 3\sqrt{S})\) (heuristic). The fitted value \(\hat k\) is reported, and these \(M\) largest \(r_s\) are replaced by their expected value under the Pareto distribution. All weights are truncated by the raw weight maximum \(\max(r_s)\)
Can be thought of as a Bayesian version of IS with a prior on the largest importance ratios. Comment: this is interesting, but seems somewhat suspicious to me as to why the Bayesian prior would only apply to the largest ratios. I suppose you could claim that we have genuine prior information that the weights shouldn’t be that big.
This procedure is claimed to provide better estimates than plain IS or truncated weights IS. But it can also be used as a diagnostic.
The psis
function in the loo
package takes as input some collection of importance ratios on the log scale as log_ratios
.
See these slides by Wittawat Jitkrittum. Seth suggests starting with MMD with Gaussian kernel and median heuristic to pick length-scale.
sessionInfo()
## R version 4.2.0 (2022-04-22)
## Platform: x86_64-apple-darwin17.0 (64-bit)
## Running under: macOS Monterey 12.6
##
## Matrix products: default
## LAPACK: /Library/Frameworks/R.framework/Versions/4.2/Resources/lib/libRlapack.dylib
##
## locale:
## [1] en_GB.UTF-8/en_GB.UTF-8/en_GB.UTF-8/C/en_GB.UTF-8/en_GB.UTF-8
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] multi.utils_0.1.0 inf.utils_0.1.0 aghq_0.4.0 tmbstan_1.0.4 rstan_2.21.5 StanHeaders_2.21.0-7 TMB_1.9.1 Matrix_1.5-1
## [9] stringr_1.4.0 purrr_0.3.4 tidyr_1.2.0 readr_2.1.2 ggplot2_3.3.6 forcats_0.5.1 dplyr_1.0.9
##
## loaded via a namespace (and not attached):
## [1] matrixStats_0.62.0 webshot_0.5.3 httr_1.4.3 rprojroot_2.0.3 numDeriv_2016.8-1.1 tools_4.2.0 bslib_0.3.1 utf8_1.2.2
## [9] R6_2.5.1 DBI_1.1.2 colorspace_2.0-3 withr_2.5.0 tidyselect_1.1.2 gridExtra_2.3 prettyunits_1.1.1 processx_3.5.3
## [17] mvQuad_1.0-6 compiler_4.2.0 cli_3.3.0 rvest_1.0.2 xml2_1.3.3 labeling_0.4.2 bookdown_0.26 sass_0.4.1
## [25] scales_1.2.0 askpass_1.1 callr_3.7.0 systemfonts_1.0.4 digest_0.6.29 rmarkdown_2.14 svglite_2.1.0 pkgconfig_2.0.3
## [33] htmltools_0.5.2 fastmap_1.1.0 highr_0.9 rlang_1.0.2 rstudioapi_0.13 jquerylib_0.1.4 generics_0.1.2 farver_2.1.0
## [41] jsonlite_1.8.0 inline_0.3.19 magrittr_2.0.3 polynom_1.4-1 kableExtra_1.3.4 loo_2.5.1 Rcpp_1.0.8.3 munsell_0.5.0
## [49] fansi_1.0.3 lifecycle_1.0.1 stringi_1.7.6 yaml_2.3.5 pkgbuild_1.3.1 grid_4.2.0 parallel_4.2.0 crayon_1.5.1
## [57] lattice_0.20-45 cowplot_1.1.1 hms_1.1.1 knitr_1.39 ps_1.7.0 pillar_1.7.0 uuid_1.1-0 codetools_0.2-18
## [65] stats4_4.2.0 glue_1.6.2 evaluate_0.15 data.table_1.14.2 RcppParallel_5.1.5 vctrs_0.4.1 tzdb_0.3.0 ids_1.0.1
## [73] openssl_2.0.1 gtable_0.3.0 assertthat_0.2.1 xfun_0.31 viridisLite_0.4.0 tibble_3.1.7 orderly_1.4.3 statmod_1.4.36
## [81] ellipsis_0.3.2