1 Background

We compare the inference results from TMB, aghq, and tmbstan. Import these inference results as follows:

tmb <- readRDS("depends/tmb.rds")
aghq <- readRDS("depends/aghq.rds")
tmbstan <- readRDS("depends/tmbstan.rds")

depends <- yaml::read_yaml("orderly.yml")$depends

Check that the parameters (latent field, hyperparameters, model outputs) sampled from each of the four methods are the same:

stopifnot(names(tmb$fit$sample) == names(aghq$quad$sample))
stopifnot(names(tmb$fit$sample) == names(tmbstan$mcmc$sample))

1.1 Run details

For more information about the conditions under which these results were generated, see:

1.1.1 TMB

dependency_details <- function(i) {
  report_name <- names(depends[[i]])
  print(paste0("Inference results obtained from ", report_name, " with the query ", depends[[i]][[report_name]]$id))
  report_id <- orderly::orderly_search(query = depends[[i]][[report_name]]$id, report_name)
  print(paste0("Obtained report had ID ", report_id, " and was run with the following parameters:"))
  print(orderly::orderly_info(report_id, report_name)$parameters)
}

dependency_details(1)
## [1] "Inference results obtained from naomi-simple_fit with the query latest(parameter:tmb == TRUE && parameter:random_only == TRUE)"
## [1] "Obtained report had ID 20230705-101408-12ab65a3 and was run with the following parameters:"
## $tmb
## [1] TRUE
## 
## $random_only
## [1] TRUE
## 
## $sample
## [1] TRUE
## 
## $aghq
## [1] FALSE
## 
## $k
## [1] 1
## 
## $grid_type
## [1] "product"
## 
## $s
## [1] 1
## 
## $tmbstan
## [1] FALSE
## 
## $hmc_laplace
## [1] FALSE
## 
## $niter
## [1] 1000
## 
## $nthin
## [1] 1
## 
## $adam
## [1] FALSE
## 
## $nsample
## [1] 1000
## 
## $area_level
## [1] 4

1.1.2 aghq

dependency_details(2)
## [1] "Inference results obtained from naomi-simple_fit with the query latest(parameter:aghq == TRUE && parameter:k == 3 && parameter:s == 8)"
## [1] "Obtained report had ID 20230714-122557-a10aa13d and was run with the following parameters:"
## $aghq
## [1] TRUE
## 
## $k
## [1] 3
## 
## $s
## [1] 8
## 
## $grid_type
## [1] "scaled_pca"
## 
## $sample
## [1] TRUE
## 
## $tmb
## [1] FALSE
## 
## $random_only
## [1] TRUE
## 
## $tmbstan
## [1] FALSE
## 
## $hmc_laplace
## [1] FALSE
## 
## $niter
## [1] 1000
## 
## $nthin
## [1] 1
## 
## $adam
## [1] FALSE
## 
## $nsample
## [1] 1000
## 
## $area_level
## [1] 4

1.1.3 tmbstan

tmbstan_details <- dependency_details(3)
## [1] "Inference results obtained from naomi-simple_fit with the query latest(parameter:tmbstan == TRUE && parameter:niter > 50000)"
## [1] "Obtained report had ID 20230414-110613-89cd4244 and was run with the following parameters:"
## $tmbstan
## [1] TRUE
## 
## $niter
## [1] 1e+05
## 
## $nthin
## [1] 40
## 
## $tmb
## [1] FALSE
## 
## $aghq
## [1] FALSE
## 
## $k
## [1] 1
## 
## $grid_type
## [1] "product"
## 
## $s
## [1] 1
## 
## $hmc_laplace
## [1] FALSE
## 
## $adam
## [1] FALSE
## 
## $nsample
## [1] 5000
## 
## $area_level
## [1] 4
tmbstan_details
## $tmbstan
## [1] TRUE
## 
## $niter
## [1] 1e+05
## 
## $nthin
## [1] 40
## 
## $tmb
## [1] FALSE
## 
## $aghq
## [1] FALSE
## 
## $k
## [1] 1
## 
## $grid_type
## [1] "product"
## 
## $s
## [1] 1
## 
## $hmc_laplace
## [1] FALSE
## 
## $adam
## [1] FALSE
## 
## $nsample
## [1] 5000
## 
## $area_level
## [1] 4

