Abstract
Background We have run the simplified Naomi model using a range of inference methods: TMB
, aghq
, and tmbstan
.
Task In this report, we compare the accuracy of the posterior distributions obtained from these inference methods using point estimates.
We compare the inference results from TMB
, aghq
, and tmbstan
.
Import these inference results as follows:
tmb <- readRDS("depends/tmb.rds")
aghq <- readRDS("depends/aghq.rds")
tmbstan <- readRDS("depends/tmbstan.rds")
We’re interested in the latent field and certain model outputs (prevalence, incidence and treatment coverage):
latent_pars <- unique(names(tmb$fit$par.full))[!(unique(names(tmb$fit$par.full)) %in% names(tmb$fit$par))]
output_pars <- c("rho_t1_out", "lambda_t1_out", "alpha_t1_out")
type_key <- data.frame(
par = c(latent_pars, output_pars),
type = c(rep("latent", times = length(latent_pars)), rep("output", times = length(output_pars)))
)
Extract the mean and standard deviation for each parameter for each inference method, then bind them together:
df_tmb <- lapply(tmb$fit$sample[type_key$par], function(row) {
data.frame("mean" = rowMeans(row), "sd" = matrixStats::rowSds(row)) %>%
tibble::rowid_to_column("id")
}) %>%
dplyr::bind_rows(.id = "par") %>%
mutate(method = "TMB")
df_aghq <- lapply(aghq$quad$sample[type_key$par], function(row) {
data.frame("mean" = rowMeans(row), "sd" = matrixStats::rowSds(row)) %>%
tibble::rowid_to_column("id")
}) %>%
dplyr::bind_rows(.id = "par") %>%
mutate(method = "PCA-AGHQ")
df_tmbstan <- lapply(tmbstan$mcmc$sample[type_key$par], function(row) {
data.frame("mean" = rowMeans(row), "sd" = matrixStats::rowSds(row)) %>%
tibble::rowid_to_column("id")
}) %>%
dplyr::bind_rows(.id = "par") %>%
mutate(method = "NUTS")
df <- bind_rows(df_tmb, df_aghq, df_tmbstan) %>%
left_join(
type_key,
by = "par"
)
head(df)
Pivot so that the data is ready to plot:
df_plot <- df %>%
pivot_longer(cols = c("sd", "mean"), names_to = "indicator", values_to = "estimate") %>%
pivot_wider(values_from = "estimate", names_from = "method") %>%
pivot_longer(cols = c("TMB", "PCA-AGHQ"), names_to = "method", values_to = "approximate") %>%
rename("truth" = "NUTS") %>%
mutate(
indicator = fct_recode(indicator, "Posterior mean estimate" = "mean", "Posterior SD estimate" = "sd"),
method = fct_relevel(method, "TMB", "PCA-AGHQ")
)
head(df_plot)
Calculate the root mean square and mean absolute errors between the approximate methods and NUTS (taken to be the truth):
df_metrics <- df_plot %>%
group_by(method, indicator, type) %>%
summarise(
rmse = sqrt(mean((truth - approximate)^2)),
mae = mean(abs(truth - approximate))
) %>%
ungroup()
## `summarise()` has grouped output by 'method', 'indicator'. You can override using the `.groups` argument.
head(df_metrics)
Add difference in metric between TMB and PCA-AGHQ:
df_metrics_pct <- df_metrics %>%
split(~ indicator + type) %>%
lapply(function(x) {
x %>% summarise(
rmse_diff = 100 * diff(rmse) / max(rmse),
mae_diff = 100 * diff(mae) / max(mae)
)
}) %>%
bind_rows(.id = "id") %>%
separate(id, c("indicator", "type"), sep = "\\.")
#' Add as a naming column for plot
df_metrics <- df_metrics %>%
left_join(
df_metrics_pct,
by = c("indicator", "type")
) %>%
mutate(
label = ifelse(
method == "PCA-AGHQ",
paste0("RMSE: ", signif(rmse, 2), " (", signif(rmse_diff, 2), "%)", "\nMAE: ", signif(mae, 2), " (", signif(mae_diff, 2), "%)"),
paste0("RMSE: ", signif(rmse, 2), "\nMAE: ", signif(mae, 2))
)
)
write_csv(df_metrics, "mean-sd.csv")
df_plot_latent <- filter(df_plot, type == "latent")
df_metrics_latent <- filter(df_metrics, type == "latent")
df_plot_latent %>%
ggplot(aes(x = truth, y = approximate - truth)) +
geom_point(shape = 1, alpha = 0.4) +
facet_grid(indicator ~ method) +
geom_abline(slope = 0, intercept = 0, linetype = "dashed") +
geom_text(data = df_metrics_latent, aes(x = -Inf, y = Inf, label = label), size = 3, hjust = 0, vjust = 1.5) +
labs(x = "NUTS", y = "Approximation - NUTS") +
theme_minimal()
ggsave("mean-sd-latent.png", h = 6, w = 6.25)
Version split into two for presentations
jitter_amount <- 0.02
mean_plot_latent <- df_plot_latent %>%
filter(indicator == "Posterior mean estimate") %>%
ggplot(aes(x = truth, y = approximate - truth)) +
geom_jitter(shape = 1, alpha = 0.4, width = jitter_amount, height = jitter_amount) +
lims(y = c(-0.4, 0.4)) +
facet_grid(indicator ~ method) +
geom_abline(slope = 0, intercept = 0, linetype = "dashed", size = 0.25) +
geom_text(
data = filter(df_metrics_latent, indicator == "Posterior mean estimate"),
aes(x = -Inf, y = Inf, label = label), size = 3, hjust = 0, vjust = 1.5
) +
labs(x = "NUTS", y = "Approximation - NUTS") +
theme_minimal()
ggsave("mean-latent.png", mean_plot_latent, h = 4, w = 6.25)
sd_plot_latent <- df_plot_latent %>%
filter(indicator == "Posterior SD estimate") %>%
ggplot(aes(x = truth, y = approximate - truth)) +
geom_jitter(shape = 1, alpha = 0.4, width = jitter_amount, height = jitter_amount) +
lims(y = c(-0.6, 0.6)) +
facet_grid(indicator ~ method) +
geom_abline(slope = 0, intercept = 0, linetype = "dashed", size = 0.25) +
geom_text(
data = filter(df_metrics_latent, indicator == "Posterior SD estimate"),
aes(x = -Inf, y = Inf, label = label), size = 3, hjust = 0, vjust = 1.5
) +
labs(x = "NUTS", y = "Approximation - NUTS") +
theme_minimal()
ggsave("sd-latent.png", sd_plot_latent, h = 4, w = 6.25)
sd_plot_latent_alt <- sd_plot_latent +
theme(
strip.background = element_blank(),
strip.text.x = element_blank()
)
y_axis <- ggplot(data.frame(l = mean_plot_latent$labels$y, x = 1, y = 1)) +
geom_text(aes(x, y, label = l), angle = 90) +
theme_void() +
coord_cartesian(clip = "off")
mean_plot_latent_alt <- mean_plot_latent
mean_plot_latent_alt$labels$x <- ""
mean_plot_latent_alt$labels$y <- sd_plot_latent_alt$labels$y <- ""
(mean_sd_plot_latent_alt <- y_axis + (mean_plot_latent_alt / sd_plot_latent_alt) +
plot_layout(widths = c(1, 30)))
ggsave("mean-sd-alt-latent.png", mean_sd_plot_latent_alt, h = 6, w = 6.25)
Where are the greatest errors occurring for the mean?
df_plot_latent %>%
filter(indicator == "Posterior mean estimate") %>%
mutate(diff = abs(approximate - truth)) %>%
arrange(desc(diff)) %>%
select(indicator, par, id, method, truth, approximate, diff) %>%
DT::datatable()
What about for the standard deviations?
df_plot_latent %>%
filter(indicator == "Posterior SD estimate") %>%
mutate(diff = abs(approximate - truth)) %>%
arrange(desc(diff)) %>%
select(indicator, par, id, method, truth, approximate, diff) %>%
DT::datatable()
df_plot_output <- filter(df_plot, type == "output")
df_metrics_output <- filter(df_metrics, type == "output")
cbpalette <- c("#56B4E9","#009E73", "#E69F00", "#F0E442", "#0072B2", "#D55E00", "#CC79A7", "#999999")
df_plot_output %>%
ggplot() +
geom_point(aes(x = truth, y = approximate - truth, color = par), shape = 1, alpha = 0.4) +
scale_color_manual(values = cbpalette) +
facet_grid(indicator ~ method) +
geom_abline(slope = 0, intercept = 0, linetype = "dashed") +
geom_text(data = df_metrics_output, aes(x = -Inf, y = Inf, label = label), size = 3, hjust = 0, vjust = 1.5) +
labs(x = "NUTS", y = "Approximation - NUTS", col = "Output") +
theme_minimal() +
theme(legend.position = "bottom")
ggsave("mean-sd-output.png", h = 6, w = 6.25)
Version split into two for presentations
jitter_amount <- 0
mean_plot_output <- df_plot_output %>%
filter(indicator == "Posterior mean estimate") %>%
ggplot() +
geom_jitter(aes(x = truth, y = approximate - truth, color = par), shape = 1, alpha = 0.4, width = jitter_amount, height = jitter_amount) +
scale_color_manual(values = cbpalette) +
facet_grid(indicator ~ method) +
geom_abline(slope = 0, intercept = 0, linetype = "dashed", size = 0.25) +
geom_text(
data = filter(df_metrics_output, indicator == "Posterior mean estimate"),
aes(x = -Inf, y = Inf, label = label), size = 3, hjust = 0, vjust = 1.5
) +
labs(x = "NUTS", y = "Approximation - NUTS", col = "Output") +
theme_minimal()
ggsave("mean-output.png", mean_plot_output, h = 4, w = 6.25)
sd_plot_output <- df_plot_output %>%
filter(indicator == "Posterior SD estimate") %>%
ggplot() +
geom_jitter(aes(x = truth, y = approximate - truth, color = par), shape = 1, alpha = 0.4, width = jitter_amount, height = jitter_amount) +
scale_color_manual(values = cbpalette) +
facet_grid(indicator ~ method) +
geom_abline(slope = 0, intercept = 0, linetype = "dashed", size = 0.25) +
geom_text(
data = filter(df_metrics_output, indicator == "Posterior SD estimate"),
aes(x = -Inf, y = Inf, label = label), size = 3, hjust = 0, vjust = 1.5
) +
labs(x = "NUTS", y = "Approximation - NUTS", col = "Output") +
theme_minimal()
ggsave("sd-output.png", sd_plot_output, h = 4, w = 6.25)
sd_plot_output_alt <- sd_plot_output +
theme(
strip.background = element_blank(),
strip.text.x = element_blank(),
legend.position = "bottom"
)
y_axis <- ggplot(data.frame(l = mean_plot_output$labels$y, x = 1, y = 1)) +
geom_text(aes(x, y, label = l), angle = 90) +
theme_void() +
coord_cartesian(clip = "off")
mean_plot_output_alt <- mean_plot_output +
theme(legend.position = "none")
mean_plot_output_alt$labels$x <- ""
mean_plot_output_alt$labels$y <- sd_plot_output_alt$labels$y <- ""
(mean_sd_plot_output_alt <- y_axis + (mean_plot_output_alt / sd_plot_output_alt) +
plot_layout(widths = c(1, 30)))
ggsave("mean-sd-alt-output.png", mean_sd_plot_output_alt, h = 6, w = 6.25, bg = "white")
Version split by output type:
df_metrics_output_split <- df_plot_output %>%
group_by(method, indicator, type, par) %>%
summarise(
rmse = sqrt(mean((truth - approximate)^2)),
mae = mean(abs(truth - approximate))
) %>%
ungroup()
## `summarise()` has grouped output by 'method', 'indicator', 'type'. You can override using the `.groups` argument.
df_metrics_output_split_pct <- df_metrics_output_split %>%
split(~ indicator + type + par) %>%
lapply(function(x) {
x %>% summarise(
rmse_diff = 100 * diff(rmse) / max(rmse),
mae_diff = 100 * diff(mae) / max(mae)
)
}) %>%
bind_rows(.id = "id") %>%
separate(id, c("indicator", "type", "par"), sep = "\\.")
df_metrics_output_split <- df_metrics_output_split %>%
left_join(
df_metrics_output_split_pct,
by = c("indicator", "type", "par")
) %>%
mutate(
label = ifelse(
method == "PCA-AGHQ",
paste0("RMSE: ", signif(rmse, 2), " (", signif(rmse_diff, 2), "%)", "\nMAE: ", signif(mae, 2), " (", signif(mae_diff, 2), "%)"),
paste0("RMSE: ", signif(rmse, 2), "\nMAE: ", signif(mae, 2))
)
)
plot_outputs <- function(.indicator, .par, col) {
df_plot_output %>%
filter(indicator == .indicator, par == .par) %>%
ggplot() +
geom_jitter(aes(x = truth, y = approximate - truth, color = par), shape = 1, alpha = 0.4, width = jitter_amount, height = jitter_amount) +
scale_color_manual(values = c(col, "white", "white")) +
facet_grid(indicator + par ~ method) +
geom_abline(slope = 0, intercept = 0, linetype = "dashed", size = 0.25) +
geom_text(
data = filter(df_metrics_output_split, indicator == .indicator, par == .par),
aes(x = -Inf, y = Inf, label = label), size = 3, hjust = 0, vjust = 1.5
) +
labs(x = "NUTS", y = "Approximation - NUTS", col = "Output") +
theme_minimal() +
theme(legend.position = "none")
}
plot_outputs(.indicator = "Posterior mean estimate", .par = "rho_t1_out", col = "#E69F00")
plot_outputs(.indicator = "Posterior mean estimate", .par = "lambda_t1_out", col = "#009E73")
plot_outputs(.indicator = "Posterior mean estimate", .par = "alpha_t1_out", col = "#56B4E9")
plot_outputs(.indicator = "Posterior SD estimate", .par = "rho_t1_out", col = "#E69F00")
plot_outputs(.indicator = "Posterior SD estimate", .par = "lambda_t1_out", col = "#009E73")
plot_outputs(.indicator = "Posterior SD estimate", .par = "alpha_t1_out", col = "#56B4E9")
sessionInfo()
## R version 4.2.0 (2022-04-22)
## Platform: x86_64-apple-darwin17.0 (64-bit)
## Running under: macOS 13.3.1
##
## Matrix products: default
## LAPACK: /Library/Frameworks/R.framework/Versions/4.2/Resources/lib/libRlapack.dylib
##
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] multi.utils_0.1.0 patchwork_1.1.2 sf_1.0-9 bayesplot_1.9.0 rstan_2.21.5 StanHeaders_2.21.0-7 Matrix_1.5-4.1 stringr_1.5.0 purrr_1.0.1
## [10] tidyr_1.3.0 readr_2.1.3 ggplot2_3.4.0 forcats_0.5.2 dplyr_1.0.10
##
## loaded via a namespace (and not attached):
## [1] traduire_0.0.6 matrixStats_0.62.0 bit64_4.0.5 rprojroot_2.0.3 numDeriv_2016.8-1.1 tools_4.2.0 TMB_1.9.2 bslib_0.4.1 utf8_1.2.3
## [10] R6_2.5.1 DT_0.26 KernSmooth_2.23-20 DBI_1.1.3 colorspace_2.0-3 withr_2.5.0 tidyselect_1.2.0 gridExtra_2.3 prettyunits_1.1.1
## [19] processx_3.8.0 mvQuad_1.0-6 curl_5.0.0 bit_4.0.5 compiler_4.2.0 aghq_0.4.1 textshaping_0.3.6 cli_3.6.1 naomi_2.8.5
## [28] bookdown_0.26 labeling_0.4.2 sass_0.4.4 scales_1.2.1 mvtnorm_1.1-3 classInt_0.4-8 askpass_1.1 ggridges_0.5.3 callr_3.7.3
## [37] proxy_0.4-27 systemfonts_1.0.4 digest_0.6.31 rmarkdown_2.18 pkgconfig_2.0.3 htmltools_0.5.3 highr_0.9 fastmap_1.1.0 htmlwidgets_1.5.4
## [46] rlang_1.1.0 rstudioapi_0.14 jquerylib_0.1.4 generics_0.1.3 farver_2.1.1 jsonlite_1.8.4 crosstalk_1.2.0 vroom_1.6.0 inline_0.3.19
## [55] magrittr_2.0.3 polynom_1.4-1 loo_2.5.1 Rcpp_1.0.10 munsell_0.5.0 fansi_1.0.4 lifecycle_1.0.3 stringi_1.7.8 tmbstan_1.0.4
## [64] yaml_2.3.7 pkgbuild_1.3.1 plyr_1.8.8 grid_4.2.0 parallel_4.2.0 crayon_1.5.2 lattice_0.20-45 splines_4.2.0 hms_1.1.2
## [73] knitr_1.41 ps_1.7.3 pillar_1.9.0 uuid_1.1-0 codetools_0.2-18 stats4_4.2.0 glue_1.6.2 evaluate_0.20 V8_4.2.2
## [82] data.table_1.14.6 RcppParallel_5.1.5 ids_1.0.1 vctrs_0.6.1 tzdb_0.3.0 openssl_2.0.5 gtable_0.3.1 assertthat_0.2.1 cachem_1.0.6
## [91] xfun_0.37 e1071_1.7-12 ragg_1.2.2 class_7.3-20 tibble_3.2.1 orderly_1.4.3 units_0.8-0 statmod_1.4.36 ellipsis_0.3.2