1 Background

We compare the inference results from TMB, aghq, and tmbstan. Import these inference results as follows:

tmb <- readRDS("depends/tmb.rds")
aghq <- readRDS("depends/aghq.rds")
tmbstan <- readRDS("depends/tmbstan.rds")

2 Data manipulation

We’re interested in the latent field and certain model outputs (prevalence, incidence and treatment coverage):

latent_pars <- unique(names(tmb$fit$par.full))[!(unique(names(tmb$fit$par.full)) %in% names(tmb$fit$par))]
output_pars <- c("rho_t1_out", "lambda_t1_out", "alpha_t1_out")        

type_key <- data.frame(
  par = c(latent_pars, output_pars),
  type = c(rep("latent", times = length(latent_pars)), rep("output", times = length(output_pars)))
)

Extract the mean and standard deviation for each parameter for each inference method, then bind them together:

df_tmb <- lapply(tmb$fit$sample[type_key$par], function(row) {
  data.frame("mean" = rowMeans(row), "sd" = matrixStats::rowSds(row)) %>%
    tibble::rowid_to_column("id")
  }) %>%
  dplyr::bind_rows(.id = "par") %>%
  mutate(method = "TMB")

df_aghq <- lapply(aghq$quad$sample[type_key$par], function(row) {
  data.frame("mean" = rowMeans(row), "sd" = matrixStats::rowSds(row)) %>%
    tibble::rowid_to_column("id")
  }) %>%
  dplyr::bind_rows(.id = "par") %>%
  mutate(method = "PCA-AGHQ")

df_tmbstan <- lapply(tmbstan$mcmc$sample[type_key$par], function(row) {
  data.frame("mean" = rowMeans(row), "sd" = matrixStats::rowSds(row)) %>%
    tibble::rowid_to_column("id")
  }) %>%
  dplyr::bind_rows(.id = "par") %>%
  mutate(method = "NUTS")

df <- bind_rows(df_tmb, df_aghq, df_tmbstan) %>%
  left_join(
    type_key,
    by = "par"
  )

head(df)

Pivot so that the data is ready to plot:

df_plot <- df %>%
  pivot_longer(cols = c("sd", "mean"), names_to = "indicator", values_to = "estimate") %>%
  pivot_wider(values_from = "estimate", names_from = "method") %>%
  pivot_longer(cols = c("TMB", "PCA-AGHQ"), names_to = "method", values_to = "approximate") %>%
  rename("truth" = "NUTS") %>%
  mutate(
    indicator = fct_recode(indicator, "Posterior mean estimate" = "mean", "Posterior SD estimate" = "sd"),
    method = fct_relevel(method, "TMB", "PCA-AGHQ")
  )

head(df_plot)

Calculate the root mean square and mean absolute errors between the approximate methods and NUTS (taken to be the truth):

df_metrics <- df_plot %>%
  group_by(method, indicator, type) %>%
  summarise(
    rmse = sqrt(mean((truth - approximate)^2)),
    mae = mean(abs(truth - approximate))
  ) %>%
  ungroup()
## `summarise()` has grouped output by 'method', 'indicator'. You can override using the `.groups` argument.
head(df_metrics)

Add difference in metric between TMB and PCA-AGHQ:

df_metrics_pct <- df_metrics %>%
  split(~ indicator + type) %>%
  lapply(function(x) {
    x %>% summarise(
      rmse_diff = 100 * diff(rmse) / max(rmse),
      mae_diff = 100 * diff(mae) / max(mae)
    )
  }) %>%
  bind_rows(.id = "id") %>%
  separate(id, c("indicator", "type"), sep = "\\.")

#' Add as a naming column for plot
df_metrics <- df_metrics %>%
  left_join(
    df_metrics_pct,
    by = c("indicator", "type")
  ) %>%
  mutate(
    label = ifelse(
      method == "PCA-AGHQ",
      paste0("RMSE: ", signif(rmse, 2), " (", signif(rmse_diff, 2), "%)", "\nMAE: ", signif(mae, 2), " (", signif(mae_diff, 2), "%)"),
      paste0("RMSE: ", signif(rmse, 2), "\nMAE: ", signif(mae, 2))
    )
  )

write_csv(df_metrics, "mean-sd.csv")

3 Plots

3.1 Latent field

df_plot_latent <- filter(df_plot, type == "latent")
df_metrics_latent <- filter(df_metrics, type == "latent")