1.2 Time taken

time_taken <- data.frame(
  "TMB" = tmb$time,
  "aghq" = aghq$time,
  "tmbstan" = tmbstan$time
)

write_csv(time_taken, "time-taken.csv")

time_taken

2 Histograms and ECDF difference plots

We create histograms and empirical cumulative distribution function (ECDF) difference plots of the samples from each method. All of the possible latent field and hyperparamter names are given by pars:

pars <- names(tmb$fit$sample)

There are also all of the Naomi outcome variables:

names(tmb$fit$sample)[!(names(tmb$fit$sample) %in% unique(names(tmb$fit$obj$env$par)))]
##  [1] "anc_clients_t1_out"            "population_t1_out"             "plhiv_attend_t1_out"           "anc_already_art_t1_out"        "anc_tested_pos_t1_out"         "anc_tested_neg_t1_out"        
##  [7] "anc_rho_obs_t1_ll"             "hhs_prev_ll"                   "u_rho_xa"                      "anc_art_new_t1_out"            "untreated_plhiv_attend_t1_out" "untreated_plhiv_num_t1_out"   
## [13] "artattend_ij_t1_out"           "anc_known_pos_t1_out"          "artattend_t1_out"              "rho_t1_out"                    "anc_rho_t1_out"                "artnum_t1_out"                
## [19] "anc_alpha_t1_out"              "alpha_t1_out"                  "infections_t1_out"             "anc_plhiv_t1_out"              "plhiv_t1_out"                  "lambda_t1_out"                
## [25] "hhs_artcov_ll"                 "artnum_t1_ll"

We will especially focus on the outputs:

  • HIV prevalence (rho_t1_out) which has 6417 variables (rows)
  • ART coverage (alpha_t1_out) which has 6417 variables (rows)
  • HIV incidence (lambda_t1_out) which has 6417 variables (rows)

These outputs are at an aggregate resolution as well as at the finest resolution. It seems to make sense to only to assess inferential accuracy at the finest resolution so as to avoid double counting. As such, we need to find the subset of these variables at the finest resolution. Within tmb$naomi_data there is a dataframe called mf_out (model frame output) which contains the area, age, sex mapping for these 6417 rows. As a sanity check, confirm that the number of areas times number of ages times number of sexes indeed equals the number of rows:

mf_out <- tmb$naomi_data$mf_out
(n_area <- length(unique(mf_out$area_id)))
## [1] 69
(n_age <- length(unique(mf_out$age_group)))
## [1] 31
(n_sex <- length(unique(mf_out$sex)))
## [1] 3
stopifnot(n_area * n_age * n_sex == dim(tmb$fit$sample$rho_t1_out)[1])

The unique levels for area, age and sex are:

