Abstract
Background We have run the simplified Naomi model using a range of inference methods: TMB, aghq, and tmbstan.
Task In this report, we compare the accuracy of the posterior distributions obtained from these inference methods using point estimates.
We compare the inference results from TMB, aghq, and tmbstan.
Import these inference results as follows:
tmb <- readRDS("depends/tmb.rds")
aghq <- readRDS("depends/aghq.rds")
tmbstan <- readRDS("depends/tmbstan.rds")
We’re interested in the latent field and certain model outputs (prevalence, incidence and treatment coverage):
latent_pars <- unique(names(tmb$fit$par.full))[!(unique(names(tmb$fit$par.full)) %in% names(tmb$fit$par))]
output_pars <- c("rho_t1_out", "lambda_t1_out", "alpha_t1_out")
type_key <- data.frame(
par = c(latent_pars, output_pars),
type = c(rep("latent", times = length(latent_pars)), rep("output", times = length(output_pars)))
)
Extract the mean and standard deviation for each parameter for each inference method, then bind them together:
df_tmb <- lapply(tmb$fit$sample[type_key$par], function(row) {
data.frame("mean" = rowMeans(row), "sd" = matrixStats::rowSds(row)) %>%
tibble::rowid_to_column("id")
}) %>%
dplyr::bind_rows(.id = "par") %>%
mutate(method = "TMB")
df_aghq <- lapply(aghq$quad$sample[type_key$par], function(row) {
data.frame("mean" = rowMeans(row), "sd" = matrixStats::rowSds(row)) %>%
tibble::rowid_to_column("id")
}) %>%
dplyr::bind_rows(.id = "par") %>%
mutate(method = "PCA-AGHQ")
df_tmbstan <- lapply(tmbstan$mcmc$sample[type_key$par], function(row) {
data.frame("mean" = rowMeans(row), "sd" = matrixStats::rowSds(row)) %>%
tibble::rowid_to_column("id")
}) %>%
dplyr::bind_rows(.id = "par") %>%
mutate(method = "NUTS")
df <- bind_rows(df_tmb, df_aghq, df_tmbstan) %>%
left_join(
type_key,
by = "par"
)
head(df)
Pivot so that the data is ready to plot:
df_plot <- df %>%
pivot_longer(cols = c("sd", "mean"), names_to = "indicator", values_to = "estimate") %>%
pivot_wider(values_from = "estimate", names_from = "method") %>%
pivot_longer(cols = c("TMB", "PCA-AGHQ"), names_to = "method", values_to = "approximate") %>%
rename("truth" = "NUTS") %>%
mutate(
indicator = fct_recode(indicator, "Posterior mean estimate" = "mean", "Posterior SD estimate" = "sd"),
method = fct_relevel(method, "TMB", "PCA-AGHQ")
)
head(df_plot)
Calculate the root mean square and mean absolute errors between the approximate methods and NUTS (taken to be the truth):
df_metrics <- df_plot %>%
group_by(method, indicator, type) %>%
summarise(
rmse = sqrt(mean((truth - approximate)^2)),
mae = mean(abs(truth - approximate))
) %>%
ungroup()
## `summarise()` has grouped output by 'method', 'indicator'. You can override using the `.groups` argument.
head(df_metrics)
Add difference in metric between TMB and PCA-AGHQ:
df_metrics_pct <- df_metrics %>%
split(~ indicator + type) %>%
lapply(function(x) {
x %>% summarise(
rmse_diff = 100 * diff(rmse) / max(rmse),
mae_diff = 100 * diff(mae) / max(mae)
)
}) %>%
bind_rows(.id = "id") %>%
separate(id, c("indicator", "type"), sep = "\\.")
#' Add as a naming column for plot
df_metrics <- df_metrics %>%
left_join(
df_metrics_pct,
by = c("indicator", "type")
) %>%
mutate(
label = ifelse(
method == "PCA-AGHQ",
paste0("RMSE: ", signif(rmse, 2), " (", signif(rmse_diff, 2), "%)", "\nMAE: ", signif(mae, 2), " (", signif(mae_diff, 2), "%)"),
paste0("RMSE: ", signif(rmse, 2), "\nMAE: ", signif(mae, 2))
)
)
write_csv(df_metrics, "mean-sd.csv")
df_plot_latent <- filter(df_plot, type == "latent")
df_metrics_latent <- filter(df_metrics, type == "latent")
df_plot_latent %>%
ggplot(aes(x = truth, y = approximate - truth)) +
geom_point(shape = 1, alpha = 0.4) +
facet_grid(indicator ~ method) +
geom_abline(slope = 0, intercept = 0, linetype = "dashed") +
geom_text(data = df_metrics_latent, aes(x = -Inf, y = Inf, label = label), size = 3, hjust = 0, vjust = 1.5) +
labs(x = "NUTS", y = "Approximation - NUTS") +
theme_minimal()
ggsave("mean-sd-latent.png", h = 6, w = 6.25)
Version split into two for presentations
jitter_amount <- 0.02
mean_plot_latent <- df_plot_latent %>%
filter(indicator == "Posterior mean estimate") %>%
ggplot(aes(x = truth, y = approximate - truth)) +
geom_jitter(shape = 1, alpha = 0.4, width = jitter_amount, height = jitter_amount) +
lims(y = c(-0.4, 0.4)) +
facet_grid(indicator ~ method) +
geom_abline(slope = 0, intercept = 0, linetype = "dashed", size = 0.25) +
geom_text(
data = filter(df_metrics_latent, indicator == "Posterior mean estimate"),
aes(x = -Inf, y = Inf, label = label), size = 3, hjust = 0, vjust = 1.5
) +
labs(x = "NUTS", y = "Approximation - NUTS") +
theme_minimal()
ggsave("mean-latent.png", mean_plot_latent, h = 4, w = 6.25)
sd_plot_latent <- df_plot_latent %>%
filter(indicator == "Posterior SD estimate") %>%
ggplot(aes(x = truth, y = approximate - truth)) +
geom_jitter(shape = 1, alpha = 0.4, width = jitter_amount, height = jitter_amount) +
lims(y = c(-0.6, 0.6)) +
facet_grid(indicator ~ method) +
geom_abline(slope = 0, intercept = 0, linetype = "dashed", size = 0.25) +
geom_text(
data = filter(df_metrics_latent, indicator == "Posterior SD estimate"),
aes(x = -Inf, y = Inf, label = label), size = 3, hjust = 0, vjust = 1.5
) +
labs(x = "NUTS", y = "Approximation - NUTS") +
theme_minimal()
ggsave("sd-latent.png", sd_plot_latent, h = 4, w = 6.25)
sd_plot_latent_alt <- sd_plot_latent +
theme(
strip.background = element_blank(),
strip.text.x = element_blank()
)
y_axis <- ggplot(data.frame(l = mean_plot_latent$labels$y, x = 1, y = 1)) +
geom_text(aes(x, y, label = l), angle = 90) +
theme_void() +
coord_cartesian(clip = "off")
mean_plot_latent_alt <- mean_plot_latent
mean_plot_latent_alt$labels$x <- ""
mean_plot_latent_alt$labels$y <- sd_plot_latent_alt$labels$y <- ""
(mean_sd_plot_latent_alt <- y_axis + (mean_plot_latent_alt / sd_plot_latent_alt) +
plot_layout(widths = c(1, 30)))
ggsave("mean-sd-alt-latent.png", mean_sd_plot_latent_alt, h = 6, w = 6.25)
Where are the greatest errors occurring for the mean?
df_plot_latent %>%
filter(indicator == "Posterior mean estimate") %>%
mutate(diff = abs(approximate - truth)) %>%
arrange(desc(diff)) %>%
select(indicator, par, id, method, truth, approximate, diff) %>%
DT::datatable()
What about for the standard deviations?
df_plot_latent %>%
filter(indicator == "Posterior SD estimate") %>%
mutate(diff = abs(approximate - truth)) %>%
arrange(desc(diff)) %>%
select(indicator, par, id, method, truth, approximate, diff) %>%
DT::datatable()
df_plot_output <- filter(df_plot, type == "output")
df_metrics_output <- filter(df_metrics, type == "output")
cbpalette <- c("#56B4E9","#009E73", "#E69F00", "#F0E442", "#0072B2", "#D55E00", "#CC79A7", "#999999")
df_plot_output %>%
ggplot() +
geom_point(aes(x = truth, y = approximate - truth, color = par), shape = 1, alpha = 0.4) +
scale_color_manual(values = cbpalette) +
facet_grid(indicator ~ method) +
geom_abline(slope = 0, intercept = 0, linetype = "dashed") +
geom_text(data = df_metrics_output, aes(x = -Inf, y = Inf, label = label), size = 3, hjust = 0, vjust = 1.5) +
labs(x = "NUTS", y = "Approximation - NUTS", col = "Output") +
theme_minimal() +
theme(legend.position = "bottom")
ggsave("mean-sd-output.png", h = 6, w = 6.25)
Version split into two for presentations
jitter_amount <- 0
mean_plot_output <- df_plot_output %>%
filter(indicator == "Posterior mean estimate") %>%
ggplot() +
geom_jitter(aes(x = truth, y = approximate - truth, color = par), shape = 1, alpha = 0.4, width = jitter_amount, height = jitter_amount) +
scale_color_manual(values = cbpalette) +
facet_grid(indicator ~ method) +
geom_abline(slope = 0, intercept = 0, linetype = "dashed", size = 0.25) +
geom_text(
data = filter(df_metrics_output, indicator == "Posterior mean estimate"),
aes(x = -Inf, y = Inf, label = label), size = 3, hjust = 0, vjust = 1.5
) +
labs(x = "NUTS", y = "Approximation - NUTS", col = "Output") +
theme_minimal()
ggsave("mean-output.png", mean_plot_output, h = 4, w = 6.25)
sd_plot_output <- df_plot_output %>%
filter(indicator == "Posterior SD estimate") %>%
ggplot() +
geom_jitter(aes(x = truth, y = approximate - truth, color = par), shape = 1, alpha = 0.4, width = jitter_amount, height = jitter_amount) +
scale_color_manual(values = cbpalette) +
facet_grid(indicator ~ method) +
geom_abline(slope = 0, intercept = 0, linetype = "dashed", size = 0.25) +
geom_text(
data = filter(df_metrics_output, indicator == "Posterior SD estimate"),
aes(x = -Inf, y = Inf, label = label), size = 3, hjust = 0, vjust = 1.5
) +
labs(x = "NUTS", y = "Approximation - NUTS", col = "Output") +
theme_minimal()
ggsave("sd-output.png", sd_plot_output, h = 4, w = 6.25)
sd_plot_output_alt <- sd_plot_output +
theme(
strip.background = element_blank(),
strip.text.x = element_blank(),
legend.position = "bottom"
)
y_axis <- ggplot(data.frame(l = mean_plot_output$labels$y, x = 1, y = 1)) +
geom_text(aes(x, y, label = l), angle = 90) +
theme_void() +
coord_cartesian(clip = "off")
mean_plot_output_alt <- mean_plot_output +
theme(legend.position = "none")
mean_plot_output_alt$labels$x <- ""
mean_plot_output_alt$labels$y <- sd_plot_output_alt$labels$y <- ""
(mean_sd_plot_output_alt <- y_axis + (mean_plot_output_alt / sd_plot_output_alt) +
plot_layout(widths = c(1, 30)))
ggsave("mean-sd-alt-output.png", mean_sd_plot_output_alt, h = 6, w = 6.25, bg = "white")
Version split by output type:
df_metrics_output_split <- df_plot_output %>%
group_by(method, indicator, type, par) %>%
summarise(
rmse = sqrt(mean((truth - approximate)^2)),
mae = mean(abs(truth - approximate))
) %>%
ungroup()
## `summarise()` has grouped output by 'method', 'indicator', 'type'. You can override using the `.groups` argument.
df_metrics_output_split_pct <- df_metrics_output_split %>%
split(~ indicator + type + par) %>%
lapply(function(x) {
x %>% summarise(
rmse_diff = 100 * diff(rmse) / max(rmse),
mae_diff = 100 * diff(mae) / max(mae)
)
}) %>%
bind_rows(.id = "id") %>%
separate(id, c("indicator", "type", "par"), sep = "\\.")
df_metrics_output_split <- df_metrics_output_split %>%
left_join(
df_metrics_output_split_pct,
by = c("indicator", "type", "par")
) %>%
mutate(
label = ifelse(
method == "PCA-AGHQ",
paste0("RMSE: ", signif(rmse, 2), " (", signif(rmse_diff, 2), "%)", "\nMAE: ", signif(mae, 2), " (", signif(mae_diff, 2), "%)"),
paste0("RMSE: ", signif(rmse, 2), "\nMAE: ", signif(mae, 2))
)
)
plot_outputs <- function(.indicator, .par, col) {
df_plot_output %>%
filter(indicator == .indicator, par == .par) %>%
ggplot() +
geom_jitter(aes(x = truth, y = approximate - truth, color = par), shape = 1, alpha = 0.4, width = jitter_amount, height = jitter_amount) +
scale_color_manual(values = c(col, "white", "white")) +
facet_grid(indicator + par ~ method) +
geom_abline(slope = 0, intercept = 0, linetype = "dashed", size = 0.25) +
geom_text(
data = filter(df_metrics_output_split, indicator == .indicator, par == .par),
aes(x = -Inf, y = Inf, label = label), size = 3, hjust = 0, vjust = 1.5
) +
labs(x = "NUTS", y = "Approximation - NUTS", col = "Output") +
theme_minimal() +
theme(legend.position = "none")
}
plot_outputs(.indicator = "Posterior mean estimate", .par = "rho_t1_out", col = "#E69F00")
plot_outputs(.indicator = "Posterior mean estimate", .par = "lambda_t1_out", col = "#009E73")
plot_outputs(.indicator = "Posterior mean estimate", .par = "alpha_t1_out", col = "#56B4E9")
plot_outputs(.indicator = "Posterior SD estimate", .par = "rho_t1_out", col = "#E69F00")
plot_outputs(.indicator = "Posterior SD estimate", .par = "lambda_t1_out", col = "#009E73")
plot_outputs(.indicator = "Posterior SD estimate", .par = "alpha_t1_out", col = "#56B4E9")
sessionInfo()
## R version 4.2.0 (2022-04-22)
## Platform: x86_64-apple-darwin17.0 (64-bit)
## Running under: macOS 13.3.1
##
## Matrix products: default
## LAPACK: /Library/Frameworks/R.framework/Versions/4.2/Resources/lib/libRlapack.dylib
##
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] multi.utils_0.1.0 patchwork_1.1.2 sf_1.0-9 bayesplot_1.9.0 rstan_2.21.5 StanHeaders_2.21.0-7 Matrix_1.5-4.1 stringr_1.5.0 purrr_1.0.1
## [10] tidyr_1.3.0 readr_2.1.3 ggplot2_3.4.0 forcats_0.5.2 dplyr_1.0.10
##
## loaded via a namespace (and not attached):
## [1] traduire_0.0.6 matrixStats_0.62.0 bit64_4.0.5 rprojroot_2.0.3 numDeriv_2016.8-1.1 tools_4.2.0 TMB_1.9.2 bslib_0.4.1 utf8_1.2.3
## [10] R6_2.5.1 DT_0.26 KernSmooth_2.23-20 DBI_1.1.3 colorspace_2.0-3 withr_2.5.0 tidyselect_1.2.0 gridExtra_2.3 prettyunits_1.1.1
## [19] processx_3.8.0 mvQuad_1.0-6 curl_5.0.0 bit_4.0.5 compiler_4.2.0 aghq_0.4.1 textshaping_0.3.6 cli_3.6.1 naomi_2.8.5
## [28] bookdown_0.26 labeling_0.4.2 sass_0.4.4 scales_1.2.1 mvtnorm_1.1-3 classInt_0.4-8 askpass_1.1 ggridges_0.5.3 callr_3.7.3
## [37] proxy_0.4-27 systemfonts_1.0.4 digest_0.6.31 rmarkdown_2.18 pkgconfig_2.0.3 htmltools_0.5.3 highr_0.9 fastmap_1.1.0 htmlwidgets_1.5.4
## [46] rlang_1.1.0 rstudioapi_0.14 jquerylib_0.1.4 generics_0.1.3 farver_2.1.1 jsonlite_1.8.4 crosstalk_1.2.0 vroom_1.6.0 inline_0.3.19
## [55] magrittr_2.0.3 polynom_1.4-1 loo_2.5.1 Rcpp_1.0.10 munsell_0.5.0 fansi_1.0.4 lifecycle_1.0.3 stringi_1.7.8 tmbstan_1.0.4
## [64] yaml_2.3.7 pkgbuild_1.3.1 plyr_1.8.8 grid_4.2.0 parallel_4.2.0 crayon_1.5.2 lattice_0.20-45 splines_4.2.0 hms_1.1.2
## [73] knitr_1.41 ps_1.7.3 pillar_1.9.0 uuid_1.1-0 codetools_0.2-18 stats4_4.2.0 glue_1.6.2 evaluate_0.20 V8_4.2.2
## [82] data.table_1.14.6 RcppParallel_5.1.5 ids_1.0.1 vctrs_0.6.1 tzdb_0.3.0 openssl_2.0.5 gtable_0.3.1 assertthat_0.2.1 cachem_1.0.6
## [91] xfun_0.37 e1071_1.7-12 ragg_1.2.2 class_7.3-20 tibble_3.2.1 orderly_1.4.3 units_0.8-0 statmod_1.4.36 ellipsis_0.3.2