df_plot_latent %>%
  ggplot(aes(x = truth, y = approximate - truth)) +
  geom_point(shape = 1, alpha = 0.4) +
  facet_grid(indicator ~ method) +
  geom_abline(slope = 0, intercept = 0, linetype = "dashed") +
  geom_text(data = df_metrics_latent, aes(x = -Inf, y = Inf, label = label), size = 3, hjust = 0, vjust = 1.5) +
  labs(x = "NUTS", y = "Approximation - NUTS") +
  theme_minimal()

ggsave("mean-sd-latent.png", h = 6, w = 6.25)

Version split into two for presentations

jitter_amount <- 0.02

mean_plot_latent <- df_plot_latent %>%
  filter(indicator == "Posterior mean estimate") %>%
  ggplot(aes(x = truth, y = approximate - truth)) +
  geom_jitter(shape = 1, alpha = 0.4, width = jitter_amount, height = jitter_amount) +
  lims(y = c(-0.4, 0.4)) +
  facet_grid(indicator ~ method) +
  geom_abline(slope = 0, intercept = 0, linetype = "dashed", size = 0.25) +
  geom_text(
    data = filter(df_metrics_latent, indicator == "Posterior mean estimate"),
    aes(x = -Inf, y = Inf, label = label), size = 3, hjust = 0, vjust = 1.5
  ) +
  labs(x = "NUTS", y = "Approximation - NUTS") +
  theme_minimal()

ggsave("mean-latent.png", mean_plot_latent, h = 4, w = 6.25)

sd_plot_latent <- df_plot_latent %>%
  filter(indicator == "Posterior SD estimate") %>%
  ggplot(aes(x = truth, y = approximate - truth)) +
  geom_jitter(shape = 1, alpha = 0.4, width = jitter_amount, height = jitter_amount) +
  lims(y = c(-0.6, 0.6)) +
  facet_grid(indicator ~ method) +
  geom_abline(slope = 0, intercept = 0, linetype = "dashed", size = 0.25) +
  geom_text(
    data = filter(df_metrics_latent, indicator == "Posterior SD estimate"),
    aes(x = -Inf, y = Inf, label = label), size = 3, hjust = 0, vjust = 1.5
  ) +
  labs(x = "NUTS", y = "Approximation - NUTS") +
  theme_minimal()

ggsave("sd-latent.png", sd_plot_latent, h = 4, w = 6.25)

sd_plot_latent_alt <- sd_plot_latent +
  theme(
    strip.background = element_blank(),
    strip.text.x = element_blank()
  )

y_axis <- ggplot(data.frame(l = mean_plot_latent$labels$y, x = 1, y = 1)) +
  geom_text(aes(x, y, label = l), angle = 90) +
  theme_void() +
  coord_cartesian(clip = "off")

mean_plot_latent_alt <- mean_plot_latent
mean_plot_latent_alt$labels$x <- ""
mean_plot_latent_alt$labels$y <- sd_plot_latent_alt$labels$y <- ""

(mean_sd_plot_latent_alt <- y_axis + (mean_plot_latent_alt / sd_plot_latent_alt) +
  plot_layout(widths = c(1, 30)))

ggsave("mean-sd-alt-latent.png", mean_sd_plot_latent_alt, h = 6, w = 6.25)

Where are the greatest errors occurring for the mean?

df_plot_latent %>%
  filter(indicator == "Posterior mean estimate") %>%
  mutate(diff = abs(approximate - truth)) %>%
  arrange(desc(diff)) %>%
  select(indicator, par, id, method, truth, approximate, diff) %>%
  DT::datatable()

What about for the standard deviations?

df_plot_latent %>%
  filter(indicator == "Posterior SD estimate") %>%
  mutate(diff = abs(approximate - truth)) %>%
  arrange(desc(diff)) %>%
  select(indicator, par, id, method, truth, approximate, diff) %>%
  DT::datatable()

3.2 Outputs

df_plot_output <- filter(df_plot, type == "output")
df_metrics_output <- filter(df_metrics, type == "output")
cbpalette <- c("#56B4E9","#009E73", "#E69F00", "#F0E442", "#0072B2", "#D55E00", "#CC79A7", "#999999")

df_plot_output %>%
  ggplot() +
  geom_point(aes(x = truth, y = approximate - truth, color = par), shape = 1, alpha = 0.4) +
  scale_color_manual(values = cbpalette) +
  facet_grid(indicator ~ method) +
  geom_abline(slope = 0, intercept = 0, linetype = "dashed") +
  geom_text(data = df_metrics_output, aes(x = -Inf, y = Inf, label = label), size = 3, hjust = 0, vjust = 1.5) +
  labs(x = "NUTS", y = "Approximation - NUTS", col = "Output") +
  theme_minimal() +
  theme(legend.position = "bottom")

ggsave("mean-sd-output.png", h = 6, w = 6.25)