mf_out$area_id %>% unique()
##  [1] "MWI"           "MWI_1_1_demo"  "MWI_1_2_demo"  "MWI_1_3_demo"  "MWI_2_1_demo"  "MWI_2_2_demo"  "MWI_2_3_demo"  "MWI_2_4_demo"  "MWI_2_5_demo"  "MWI_3_1_demo"  "MWI_3_10_demo" "MWI_3_11_demo"
## [13] "MWI_3_12_demo" "MWI_3_13_demo" "MWI_3_14_demo" "MWI_3_15_demo" "MWI_3_16_demo" "MWI_3_17_demo" "MWI_3_18_demo" "MWI_3_19_demo" "MWI_3_2_demo"  "MWI_3_20_demo" "MWI_3_21_demo" "MWI_3_22_demo"
## [25] "MWI_3_23_demo" "MWI_3_24_demo" "MWI_3_25_demo" "MWI_3_26_demo" "MWI_3_27_demo" "MWI_3_28_demo" "MWI_3_3_demo"  "MWI_3_4_demo"  "MWI_3_5_demo"  "MWI_3_6_demo"  "MWI_3_7_demo"  "MWI_3_8_demo" 
## [37] "MWI_3_9_demo"  "MWI_4_1_demo"  "MWI_4_10_demo" "MWI_4_11_demo" "MWI_4_12_demo" "MWI_4_13_demo" "MWI_4_14_demo" "MWI_4_15_demo" "MWI_4_16_demo" "MWI_4_17_demo" "MWI_4_18_demo" "MWI_4_19_demo"
## [49] "MWI_4_2_demo"  "MWI_4_20_demo" "MWI_4_21_demo" "MWI_4_22_demo" "MWI_4_23_demo" "MWI_4_24_demo" "MWI_4_25_demo" "MWI_4_26_demo" "MWI_4_27_demo" "MWI_4_28_demo" "MWI_4_29_demo" "MWI_4_3_demo" 
## [61] "MWI_4_30_demo" "MWI_4_31_demo" "MWI_4_32_demo" "MWI_4_4_demo"  "MWI_4_5_demo"  "MWI_4_6_demo"  "MWI_4_7_demo"  "MWI_4_8_demo"  "MWI_4_9_demo"
mf_out$age_group %>% unique()
##  [1] "Y000_004" "Y000_014" "Y000_064" "Y000_999" "Y005_009" "Y010_014" "Y010_019" "Y015_019" "Y015_024" "Y015_049" "Y015_064" "Y015_999" "Y020_024" "Y025_029" "Y025_034" "Y025_049" "Y030_034" "Y035_039"
## [19] "Y035_049" "Y040_044" "Y045_049" "Y050_054" "Y050_064" "Y050_999" "Y055_059" "Y060_064" "Y065_069" "Y065_999" "Y070_074" "Y075_079" "Y080_999"
mf_out$sex %>% unique()
## [1] "both"   "female" "male"

Filter to only the finest resolution, adding an id column has to enable merging to the samples:

mf_out_fine <- mf_out %>%
  tibble::rownames_to_column("id") %>%
  mutate(id = as.numeric(id)) %>%
  filter(
    area_id %in% paste0("MWI_4_", 1:32, "_demo"),
    sex %in% c("male", "female"),
    age_group %in%
      c(
        "Y000_004", "Y005_009", "Y010_014", "Y015_019", "Y020_024", "Y025_029",
        "Y025_034", "Y030_034", "Y035_039", "Y040_044", "Y045_049", "Y050_054",
        "Y055_059", "Y060_064", "Y065_069", "Y070_074", "Y075_079", "Y080_999"
      )
  )

We will produce plots about the a small subset of parameters for the time being. There is no particular reason to choose this subset rather than other, it’s quite arbitrary.

pars_eval <- pars %in% c("beta_rho", "u_rho_x")
names(pars_eval) <- pars

2.1 beta_rho

histogram_and_ecdf_list("beta_rho")
## [[1]]

## 
## [[2]]

2.2 beta_alpha

histogram_and_ecdf_list("beta_alpha")

2.3 beta_lambda

histogram_and_ecdf_list("beta_lambda")

2.4 beta_anc_rho

histogram_and_ecdf_list("beta_anc_rho")

2.5 beta_anc_alpha

histogram_and_ecdf("beta_anc_alpha")

2.6 logit

lapply(pars[stringr::str_starts(pars, "logit")], histogram_and_ecdf)

2.7 log_sigma

lapply(pars[stringr::str_starts(pars, "log_sigma")], histogram_and_ecdf)

2.8 u_rho_x

histogram_and_ecdf_list("u_rho_x")
## [[1]]

## 
## [[2]]

## 
## [[3]]

## 
## [[4]]

## 
## [[5]]

## 
## [[6]]

## 
## [[7]]

