Abstract
Background In the report naomi-simple_fit
with parameter tmbstan = TRUE
, we used the NUTS algorithm to perform MCMC inference for the simplified Naomi model.
Task Here we assess whether or not the results of the MCMC are suitable using a range of diagnostic tools.
We start by obtaining results from the latest version of naomi-simple_fit
with tmbstan = TRUE
.
out <- readRDS("depends/out.rds")
mcmc <- out$mcmc$stanfit
depends <- yaml::read_yaml("orderly.yml")$depends
dependency_details <- function(i) {
report_name <- names(depends[[i]])
print(paste0("Inference results obtained from ", report_name, " with the query ", depends[[i]][[report_name]]$id))
report_id <- orderly::orderly_search(query = depends[[i]][[report_name]]$id, report_name)
print(paste0("Obtained report had ID ", report_id, " and was run with the following parameters:"))
print(orderly::orderly_info(report_id, report_name)$parameters)
}
dependency_details(1)
## [1] "Inference results obtained from naomi-simple_fit with the query latest(parameter:tmbstan == TRUE && parameter:niter > 50000)"
## [1] "Obtained report had ID 20230414-110613-89cd4244 and was run with the following parameters:"
## $tmbstan
## [1] TRUE
##
## $niter
## [1] 1e+05
##
## $nthin
## [1] 40
##
## $tmb
## [1] FALSE
##
## $aghq
## [1] FALSE
##
## $k
## [1] 1
##
## $grid_type
## [1] "product"
##
## $s
## [1] 1
##
## $hmc_laplace
## [1] FALSE
##
## $adam
## [1] FALSE
##
## $nsample
## [1] 5000
##
## $area_level
## [1] 4
This MCMC took 3.28 days to run
cbpalette <- c("#56B4E9", "#009E73", "#E69F00", "#F0E442", "#0072B2", "#D55E00", "#CC79A7", "#999999")
color_scheme_set("viridis")
ggplot2::theme_set(theme_minimal())
We are looking for values of \(\hat R\) less than 1.05 here.
rhats <- bayesplot::rhat(mcmc)
bayesplot::mcmc_rhat_data(rhats) %>%
ggplot(aes(x = value, y = parameter, col = description)) +
geom_segment(aes(yend = parameter, xend = ifelse(min(value) < 1, 1, -Inf)), na.rm = TRUE, alpha = 0.7) +
scale_color_manual(values = "#E69F00") +
geom_vline(xintercept = 1.05, linetype = "dashed", col = "grey40") +
labs(x = "Potential scale reduction factor", y = "NUTS parameter", col = "") +
theme_minimal() +
theme(
axis.text.y = element_blank(),
axis.ticks.y = element_blank(),
panel.grid.major = element_blank()
)
ggsave("rhat.png", h = 3, w = 6.25)
(big_rhats <- rhats[rhats > 1.05])
## named numeric(0)
length(big_rhats) / length(rhats)
## [1] 0
(max_rhat <- max(rhats))
## [1] 1.021442
Reasonable to be worried about values less than 0.1 here.
ratios <- bayesplot::neff_ratio(mcmc)
breaks <- c(0, 0.1, 0.25, 0.5, 0.75, 1)
bayesplot::mcmc_neff_data(ratios) %>%
ggplot(mapping = aes(x = value, y = parameter, color = description)) +
geom_segment(aes(yend = parameter, xend = -Inf), na.rm = TRUE, alpha = 0.7) +
scale_color_manual(values = c("#56B4E9", "#009E73", "#E69F00")) +
geom_vline(xintercept = 0.1, linetype = "dashed", col = "grey40") +
geom_vline(xintercept = 0.5, linetype = "dashed", col = "grey40") +
geom_vline(xintercept = 1, linetype = "dashed", col = "grey40") +
labs(x = "ESS ratio", y = "NUTS parameter", col = "") +
theme_minimal() +
theme(
axis.text.y = element_blank(),
axis.ticks.y = element_blank(),
panel.grid.major = element_blank()
)
ggsave("ratio.png", h = 3, w = 6.25)
(average_ess_ratio <- mean(ratios))
## [1] 0.2529535
What are the total effective sample sizes?
#' I think that this $summary should be all of the chains grouped together
mcmc_summary <- summary(mcmc)$summary
data.frame(mcmc_summary) %>%
tibble::rownames_to_column("param") %>%
ggplot(aes(x = n_eff)) +
geom_histogram(alpha = 0.8) +
labs(x = "ESS", y = "Count")
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
ggsave("ess.png", h = 3, w = 6.25)
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
(ess_min <- min(mcmc_summary[, "n_eff"]))
## [1] 208.0701
(ess_lower <- quantile(mcmc_summary[, "n_eff"], 0.025))
## 2.5%
## 318.1532
(ess_median <- quantile(mcmc_summary[, "n_eff"], 0.50))
## 50%
## 1231.034
(ess_upper <- quantile(mcmc_summary[, "n_eff"], 0.975))
## 97.5%
## 2776.293
(ess_max <- max(mcmc_summary[, "n_eff"]))
## [1] 4250.653
Save outputs for use in manuscript:
out <- list(
"max_rhat" = max_rhat,
"average_ess_ratio" = average_ess_ratio,
"ess_min" = ess_min,
"ess_lower" = ess_lower,
"ess_median" = ess_median,
"ess_upper" = ess_upper,
"ess_max" = ess_max
)
saveRDS(out, "out.rds")
How much autocorrelation is there in the chains?
bayesplot::mcmc_acf(mcmc, pars = vars(starts_with("beta")))
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("beta")))
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("logit")))
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("log_sigma")))
The parameters with the worst ESS and the worst \(\hat R\):
(plot <- bayesplot::mcmc_trace(mcmc, pars = c(names(which.min(mcmc_summary[, "n_eff"])), names(which.max(rhats)))))
ggsave("worst-trace.png", plot, h = 4, w = 6.25)
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("u_rho_x[")))
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("u_rho_xs[")))
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("us_rho_x[")))
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("us_rho_xs[")))
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("u_rho_a[")))
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("u_rho_as[")))
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("u_alpha_x[")))
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("u_alpha_xs[")))
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("us_alpha_x[")))
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("us_alpha_xs[")))
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("u_alpha_a[")))
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("u_alpha_as[")))
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("u_alpha_xa[")))
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("ui_lambda_x[")))
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("ui_anc_rho_x[")))
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("ui_anc_alpha_x[")))
bayesplot::mcmc_trace(mcmc, pars = vars(starts_with("log_or_gamma["))) #' N.B. these are from the ANC attendance model
Variation between units can be explained by high correlation and high standard deviation, or low correlation and low standard deviation. Hence there is an unidentifiabiility here that leads to correlated posteriors.
bayesplot::mcmc_pairs(mcmc, pars = c("log_sigma_alpha_a", "logit_phi_alpha_a"), diag_fun = "hist", off_diag_fun = "hex")
bayesplot::mcmc_pairs(mcmc, pars = c("log_sigma_alpha_as", "logit_phi_alpha_as"), diag_fun = "hist", off_diag_fun = "hex")
bayesplot::mcmc_pairs(mcmc, pars = c("log_sigma_rho_as", "logit_phi_rho_as"), diag_fun = "hist", off_diag_fun = "hex")
(plot <- bayesplot::mcmc_pairs(mcmc, pars = c("log_sigma_rho_a", "logit_phi_rho_a"), diag_fun = "hist", off_diag_fun = "hex"))
ggsave("rho_a.png", plot, h = 4, w = 6.25)
get_correlation <- function(par) {
cor(as.data.frame(rstan::extract(mcmc, c(paste0("log_sigma_", par), paste0("logit_phi_", par)))))[1, 2]
}
ar1_cor_df <- data.frame(
par = c("alpha_a", "alpha_as", "rho_as", "rho_a"),
type = "ar1"
) %>%
mutate(cor = purrr::map_dbl(get_correlation, .x = par))
ar1_cor_df
Does the supposed orthogonality of the BYM2 model play out? Looks like the answer is mostly yes.
bayesplot::mcmc_pairs(mcmc, pars = c("log_sigma_rho_x", "logit_phi_rho_x"), diag_fun = "hist", off_diag_fun = "hex")
bayesplot::mcmc_pairs(mcmc, pars = c("log_sigma_rho_xs", "logit_phi_rho_xs"), diag_fun = "hist", off_diag_fun = "hex")
bayesplot::mcmc_pairs(mcmc, pars = c("log_sigma_alpha_xs", "logit_phi_alpha_xs"), diag_fun = "hist", off_diag_fun = "hex")
(plot <- bayesplot::mcmc_pairs(mcmc, pars = c("log_sigma_alpha_x", "logit_phi_alpha_x"), diag_fun = "hist", off_diag_fun = "hex"))
ggsave("alpha_x.png", plot, h = 4, w = 6.25)
bym2_cor_df <- data.frame(
par = c("rho_x", "rho_xs", "alpha_x", "alpha_xs"),
type = "bym2"
) %>%
mutate(cor = purrr::map_dbl(get_correlation, .x = par))
bym2_cor_df
write_csv(bind_rows(ar1_cor_df, bym2_cor_df), file = "ar1-bym2-cor.csv")
There is a prior suspicion (from Jeff, Tim, Rachel) that the ART attendance model is unidentifiable.
Let’s have a look at the pairs plot for neighbouring districts and the log_or_gamma
parameter.
area_merged <- sf::read_sf(system.file("extdata/demo_areas.geojson", package = "naomi"))
nb <- area_merged %>%
filter(area_level == max(area_level)) %>%
bsae::sf_to_nb()
neighbours_pairs_plot <- function(par, i) {
neighbour_pars <- paste0(par, "[", c(i, nb[[i]]), "]")
bayesplot::mcmc_pairs(mcmc, pars = neighbour_pars, diag_fun = "hist", off_diag_fun = "hex")
}
Here are Nkhata Bay and neighbours:
neighbours_pairs_plot("log_or_gamma", 5)
And here are Blantyre and neighbours:
neighbours_pairs_plot("log_or_gamma", 26)
np <- bayesplot::nuts_params(mcmc)
saveRDS(np, "nuts-params.rds")
Are there any divergent transitions?
np %>%
filter(Parameter == "divergent__") %>%
summarise(n_divergent = sum(Value))
bayesplot::mcmc_nuts_divergence(np, bayesplot::log_posterior(mcmc))
We can also use energy plots (Betancourt 2017): ideally these two histograms would be the same When the histograms are quite different, it may suggest the chains are not fully exploring the tails of the target distribution.
bayesplot::mcmc_nuts_energy(np)
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.
sessionInfo()
## R version 4.2.0 (2022-04-22)
## Platform: x86_64-apple-darwin17.0 (64-bit)
## Running under: macOS 13.3.1
##
## Matrix products: default
## LAPACK: /Library/Frameworks/R.framework/Versions/4.2/Resources/lib/libRlapack.dylib
##
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] patchwork_1.1.2 tibble_3.2.1 tidyverse_1.3.1 rmarkdown_2.18 multi.utils_0.1.0
## [6] sf_1.0-9 bayesplot_1.9.0 rstan_2.21.5 StanHeaders_2.21.0-7 Matrix_1.5-4.1
## [11] stringr_1.5.0 purrr_1.0.1 tidyr_1.3.0 readr_2.1.3 ggplot2_3.4.0
## [16] forcats_0.5.2 dplyr_1.0.10
##
## loaded via a namespace (and not attached):
## [1] readxl_1.4.0 uuid_1.1-0 backports_1.4.1 systemfonts_1.0.4 tmbstan_1.0.4
## [6] plyr_1.8.8 sp_1.5-1 TMB_1.9.2 inline_0.3.19 digest_0.6.31
## [11] htmltools_0.5.3 fansi_1.0.4 magrittr_2.0.3 checkmate_2.1.0 memoise_2.0.1
## [16] naomi_2.8.5 tzdb_0.3.0 modelr_0.1.8 RcppParallel_5.1.5 matrixStats_0.62.0
## [21] vroom_1.6.0 askpass_1.1 prettyunits_1.1.1 colorspace_2.0-3 blob_1.2.3
## [26] rvest_1.0.2 textshaping_0.3.6 haven_2.5.0 xfun_0.37 callr_3.7.3
## [31] crayon_1.5.2 jsonlite_1.8.4 hexbin_1.28.2 glue_1.6.2 gtable_0.3.1
## [36] V8_4.2.2 distributional_0.3.0 pkgbuild_1.3.1 traduire_0.0.6 abind_1.4-5
## [41] scales_1.2.1 DBI_1.1.3 Rcpp_1.0.10 isoband_0.2.6 spData_2.2.1
## [46] units_0.8-0 bit_4.0.5 spdep_1.2-7 proxy_0.4-27 stats4_4.2.0
## [51] httr_1.4.4 wk_0.7.0 posterior_1.2.2 ellipsis_0.3.2 pkgconfig_2.0.3
## [56] loo_2.5.1 farver_2.1.1 sass_0.4.4 dbplyr_2.1.1 deldir_1.0-6
## [61] utf8_1.2.3 tidyselect_1.2.0 labeling_0.4.2 rlang_1.1.0 reshape2_1.4.4
## [66] munsell_0.5.0 cellranger_1.1.0 tools_4.2.0 cachem_1.0.6 bsae_0.2.7
## [71] cli_3.6.1 generics_0.1.3 RSQLite_2.2.14 broom_0.8.0 ggridges_0.5.3
## [76] evaluate_0.20 fastmap_1.1.0 yaml_2.3.7 ragg_1.2.2 rticles_0.23.6
## [81] processx_3.8.0 knitr_1.41 bit64_4.0.5 fs_1.6.1 s2_1.1.1
## [86] orderly_1.4.3 xml2_1.3.3 compiler_4.2.0 rstudioapi_0.14 curl_5.0.0
## [91] e1071_1.7-12 gt_0.8.0 reprex_2.0.1 statmod_1.4.36 bslib_0.4.1
## [96] stringi_1.7.8 highr_0.9 ids_1.0.1 ps_1.7.3 lattice_0.20-45
## [101] classInt_0.4-8 mvQuad_1.0-6 tensorA_0.36.2 vctrs_0.6.1 pillar_1.9.0
## [106] lifecycle_1.0.3 jquerylib_0.1.4 data.table_1.14.6 cowplot_1.1.1 R6_2.5.1
## [111] bookdown_0.26 KernSmooth_2.23-20 aghq_0.4.1 gridExtra_2.3 codetools_0.2-18
## [116] boot_1.3-28 assertthat_0.2.1 openssl_2.0.5 rprojroot_2.0.3 withr_2.5.0
## [121] parallel_4.2.0 hms_1.1.2 grid_4.2.0 class_7.3-20 lubridate_1.8.0