Abstract
Background We have run the simplified Naomi model using a range of inference methods: TMB, aghq, and tmbstan.
Task In this report, we compare the accuracy of the posterior distributions obtained from these inference methods using histograms and Kolmogorov-Smirnov tests.
We compare the inference results from TMB, aghq, and tmbstan.
Import these inference results as follows:
tmb <- readRDS("depends/tmb.rds")
aghq <- readRDS("depends/aghq.rds")
tmbstan <- readRDS("depends/tmbstan.rds")
depends <- yaml::read_yaml("orderly.yml")$depends
Check that the parameters (latent field, hyperparameters, model outputs) sampled from each of the four methods are the same:
stopifnot(names(tmb$fit$sample) == names(aghq$quad$sample))
stopifnot(names(tmb$fit$sample) == names(tmbstan$mcmc$sample))
For more information about the conditions under which these results were generated, see:
TMBdependency_details <- function(i) {
report_name <- names(depends[[i]])
print(paste0("Inference results obtained from ", report_name, " with the query ", depends[[i]][[report_name]]$id))
report_id <- orderly::orderly_search(query = depends[[i]][[report_name]]$id, report_name)
print(paste0("Obtained report had ID ", report_id, " and was run with the following parameters:"))
print(orderly::orderly_info(report_id, report_name)$parameters)
}
dependency_details(1)
## [1] "Inference results obtained from naomi-simple_fit with the query latest(parameter:tmb == TRUE && parameter:random_only == TRUE)"
## [1] "Obtained report had ID 20230705-101408-12ab65a3 and was run with the following parameters:"
## $tmb
## [1] TRUE
##
## $random_only
## [1] TRUE
##
## $sample
## [1] TRUE
##
## $aghq
## [1] FALSE
##
## $k
## [1] 1
##
## $grid_type
## [1] "product"
##
## $s
## [1] 1
##
## $tmbstan
## [1] FALSE
##
## $hmc_laplace
## [1] FALSE
##
## $niter
## [1] 1000
##
## $nthin
## [1] 1
##
## $adam
## [1] FALSE
##
## $nsample
## [1] 1000
##
## $area_level
## [1] 4
aghqdependency_details(2)
## [1] "Inference results obtained from naomi-simple_fit with the query latest(parameter:aghq == TRUE && parameter:k == 3 && parameter:s == 8)"
## [1] "Obtained report had ID 20230714-122557-a10aa13d and was run with the following parameters:"
## $aghq
## [1] TRUE
##
## $k
## [1] 3
##
## $s
## [1] 8
##
## $grid_type
## [1] "scaled_pca"
##
## $sample
## [1] TRUE
##
## $tmb
## [1] FALSE
##
## $random_only
## [1] TRUE
##
## $tmbstan
## [1] FALSE
##
## $hmc_laplace
## [1] FALSE
##
## $niter
## [1] 1000
##
## $nthin
## [1] 1
##
## $adam
## [1] FALSE
##
## $nsample
## [1] 1000
##
## $area_level
## [1] 4
tmbstantmbstan_details <- dependency_details(3)
## [1] "Inference results obtained from naomi-simple_fit with the query latest(parameter:tmbstan == TRUE && parameter:niter > 50000)"
## [1] "Obtained report had ID 20230414-110613-89cd4244 and was run with the following parameters:"
## $tmbstan
## [1] TRUE
##
## $niter
## [1] 1e+05
##
## $nthin
## [1] 40
##
## $tmb
## [1] FALSE
##
## $aghq
## [1] FALSE
##
## $k
## [1] 1
##
## $grid_type
## [1] "product"
##
## $s
## [1] 1
##
## $hmc_laplace
## [1] FALSE
##
## $adam
## [1] FALSE
##
## $nsample
## [1] 5000
##
## $area_level
## [1] 4
tmbstan_details
## $tmbstan
## [1] TRUE
##
## $niter
## [1] 1e+05
##
## $nthin
## [1] 40
##
## $tmb
## [1] FALSE
##
## $aghq
## [1] FALSE
##
## $k
## [1] 1
##
## $grid_type
## [1] "product"
##
## $s
## [1] 1
##
## $hmc_laplace
## [1] FALSE
##
## $adam
## [1] FALSE
##
## $nsample
## [1] 5000
##
## $area_level
## [1] 4
time_taken <- data.frame(
"TMB" = tmb$time,
"aghq" = aghq$time,
"tmbstan" = tmbstan$time
)
write_csv(time_taken, "time-taken.csv")
time_taken
We create histograms and empirical cumulative distribution function (ECDF) difference plots of the samples from each method.
All of the possible latent field and hyperparamter names are given by pars:
pars <- names(tmb$fit$sample)
There are also all of the Naomi outcome variables:
names(tmb$fit$sample)[!(names(tmb$fit$sample) %in% unique(names(tmb$fit$obj$env$par)))]
## [1] "anc_clients_t1_out" "population_t1_out" "plhiv_attend_t1_out" "anc_already_art_t1_out" "anc_tested_pos_t1_out" "anc_tested_neg_t1_out"
## [7] "anc_rho_obs_t1_ll" "hhs_prev_ll" "u_rho_xa" "anc_art_new_t1_out" "untreated_plhiv_attend_t1_out" "untreated_plhiv_num_t1_out"
## [13] "artattend_ij_t1_out" "anc_known_pos_t1_out" "artattend_t1_out" "rho_t1_out" "anc_rho_t1_out" "artnum_t1_out"
## [19] "anc_alpha_t1_out" "alpha_t1_out" "infections_t1_out" "anc_plhiv_t1_out" "plhiv_t1_out" "lambda_t1_out"
## [25] "hhs_artcov_ll" "artnum_t1_ll"
We will especially focus on the outputs:
rho_t1_out) which has 6417 variables (rows)alpha_t1_out) which has 6417 variables (rows)lambda_t1_out) which has 6417 variables (rows)These outputs are at an aggregate resolution as well as at the finest resolution.
It seems to make sense to only to assess inferential accuracy at the finest resolution so as to avoid double counting.
As such, we need to find the subset of these variables at the finest resolution.
Within tmb$naomi_data there is a dataframe called mf_out (model frame output) which contains the area, age, sex mapping for these 6417 rows.
As a sanity check, confirm that the number of areas times number of ages times number of sexes indeed equals the number of rows:
mf_out <- tmb$naomi_data$mf_out
(n_area <- length(unique(mf_out$area_id)))
## [1] 69
(n_age <- length(unique(mf_out$age_group)))
## [1] 31
(n_sex <- length(unique(mf_out$sex)))
## [1] 3
stopifnot(n_area * n_age * n_sex == dim(tmb$fit$sample$rho_t1_out)[1])
The unique levels for area, age and sex are:
mf_out$area_id %>% unique()
## [1] "MWI" "MWI_1_1_demo" "MWI_1_2_demo" "MWI_1_3_demo" "MWI_2_1_demo" "MWI_2_2_demo" "MWI_2_3_demo" "MWI_2_4_demo" "MWI_2_5_demo" "MWI_3_1_demo" "MWI_3_10_demo" "MWI_3_11_demo"
## [13] "MWI_3_12_demo" "MWI_3_13_demo" "MWI_3_14_demo" "MWI_3_15_demo" "MWI_3_16_demo" "MWI_3_17_demo" "MWI_3_18_demo" "MWI_3_19_demo" "MWI_3_2_demo" "MWI_3_20_demo" "MWI_3_21_demo" "MWI_3_22_demo"
## [25] "MWI_3_23_demo" "MWI_3_24_demo" "MWI_3_25_demo" "MWI_3_26_demo" "MWI_3_27_demo" "MWI_3_28_demo" "MWI_3_3_demo" "MWI_3_4_demo" "MWI_3_5_demo" "MWI_3_6_demo" "MWI_3_7_demo" "MWI_3_8_demo"
## [37] "MWI_3_9_demo" "MWI_4_1_demo" "MWI_4_10_demo" "MWI_4_11_demo" "MWI_4_12_demo" "MWI_4_13_demo" "MWI_4_14_demo" "MWI_4_15_demo" "MWI_4_16_demo" "MWI_4_17_demo" "MWI_4_18_demo" "MWI_4_19_demo"
## [49] "MWI_4_2_demo" "MWI_4_20_demo" "MWI_4_21_demo" "MWI_4_22_demo" "MWI_4_23_demo" "MWI_4_24_demo" "MWI_4_25_demo" "MWI_4_26_demo" "MWI_4_27_demo" "MWI_4_28_demo" "MWI_4_29_demo" "MWI_4_3_demo"
## [61] "MWI_4_30_demo" "MWI_4_31_demo" "MWI_4_32_demo" "MWI_4_4_demo" "MWI_4_5_demo" "MWI_4_6_demo" "MWI_4_7_demo" "MWI_4_8_demo" "MWI_4_9_demo"
mf_out$age_group %>% unique()
## [1] "Y000_004" "Y000_014" "Y000_064" "Y000_999" "Y005_009" "Y010_014" "Y010_019" "Y015_019" "Y015_024" "Y015_049" "Y015_064" "Y015_999" "Y020_024" "Y025_029" "Y025_034" "Y025_049" "Y030_034" "Y035_039"
## [19] "Y035_049" "Y040_044" "Y045_049" "Y050_054" "Y050_064" "Y050_999" "Y055_059" "Y060_064" "Y065_069" "Y065_999" "Y070_074" "Y075_079" "Y080_999"
mf_out$sex %>% unique()
## [1] "both" "female" "male"
Filter to only the finest resolution, adding an id column has to enable merging to the samples:
mf_out_fine <- mf_out %>%
tibble::rownames_to_column("id") %>%
mutate(id = as.numeric(id)) %>%
filter(
area_id %in% paste0("MWI_4_", 1:32, "_demo"),
sex %in% c("male", "female"),
age_group %in%
c(
"Y000_004", "Y005_009", "Y010_014", "Y015_019", "Y020_024", "Y025_029",
"Y025_034", "Y030_034", "Y035_039", "Y040_044", "Y045_049", "Y050_054",
"Y055_059", "Y060_064", "Y065_069", "Y070_074", "Y075_079", "Y080_999"
)
)
We will produce plots about the a small subset of parameters for the time being. There is no particular reason to choose this subset rather than other, it’s quite arbitrary.
pars_eval <- pars %in% c("beta_rho", "u_rho_x")
names(pars_eval) <- pars
beta_rhohistogram_and_ecdf_list("beta_rho")
## [[1]]
##
## [[2]]
beta_alphahistogram_and_ecdf_list("beta_alpha")
beta_lambdahistogram_and_ecdf_list("beta_lambda")
beta_anc_rhohistogram_and_ecdf_list("beta_anc_rho")
beta_anc_alphahistogram_and_ecdf("beta_anc_alpha")
logitlapply(pars[stringr::str_starts(pars, "logit")], histogram_and_ecdf)
log_sigmalapply(pars[stringr::str_starts(pars, "log_sigma")], histogram_and_ecdf)
u_rho_xhistogram_and_ecdf_list("u_rho_x")
## [[1]]
##
## [[2]]
##
## [[3]]
##
## [[4]]
##
## [[5]]
##
## [[6]]
##
## [[7]]
##
## [[8]]
##
## [[9]]
##
## [[10]]
##
## [[11]]
##
## [[12]]
##
## [[13]]
##
## [[14]]
##
## [[15]]
##
## [[16]]
##
## [[17]]
##
## [[18]]
##
## [[19]]
##
## [[20]]
##
## [[21]]
##
## [[22]]
##
## [[23]]
##
## [[24]]
##
## [[25]]
##
## [[26]]
##
## [[27]]
##
## [[28]]
##
## [[29]]
##
## [[30]]
##
## [[31]]
##
## [[32]]
us_rho_xhistogram_and_ecdf_list("us_rho_x")
u_rho_xshistogram_and_ecdf_list("u_rho_xs")
us_rho_xshistogram_and_ecdf_list("us_rho_xs")
u_rho_ashistogram_and_ecdf_list("u_rho_as")
u_alpha_xhistogram_and_ecdf_list("u_alpha_x")
us_alpha_xhistogram_and_ecdf_list("us_alpha_x")
u_alpha_xshistogram_and_ecdf_list("u_alpha_xs")
us_alpha_xshistogram_and_ecdf_list("us_alpha_xs")
## [[1]]
##
## [[2]]
##
## [[3]]
##
## [[4]]
##
## [[5]]
##
## [[6]]
##
## [[7]]
##
## [[8]]
##
## [[9]]
##
## [[10]]
##
## [[11]]
##
## [[12]]
##
## [[13]]
##
## [[14]]
##
## [[15]]
##
## [[16]]
##
## [[17]]
##
## [[18]]
##
## [[19]]
##
## [[20]]
##
## [[21]]
##
## [[22]]
##
## [[23]]
##
## [[24]]
##
## [[25]]
##
## [[26]]
##
## [[27]]
##
## [[28]]
##
## [[29]]
##
## [[30]]
##
## [[31]]
##
## [[32]]
u_alpha_ahistogram_and_ecdf_list("u_alpha_a")
u_alpha_ashistogram_and_ecdf_list("u_alpha_as")
u_alpha_xahistogram_and_ecdf_list("u_alpha_xa")
ui_lambda_xhistogram_and_ecdf_list("ui_lambda_x")
ui_anc_rho_xhistogram_and_ecdf_list("ui_anc_rho_x")
ui_anc_alpha_xhistogram_and_ecdf_list("ui_anc_alpha_x")
log_or_gammahistogram_and_ecdf_list("log_or_gamma")
rho_t1_outThere are too many variables to plot this here!
alpha_t1_outThere are too many variables to plot this here!
lambda_t1_outThere are too many variables to plot this here!
r <- tmb$fit$obj$env$random
x_names <- names(tmb$fit$obj$env$par[r])
theta_names <- names(tmb$fit$obj$env$par[-r])
dict <- data.frame(
parname = c(unique(x_names), unique(theta_names)),
type = c(rep("Latent", length(unique(x_names))), rep("Hyper", length(unique(theta_names))))
)
ks_df <- lapply(unique(names(tmb$fit$obj$env$par)), to_ks_df) %>%
bind_rows() %>%
mutate(parname = str_extract(par, ".*(?=\\[)")) %>%
left_join(dict, by = "parname")
ks_plot_starts_with <- function(param) {
ks_df %>%
filter(startsWith(par, param)) %>%
ks_plot(par = param)
}
ks_df_out <- function(par) {
to_ks_df(par = par, outputs = TRUE, id = mf_out_fine$id) %>%
select(-par) %>%
left_join(
mf_out_fine %>%
rename("full_id" = "id") %>%
tibble::rowid_to_column("index") %>%
mutate(index = as.numeric(index))
)
}
betaks_plot_starts_with("beta")
logitks_plot_starts_with("logit")
log_sigmaks_plot_starts_with("log_sigma")
u_rho_xks_plot_starts_with("u_rho_x")
u_rho_xsks_plot_starts_with("u_rho_xs")
us_rho_xks_plot_starts_with("us_rho_x")
us_rho_xsks_plot_starts_with("us_rho_xs")
u_rho_aks_plot_starts_with("u_rho_a")
u_rho_asks_plot_starts_with("u_rho_as")
u_alpha_xks_plot_starts_with("u_alpha_x")
u_alpha_xsks_plot_starts_with("u_alpha_xs")
us_alpha_xks_plot_starts_with("us_alpha_x")
us_alpha_xsks_plot_starts_with("us_alpha_xs")
u_alpha_aks_plot_starts_with("u_alpha_a")
u_alpha_asks_plot_starts_with("u_alpha_as")
u_alpha_xaks_plot_starts_with("u_alpha_xa")
ui_anc_rho_xks_plot_starts_with("ui_anc_rho_x")
ui_anc_alpha_xks_plot_starts_with("ui_anc_alpha_x")
log_or_gammaks_plot_starts_with("log_or_gamma")
rho_t1_outks_df_out(par = "rho_t1_out") %>%
ks_plot(par = "rho_t1_out", alpha = 0.2)
## Joining with `by = join_by(index)`
alpha_t1_outks_df_out(par = "alpha_t1_out") %>%
ks_plot(par = "rho_t1_out", alpha = 0.2)
## Joining with `by = join_by(index)`
lambda_t1_outks_df_out(par = "lambda_t1_out") %>%
filter(!age_group %in% c("Y005_009", "Y010_014", "Y080_999")) %>%
ks_plot(par = "lambda_t1_out", alpha = 0.2)
## Joining with `by = join_by(index)`
Filter out the age group 5-9, 10-14 and 80+ here because there are new new infections in those age groups, and the posterior samples from each method are exactly the same: in particular draws of identically zero.
options(dplyr.summarise.inform = FALSE)
ks_summary <- ks_df %>%
group_by(method, parname, type) %>%
summarise(
ks = mean(ks),
size = n()
) %>%
pivot_wider(names_from = "method", values_from = "ks")
saveRDS(ks_summary, "ks-summary.rds")
extended_cbpalette <- colorRampPalette(multi.utils::cbpalette())
ks_summary_latent <- filter(ks_summary, type == "Latent field")
xy_length <- min(1, max(ks_summary_latent$aghq, ks_summary_latent$TMB) + 0.03)
## Warning in max(ks_summary_latent$aghq, ks_summary_latent$TMB): no non-missing arguments to max; returning -Inf
ks_summary_latent %>%
ggplot(aes(x = TMB, y = aghq, col = parname, size = size)) +
geom_point(alpha = 0.4) +
xlim(0, xy_length) +
ylim(0, xy_length) +
scale_color_manual(values = extended_cbpalette(n = 20)) +
scale_size_continuous(breaks = c(2, 10, 32), labels = c(2, 10, 32)) +
geom_abline(slope = 1, intercept = 0, linetype = "dashed") +
labs(x = "KS(TMB, NUTS)", y = "KS(PCA-AGHQ, NUTS)", col = "Parameter", size = "Length", subtitle = "Smaller KS values indicate higher accuracy") +
theme_minimal() +
guides(size = guide_legend(title.position = "top", direction = "vertical"), col = guide_legend(title.position = "top")) +
theme(legend.position = "bottom", legend.title = element_text(size = rel(0.9)), legend.text = element_text(size = rel(0.7)))
ggsave("ks-summary.png", h = 5, w = 6.25, bg = "white")
ks_summary %>%
group_by(type) %>%
summarise(
TMB = mean(TMB),
aghq = mean(aghq)
) %>%
gt::gt()
| type | TMB | aghq |
|---|---|---|
| Latent | 0.08185225 | 0.07688924 |
| NA | 0.53123077 | 0.47416923 |
ks_summary %>%
filter(type == "Latent") %>%
mutate(diff = signif(TMB - aghq, 3)) %>%
DT::datatable()
Want to create rank order lists of largest KS differences between methods:
ks_df_wide <- ks_df %>%
pivot_wider(names_from = "method", values_from = "ks") %>%
mutate(diff = TMB - aghq)
TMB beats aghq(tmb_beats_aghq <- ks_df_wide %>%
filter(!is.na(type)) %>%
arrange(diff) %>%
head(n = 10))
histogram_and_ecdf(par = tmb_beats_aghq$parname[1], i = tmb_beats_aghq$index[1])
histogram_and_ecdf(par = tmb_beats_aghq$parname[2], i = tmb_beats_aghq$index[2])
histogram_and_ecdf(par = tmb_beats_aghq$parname[3], i = tmb_beats_aghq$index[3])
aghq beats TMB(aghq_beats_tmb <- ks_df_wide %>%
filter(!is.na(type)) %>%
arrange(desc(diff)) %>%
head(n = 10))
histogram_and_ecdf(par = aghq_beats_tmb$parname[1], i = aghq_beats_tmb$index[1])
histogram_and_ecdf(par = aghq_beats_tmb$parname[2], i = aghq_beats_tmb$index[2])
histogram_and_ecdf(par = aghq_beats_tmb$parname[3], i = aghq_beats_tmb$index[3])
Is there any correlation between the value of \(\text{KS}(\texttt{method}, \texttt{tmbstan})\) for a particular parameter and the ESS of that parameter from tmbstan output?
rhats <- bayesplot::rhat(tmbstan$mcmc$stanfit)
ess_ratio <- bayesplot::neff_ratio(tmbstan$mcmc$stanfit)
niter <- 0.5 * 4 * tmbstan_details$niter / tmbstan_details$nthin
ess <- ess_ratio * niter
fct_reorg <- function(fac, ...) {
fct_recode(fct_relevel(fac, ...), ...)
}
ks_df %>%
filter(type != "Hyper") %>%
mutate(method = fct_reorg(method, "TMB" = "TMB", "PCA-AGHQ" = "aghq")) %>%
left_join(data.frame(ess) %>%
tibble::rownames_to_column("par"),
) %>%
ggplot(aes(x = ess, y = ks)) +
geom_point(shape = 1, alpha = 0.5) +
geom_smooth(method = "lm", color = "#56B4E9", fullrange = TRUE) +
stat_poly_eq(use_label("eq"), label.x.npc = "right") +
stat_poly_eq(label.y = 0.9, label.x.npc = "right") +
scale_x_continuous(limits = c(0, NA)) +
facet_grid(~ method) +
theme_minimal() +
labs(x = "ESS", y = "KS(method, NUTS)")
## Joining with `by = join_by(par)`
## `geom_smooth()` using formula = 'y ~ x'
ggsave("ks-ess.png", h = 4, w = 6.25)
## `geom_smooth()` using formula = 'y ~ x'