## 
## [[8]]

## 
## [[9]]

## 
## [[10]]

## 
## [[11]]

## 
## [[12]]

## 
## [[13]]

## 
## [[14]]

## 
## [[15]]

## 
## [[16]]

## 
## [[17]]

## 
## [[18]]

## 
## [[19]]

## 
## [[20]]

## 
## [[21]]

## 
## [[22]]

## 
## [[23]]

## 
## [[24]]

## 
## [[25]]

## 
## [[26]]

## 
## [[27]]

## 
## [[28]]

## 
## [[29]]

## 
## [[30]]

## 
## [[31]]

## 
## [[32]]

2.9 us_rho_x

histogram_and_ecdf_list("us_rho_x")

2.10 u_rho_xs

histogram_and_ecdf_list("u_rho_xs")

2.11 us_rho_xs

histogram_and_ecdf_list("us_rho_xs")

2.12 u_rho_as

histogram_and_ecdf_list("u_rho_as")

2.13 u_alpha_x

histogram_and_ecdf_list("u_alpha_x")

2.14 us_alpha_x

histogram_and_ecdf_list("us_alpha_x")

2.15 u_alpha_xs

histogram_and_ecdf_list("u_alpha_xs")

2.16 us_alpha_xs

histogram_and_ecdf_list("us_alpha_xs")
## [[1]]

## 
## [[2]]

## 
## [[3]]

## 
## [[4]]

## 
## [[5]]

## 
## [[6]]

## 
## [[7]]

## 
## [[8]]

## 
## [[9]]

## 
## [[10]]

## 
## [[11]]

## 
## [[12]]

## 
## [[13]]

## 
## [[14]]

## 
## [[15]]

## 
## [[16]]

## 
## [[17]]

## 
## [[18]]

## 
## [[19]]

## 
## [[20]]

## 
## [[21]]

## 
## [[22]]

## 
## [[23]]

## 
## [[24]]

## 
## [[25]]

## 
## [[26]]

## 
## [[27]]

## 
## [[28]]

## 
## [[29]]

## 
## [[30]]

## 
## [[31]]

## 
## [[32]]

2.17 u_alpha_a

histogram_and_ecdf_list("u_alpha_a")

2.18 u_alpha_as

histogram_and_ecdf_list("u_alpha_as")

2.19 u_alpha_xa

histogram_and_ecdf_list("u_alpha_xa")

2.20 ui_lambda_x

histogram_and_ecdf_list("ui_lambda_x")

2.21 ui_anc_rho_x

histogram_and_ecdf_list("ui_anc_rho_x")

2.22 ui_anc_alpha_x

histogram_and_ecdf_list("ui_anc_alpha_x")

2.23 log_or_gamma

histogram_and_ecdf_list("log_or_gamma")

2.24 rho_t1_out

There are too many variables to plot this here!

2.25 alpha_t1_out

There are too many variables to plot this here!

2.26 lambda_t1_out

There are too many variables to plot this here!

3 KS plots

r <- tmb$fit$obj$env$random
x_names <- names(tmb$fit$obj$env$par[r])
theta_names <- names(tmb$fit$obj$env$par[-r])

dict <- data.frame(
  parname = c(unique(x_names), unique(theta_names)),
  type = c(rep("Latent", length(unique(x_names))), rep("Hyper", length(unique(theta_names))))
)

ks_df <- lapply(unique(names(tmb$fit$obj$env$par)), to_ks_df) %>%
  bind_rows() %>%
  mutate(parname = str_extract(par, ".*(?=\\[)")) %>%
  left_join(dict, by = "parname")

3.1 Individual parameters

ks_plot_starts_with <- function(param) {
  ks_df %>%
    filter(startsWith(par, param)) %>%
    ks_plot(par = param)
}

ks_df_out <- function(par) {
  to_ks_df(par = par, outputs = TRUE, id = mf_out_fine$id) %>%
  select(-par) %>%
  left_join(
    mf_out_fine %>%
      rename("full_id" = "id") %>%
      tibble::rowid_to_column("index") %>%
      mutate(index = as.numeric(index))
  )
}