Version split into two for presentations

jitter_amount <- 0

mean_plot_output <- df_plot_output %>%
  filter(indicator == "Posterior mean estimate") %>%
  ggplot() +
  geom_jitter(aes(x = truth, y = approximate - truth, color = par), shape = 1, alpha = 0.4, width = jitter_amount, height = jitter_amount) +
  scale_color_manual(values = cbpalette) +
  facet_grid(indicator ~ method) +
  geom_abline(slope = 0, intercept = 0, linetype = "dashed", size = 0.25) +
  geom_text(
    data = filter(df_metrics_output, indicator == "Posterior mean estimate"),
    aes(x = -Inf, y = Inf, label = label), size = 3, hjust = 0, vjust = 1.5
  ) +
  labs(x = "NUTS", y = "Approximation - NUTS", col = "Output") +
  theme_minimal()

ggsave("mean-output.png", mean_plot_output, h = 4, w = 6.25)

sd_plot_output <- df_plot_output %>%
  filter(indicator == "Posterior SD estimate") %>%
  ggplot() +
  geom_jitter(aes(x = truth, y = approximate - truth, color = par), shape = 1, alpha = 0.4, width = jitter_amount, height = jitter_amount) +
  scale_color_manual(values = cbpalette) +
  facet_grid(indicator ~ method) +
  geom_abline(slope = 0, intercept = 0, linetype = "dashed", size = 0.25) +
  geom_text(
    data = filter(df_metrics_output, indicator == "Posterior SD estimate"),
    aes(x = -Inf, y = Inf, label = label), size = 3, hjust = 0, vjust = 1.5
  ) +
  labs(x = "NUTS", y = "Approximation - NUTS", col = "Output") +
  theme_minimal()

ggsave("sd-output.png", sd_plot_output, h = 4, w = 6.25)

sd_plot_output_alt <- sd_plot_output +
  theme(
    strip.background = element_blank(),
    strip.text.x = element_blank(),
    legend.position = "bottom"
  )

y_axis <- ggplot(data.frame(l = mean_plot_output$labels$y, x = 1, y = 1)) +
  geom_text(aes(x, y, label = l), angle = 90) +
  theme_void() +
  coord_cartesian(clip = "off")

mean_plot_output_alt <- mean_plot_output +
  theme(legend.position = "none")

mean_plot_output_alt$labels$x <- ""
mean_plot_output_alt$labels$y <- sd_plot_output_alt$labels$y <- ""

(mean_sd_plot_output_alt <- y_axis + (mean_plot_output_alt / sd_plot_output_alt) +
  plot_layout(widths = c(1, 30)))

ggsave("mean-sd-alt-output.png", mean_sd_plot_output_alt, h = 6, w = 6.25, bg = "white")

Version split by output type:

df_metrics_output_split <- df_plot_output %>%
  group_by(method, indicator, type, par) %>%
  summarise(
    rmse = sqrt(mean((truth - approximate)^2)),
    mae = mean(abs(truth - approximate))
  ) %>%
  ungroup()
## `summarise()` has grouped output by 'method', 'indicator', 'type'. You can override using the `.groups` argument.
df_metrics_output_split_pct <- df_metrics_output_split %>%
  split(~ indicator + type + par) %>%
  lapply(function(x) {
    x %>% summarise(
      rmse_diff = 100 * diff(rmse) / max(rmse),
      mae_diff = 100 * diff(mae) / max(mae)
    )
  }) %>%
  bind_rows(.id = "id") %>%
  separate(id, c("indicator", "type", "par"), sep = "\\.")

df_metrics_output_split <- df_metrics_output_split %>%
  left_join(
    df_metrics_output_split_pct,
    by = c("indicator", "type", "par")
  ) %>%
  mutate(
    label = ifelse(
      method == "PCA-AGHQ",
      paste0("RMSE: ", signif(rmse, 2), " (", signif(rmse_diff, 2), "%)", "\nMAE: ", signif(mae, 2), " (", signif(mae_diff, 2), "%)"),
      paste0("RMSE: ", signif(rmse, 2), "\nMAE: ", signif(mae, 2))
    )
  )

plot_outputs <- function(.indicator, .par, col) {
  df_plot_output %>%
  filter(indicator == .indicator, par == .par) %>%
  ggplot() +
  geom_jitter(aes(x = truth, y = approximate - truth, color = par), shape = 1, alpha = 0.4, width = jitter_amount, height = jitter_amount) +
  scale_color_manual(values = c(col, "white", "white")) +
  facet_grid(indicator + par ~ method) +
  geom_abline(slope = 0, intercept = 0, linetype = "dashed", size = 0.25) +
  geom_text(
    data = filter(df_metrics_output_split, indicator == .indicator, par == .par),
    aes(x = -Inf, y = Inf, label = label), size = 3, hjust = 0, vjust = 1.5
  ) +
  labs(x = "NUTS", y = "Approximation - NUTS", col = "Output") +
  theme_minimal() +
  theme(legend.position = "none")
}