3.1.1 beta

ks_plot_starts_with("beta")

3.1.2 logit

ks_plot_starts_with("logit")

3.1.3 log_sigma

ks_plot_starts_with("log_sigma")

3.1.4 u_rho_x

ks_plot_starts_with("u_rho_x")

3.1.5 u_rho_xs

ks_plot_starts_with("u_rho_xs")

3.1.6 us_rho_x

ks_plot_starts_with("us_rho_x")

3.1.7 us_rho_xs

ks_plot_starts_with("us_rho_xs")

3.1.8 u_rho_a

ks_plot_starts_with("u_rho_a")

3.1.9 u_rho_as

ks_plot_starts_with("u_rho_as")

3.1.10 u_alpha_x

ks_plot_starts_with("u_alpha_x")

3.1.11 u_alpha_xs

ks_plot_starts_with("u_alpha_xs")

3.1.12 us_alpha_x

ks_plot_starts_with("us_alpha_x")

3.1.13 us_alpha_xs

ks_plot_starts_with("us_alpha_xs")

3.1.14 u_alpha_a

ks_plot_starts_with("u_alpha_a")

3.1.15 u_alpha_as

ks_plot_starts_with("u_alpha_as")

3.1.16 u_alpha_xa

ks_plot_starts_with("u_alpha_xa")

3.1.17 ui_anc_rho_x

ks_plot_starts_with("ui_anc_rho_x")

3.1.18 ui_anc_alpha_x

ks_plot_starts_with("ui_anc_alpha_x")

3.1.19 log_or_gamma

ks_plot_starts_with("log_or_gamma")

3.1.20 rho_t1_out

ks_df_out(par = "rho_t1_out") %>%
  ks_plot(par = "rho_t1_out", alpha = 0.2)
## Joining with `by = join_by(index)`

3.1.21 alpha_t1_out

ks_df_out(par = "alpha_t1_out") %>%
  ks_plot(par = "rho_t1_out", alpha = 0.2)
## Joining with `by = join_by(index)`

3.1.22 lambda_t1_out

ks_df_out(par = "lambda_t1_out") %>%
  filter(!age_group %in% c("Y005_009", "Y010_014", "Y080_999")) %>%
  ks_plot(par = "lambda_t1_out", alpha = 0.2)
## Joining with `by = join_by(index)`

Filter out the age group 5-9, 10-14 and 80+ here because there are new new infections in those age groups, and the posterior samples from each method are exactly the same: in particular draws of identically zero.

3.2 Summary

options(dplyr.summarise.inform = FALSE)

ks_summary <- ks_df %>%
  group_by(method, parname, type) %>%
  summarise(
    ks = mean(ks),
    size = n()
  ) %>%
  pivot_wider(names_from = "method", values_from = "ks")

saveRDS(ks_summary, "ks-summary.rds")

extended_cbpalette <- colorRampPalette(multi.utils::cbpalette())

ks_summary_latent <- filter(ks_summary, type == "Latent field")

xy_length <- min(1, max(ks_summary_latent$aghq, ks_summary_latent$TMB) + 0.03)
## Warning in max(ks_summary_latent$aghq, ks_summary_latent$TMB): no non-missing arguments to max; returning -Inf
ks_summary_latent %>%
  ggplot(aes(x = TMB, y = aghq, col = parname, size = size)) +
  geom_point(alpha = 0.4) +
  xlim(0, xy_length) +
  ylim(0, xy_length) +
  scale_color_manual(values = extended_cbpalette(n = 20)) +
  scale_size_continuous(breaks = c(2, 10, 32), labels = c(2, 10, 32)) +
  geom_abline(slope = 1, intercept = 0, linetype = "dashed") +
  labs(x = "KS(TMB, NUTS)", y = "KS(PCA-AGHQ, NUTS)", col = "Parameter", size = "Length", subtitle = "Smaller KS values indicate higher accuracy") +
  theme_minimal() +
  guides(size = guide_legend(title.position = "top", direction = "vertical"), col = guide_legend(title.position = "top")) +
  theme(legend.position = "bottom", legend.title = element_text(size = rel(0.9)), legend.text = element_text(size = rel(0.7)))