plot_outputs(.indicator = "Posterior mean estimate", .par = "rho_t1_out", col = "#E69F00")

plot_outputs(.indicator = "Posterior mean estimate", .par = "lambda_t1_out", col = "#009E73")

plot_outputs(.indicator = "Posterior mean estimate", .par = "alpha_t1_out", col = "#56B4E9")

plot_outputs(.indicator = "Posterior SD estimate", .par = "rho_t1_out", col = "#E69F00")

plot_outputs(.indicator = "Posterior SD estimate", .par = "lambda_t1_out", col = "#009E73")

plot_outputs(.indicator = "Posterior SD estimate", .par = "alpha_t1_out", col = "#56B4E9")

Original computing environment

sessionInfo()
## R version 4.2.0 (2022-04-22)
## Platform: x86_64-apple-darwin17.0 (64-bit)
## Running under: macOS 13.3.1
## 
## Matrix products: default
## LAPACK: /Library/Frameworks/R.framework/Versions/4.2/Resources/lib/libRlapack.dylib
## 
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
##  [1] multi.utils_0.1.0    patchwork_1.1.2      sf_1.0-9             bayesplot_1.9.0      rstan_2.21.5         StanHeaders_2.21.0-7 Matrix_1.5-4.1       stringr_1.5.0        purrr_1.0.1         
## [10] tidyr_1.3.0          readr_2.1.3          ggplot2_3.4.0        forcats_0.5.2        dplyr_1.0.10        
## 
## loaded via a namespace (and not attached):
##  [1] traduire_0.0.6      matrixStats_0.62.0  bit64_4.0.5         rprojroot_2.0.3     numDeriv_2016.8-1.1 tools_4.2.0         TMB_1.9.2           bslib_0.4.1         utf8_1.2.3         
## [10] R6_2.5.1            DT_0.26             KernSmooth_2.23-20  DBI_1.1.3           colorspace_2.0-3    withr_2.5.0         tidyselect_1.2.0    gridExtra_2.3       prettyunits_1.1.1  
## [19] processx_3.8.0      mvQuad_1.0-6        curl_5.0.0          bit_4.0.5           compiler_4.2.0      aghq_0.4.1          textshaping_0.3.6   cli_3.6.1           naomi_2.8.5        
## [28] bookdown_0.26       labeling_0.4.2      sass_0.4.4          scales_1.2.1        mvtnorm_1.1-3       classInt_0.4-8      askpass_1.1         ggridges_0.5.3      callr_3.7.3        
## [37] proxy_0.4-27        systemfonts_1.0.4   digest_0.6.31       rmarkdown_2.18      pkgconfig_2.0.3     htmltools_0.5.3     highr_0.9           fastmap_1.1.0       htmlwidgets_1.5.4  
## [46] rlang_1.1.0         rstudioapi_0.14     jquerylib_0.1.4     generics_0.1.3      farver_2.1.1        jsonlite_1.8.4      crosstalk_1.2.0     vroom_1.6.0         inline_0.3.19      
## [55] magrittr_2.0.3      polynom_1.4-1       loo_2.5.1           Rcpp_1.0.10         munsell_0.5.0       fansi_1.0.4         lifecycle_1.0.3     stringi_1.7.8       tmbstan_1.0.4      
## [64] yaml_2.3.7          pkgbuild_1.3.1      plyr_1.8.8          grid_4.2.0          parallel_4.2.0      crayon_1.5.2        lattice_0.20-45     splines_4.2.0       hms_1.1.2          
## [73] knitr_1.41          ps_1.7.3            pillar_1.9.0        uuid_1.1-0          codetools_0.2-18    stats4_4.2.0        glue_1.6.2          evaluate_0.20       V8_4.2.2           
## [82] data.table_1.14.6   RcppParallel_5.1.5  ids_1.0.1           vctrs_0.6.1         tzdb_0.3.0          openssl_2.0.5       gtable_0.3.1        assertthat_0.2.1    cachem_1.0.6       
## [91] xfun_0.37           e1071_1.7-12        ragg_1.2.2          class_7.3-20        tibble_3.2.1        orderly_1.4.3       units_0.8-0         statmod_1.4.36      ellipsis_0.3.2

Bibliography