ggsave("ks-summary.png", h = 5, w = 6.25, bg = "white")
ks_summary %>%
  group_by(type) %>%
  summarise(
    TMB = mean(TMB),
    aghq = mean(aghq)
  ) %>%
  gt::gt()
type TMB aghq
Latent 0.08185225 0.07688924
NA 0.53123077 0.47416923

3.2.1 Latent field

ks_summary %>%
  filter(type == "Latent") %>%
  mutate(diff = signif(TMB - aghq, 3)) %>%
  DT::datatable()

3.3 Investigation into large KS values

Want to create rank order lists of largest KS differences between methods:

ks_df_wide <- ks_df %>%
  pivot_wider(names_from = "method", values_from = "ks") %>%
  mutate(diff = TMB - aghq)

3.3.1 Nodes where TMB beats aghq

(tmb_beats_aghq <- ks_df_wide %>%
  filter(!is.na(type)) %>%
  arrange(diff) %>%
  head(n = 10))
histogram_and_ecdf(par = tmb_beats_aghq$parname[1], i = tmb_beats_aghq$index[1])

histogram_and_ecdf(par = tmb_beats_aghq$parname[2], i = tmb_beats_aghq$index[2])

histogram_and_ecdf(par = tmb_beats_aghq$parname[3], i = tmb_beats_aghq$index[3])

3.3.2 Nodes where aghq beats TMB

(aghq_beats_tmb <- ks_df_wide %>%
  filter(!is.na(type)) %>%
  arrange(desc(diff)) %>%
  head(n = 10))
histogram_and_ecdf(par = aghq_beats_tmb$parname[1], i = aghq_beats_tmb$index[1])

histogram_and_ecdf(par = aghq_beats_tmb$parname[2], i = aghq_beats_tmb$index[2])

histogram_and_ecdf(par = aghq_beats_tmb$parname[3], i = aghq_beats_tmb$index[3])

3.4 Correlation between KS values and ESS

Is there any correlation between the value of \(\text{KS}(\texttt{method}, \texttt{tmbstan})\) for a particular parameter and the ESS of that parameter from tmbstan output?

rhats <- bayesplot::rhat(tmbstan$mcmc$stanfit)
ess_ratio <- bayesplot::neff_ratio(tmbstan$mcmc$stanfit)
niter <- 0.5 * 4 * tmbstan_details$niter / tmbstan_details$nthin
ess <- ess_ratio * niter

fct_reorg <- function(fac, ...) {
  fct_recode(fct_relevel(fac, ...), ...)
}

ks_df %>%
  filter(type != "Hyper") %>%
  mutate(method = fct_reorg(method, "TMB" = "TMB", "PCA-AGHQ" = "aghq")) %>%
  left_join(data.frame(ess) %>%
    tibble::rownames_to_column("par"),
  ) %>%
  ggplot(aes(x = ess, y = ks)) +
    geom_point(shape = 1, alpha = 0.5) +
    geom_smooth(method = "lm", color = "#56B4E9", fullrange = TRUE) +
    stat_poly_eq(use_label("eq"), label.x.npc = "right") +
    stat_poly_eq(label.y = 0.9, label.x.npc = "right") +
    scale_x_continuous(limits = c(0, NA)) +
    facet_grid(~ method) +
    theme_minimal() +
    labs(x = "ESS", y = "KS(method, NUTS)")
## Joining with `by = join_by(par)`
## `geom_smooth()` using formula = 'y ~ x'

ggsave("ks-ess.png", h = 4, w = 6.25)
## `geom_smooth()` using formula = 'y ~ x'