Abstract
Background We have run the simplified Naomi model using a range of inference methods: TMB
, aghq
, and tmbstan
.
Task In this report, we compare the accuracy of the posterior distributions obtained from these inference methods using histograms and Kolmogorov-Smirnov tests.
We compare the inference results from TMB
, aghq
, and tmbstan
.
Import these inference results as follows:
tmb <- readRDS("depends/tmb.rds")
aghq <- readRDS("depends/aghq.rds")
tmbstan <- readRDS("depends/tmbstan.rds")
depends <- yaml::read_yaml("orderly.yml")$depends
Check that the parameters (latent field, hyperparameters, model outputs) sampled from each of the four methods are the same:
stopifnot(names(tmb$fit$sample) == names(aghq$quad$sample))
stopifnot(names(tmb$fit$sample) == names(tmbstan$mcmc$sample))
For more information about the conditions under which these results were generated, see:
TMB
dependency_details <- function(i) {
report_name <- names(depends[[i]])
print(paste0("Inference results obtained from ", report_name, " with the query ", depends[[i]][[report_name]]$id))
report_id <- orderly::orderly_search(query = depends[[i]][[report_name]]$id, report_name)
print(paste0("Obtained report had ID ", report_id, " and was run with the following parameters:"))
print(orderly::orderly_info(report_id, report_name)$parameters)
}
dependency_details(1)
## [1] "Inference results obtained from naomi-simple_fit with the query latest(parameter:tmb == TRUE && parameter:random_only == TRUE)"
## [1] "Obtained report had ID 20230705-101408-12ab65a3 and was run with the following parameters:"
## $tmb
## [1] TRUE
##
## $random_only
## [1] TRUE
##
## $sample
## [1] TRUE
##
## $aghq
## [1] FALSE
##
## $k
## [1] 1
##
## $grid_type
## [1] "product"
##
## $s
## [1] 1
##
## $tmbstan
## [1] FALSE
##
## $hmc_laplace
## [1] FALSE
##
## $niter
## [1] 1000
##
## $nthin
## [1] 1
##
## $adam
## [1] FALSE
##
## $nsample
## [1] 1000
##
## $area_level
## [1] 4
aghq
dependency_details(2)
## [1] "Inference results obtained from naomi-simple_fit with the query latest(parameter:aghq == TRUE && parameter:k == 3 && parameter:s == 8)"
## [1] "Obtained report had ID 20230714-122557-a10aa13d and was run with the following parameters:"
## $aghq
## [1] TRUE
##
## $k
## [1] 3
##
## $s
## [1] 8
##
## $grid_type
## [1] "scaled_pca"
##
## $sample
## [1] TRUE
##
## $tmb
## [1] FALSE
##
## $random_only
## [1] TRUE
##
## $tmbstan
## [1] FALSE
##
## $hmc_laplace
## [1] FALSE
##
## $niter
## [1] 1000
##
## $nthin
## [1] 1
##
## $adam
## [1] FALSE
##
## $nsample
## [1] 1000
##
## $area_level
## [1] 4
tmbstan
tmbstan_details <- dependency_details(3)
## [1] "Inference results obtained from naomi-simple_fit with the query latest(parameter:tmbstan == TRUE && parameter:niter > 50000)"
## [1] "Obtained report had ID 20230414-110613-89cd4244 and was run with the following parameters:"
## $tmbstan
## [1] TRUE
##
## $niter
## [1] 1e+05
##
## $nthin
## [1] 40
##
## $tmb
## [1] FALSE
##
## $aghq
## [1] FALSE
##
## $k
## [1] 1
##
## $grid_type
## [1] "product"
##
## $s
## [1] 1
##
## $hmc_laplace
## [1] FALSE
##
## $adam
## [1] FALSE
##
## $nsample
## [1] 5000
##
## $area_level
## [1] 4
tmbstan_details
## $tmbstan
## [1] TRUE
##
## $niter
## [1] 1e+05
##
## $nthin
## [1] 40
##
## $tmb
## [1] FALSE
##
## $aghq
## [1] FALSE
##
## $k
## [1] 1
##
## $grid_type
## [1] "product"
##
## $s
## [1] 1
##
## $hmc_laplace
## [1] FALSE
##
## $adam
## [1] FALSE
##
## $nsample
## [1] 5000
##
## $area_level
## [1] 4
time_taken <- data.frame(
"TMB" = tmb$time,
"aghq" = aghq$time,
"tmbstan" = tmbstan$time
)
write_csv(time_taken, "time-taken.csv")
time_taken
We create histograms and empirical cumulative distribution function (ECDF) difference plots of the samples from each method.
All of the possible latent field and hyperparamter names are given by pars
:
pars <- names(tmb$fit$sample)
There are also all of the Naomi outcome variables:
names(tmb$fit$sample)[!(names(tmb$fit$sample) %in% unique(names(tmb$fit$obj$env$par)))]
## [1] "anc_clients_t1_out" "population_t1_out" "plhiv_attend_t1_out" "anc_already_art_t1_out" "anc_tested_pos_t1_out" "anc_tested_neg_t1_out"
## [7] "anc_rho_obs_t1_ll" "hhs_prev_ll" "u_rho_xa" "anc_art_new_t1_out" "untreated_plhiv_attend_t1_out" "untreated_plhiv_num_t1_out"
## [13] "artattend_ij_t1_out" "anc_known_pos_t1_out" "artattend_t1_out" "rho_t1_out" "anc_rho_t1_out" "artnum_t1_out"
## [19] "anc_alpha_t1_out" "alpha_t1_out" "infections_t1_out" "anc_plhiv_t1_out" "plhiv_t1_out" "lambda_t1_out"
## [25] "hhs_artcov_ll" "artnum_t1_ll"
We will especially focus on the outputs:
rho_t1_out
) which has 6417 variables (rows)alpha_t1_out
) which has 6417 variables (rows)lambda_t1_out
) which has 6417 variables (rows)These outputs are at an aggregate resolution as well as at the finest resolution.
It seems to make sense to only to assess inferential accuracy at the finest resolution so as to avoid double counting.
As such, we need to find the subset of these variables at the finest resolution.
Within tmb$naomi_data
there is a dataframe called mf_out
(model frame output) which contains the area, age, sex mapping for these 6417 rows.
As a sanity check, confirm that the number of areas times number of ages times number of sexes indeed equals the number of rows:
mf_out <- tmb$naomi_data$mf_out
(n_area <- length(unique(mf_out$area_id)))
## [1] 69
(n_age <- length(unique(mf_out$age_group)))
## [1] 31
(n_sex <- length(unique(mf_out$sex)))
## [1] 3
stopifnot(n_area * n_age * n_sex == dim(tmb$fit$sample$rho_t1_out)[1])
The unique levels for area, age and sex are:
mf_out$area_id %>% unique()
## [1] "MWI" "MWI_1_1_demo" "MWI_1_2_demo" "MWI_1_3_demo" "MWI_2_1_demo" "MWI_2_2_demo" "MWI_2_3_demo" "MWI_2_4_demo" "MWI_2_5_demo" "MWI_3_1_demo" "MWI_3_10_demo" "MWI_3_11_demo"
## [13] "MWI_3_12_demo" "MWI_3_13_demo" "MWI_3_14_demo" "MWI_3_15_demo" "MWI_3_16_demo" "MWI_3_17_demo" "MWI_3_18_demo" "MWI_3_19_demo" "MWI_3_2_demo" "MWI_3_20_demo" "MWI_3_21_demo" "MWI_3_22_demo"
## [25] "MWI_3_23_demo" "MWI_3_24_demo" "MWI_3_25_demo" "MWI_3_26_demo" "MWI_3_27_demo" "MWI_3_28_demo" "MWI_3_3_demo" "MWI_3_4_demo" "MWI_3_5_demo" "MWI_3_6_demo" "MWI_3_7_demo" "MWI_3_8_demo"
## [37] "MWI_3_9_demo" "MWI_4_1_demo" "MWI_4_10_demo" "MWI_4_11_demo" "MWI_4_12_demo" "MWI_4_13_demo" "MWI_4_14_demo" "MWI_4_15_demo" "MWI_4_16_demo" "MWI_4_17_demo" "MWI_4_18_demo" "MWI_4_19_demo"
## [49] "MWI_4_2_demo" "MWI_4_20_demo" "MWI_4_21_demo" "MWI_4_22_demo" "MWI_4_23_demo" "MWI_4_24_demo" "MWI_4_25_demo" "MWI_4_26_demo" "MWI_4_27_demo" "MWI_4_28_demo" "MWI_4_29_demo" "MWI_4_3_demo"
## [61] "MWI_4_30_demo" "MWI_4_31_demo" "MWI_4_32_demo" "MWI_4_4_demo" "MWI_4_5_demo" "MWI_4_6_demo" "MWI_4_7_demo" "MWI_4_8_demo" "MWI_4_9_demo"
mf_out$age_group %>% unique()
## [1] "Y000_004" "Y000_014" "Y000_064" "Y000_999" "Y005_009" "Y010_014" "Y010_019" "Y015_019" "Y015_024" "Y015_049" "Y015_064" "Y015_999" "Y020_024" "Y025_029" "Y025_034" "Y025_049" "Y030_034" "Y035_039"
## [19] "Y035_049" "Y040_044" "Y045_049" "Y050_054" "Y050_064" "Y050_999" "Y055_059" "Y060_064" "Y065_069" "Y065_999" "Y070_074" "Y075_079" "Y080_999"
mf_out$sex %>% unique()
## [1] "both" "female" "male"
Filter to only the finest resolution, adding an id
column has to enable merging to the samples:
mf_out_fine <- mf_out %>%
tibble::rownames_to_column("id") %>%
mutate(id = as.numeric(id)) %>%
filter(
area_id %in% paste0("MWI_4_", 1:32, "_demo"),
sex %in% c("male", "female"),
age_group %in%
c(
"Y000_004", "Y005_009", "Y010_014", "Y015_019", "Y020_024", "Y025_029",
"Y025_034", "Y030_034", "Y035_039", "Y040_044", "Y045_049", "Y050_054",
"Y055_059", "Y060_064", "Y065_069", "Y070_074", "Y075_079", "Y080_999"
)
)
We will produce plots about the a small subset of parameters for the time being. There is no particular reason to choose this subset rather than other, it’s quite arbitrary.
pars_eval <- pars %in% c("beta_rho", "u_rho_x")
names(pars_eval) <- pars
beta_rho
histogram_and_ecdf_list("beta_rho")
## [[1]]
##
## [[2]]
beta_alpha
histogram_and_ecdf_list("beta_alpha")
beta_lambda
histogram_and_ecdf_list("beta_lambda")
beta_anc_rho
histogram_and_ecdf_list("beta_anc_rho")
beta_anc_alpha
histogram_and_ecdf("beta_anc_alpha")
logit
lapply(pars[stringr::str_starts(pars, "logit")], histogram_and_ecdf)
log_sigma
lapply(pars[stringr::str_starts(pars, "log_sigma")], histogram_and_ecdf)
u_rho_x
histogram_and_ecdf_list("u_rho_x")
## [[1]]
##
## [[2]]
##
## [[3]]
##
## [[4]]
##
## [[5]]
##
## [[6]]
##
## [[7]]
##
## [[8]]
##
## [[9]]
##
## [[10]]
##
## [[11]]
##
## [[12]]
##
## [[13]]
##
## [[14]]
##
## [[15]]
##
## [[16]]
##
## [[17]]
##
## [[18]]
##
## [[19]]
##
## [[20]]
##
## [[21]]
##
## [[22]]
##
## [[23]]
##
## [[24]]
##
## [[25]]
##
## [[26]]
##
## [[27]]
##
## [[28]]
##
## [[29]]
##
## [[30]]
##
## [[31]]
##
## [[32]]
us_rho_x
histogram_and_ecdf_list("us_rho_x")
u_rho_xs
histogram_and_ecdf_list("u_rho_xs")
us_rho_xs
histogram_and_ecdf_list("us_rho_xs")
u_rho_as
histogram_and_ecdf_list("u_rho_as")
u_alpha_x
histogram_and_ecdf_list("u_alpha_x")
us_alpha_x
histogram_and_ecdf_list("us_alpha_x")
u_alpha_xs
histogram_and_ecdf_list("u_alpha_xs")
us_alpha_xs
histogram_and_ecdf_list("us_alpha_xs")
## [[1]]
##
## [[2]]
##
## [[3]]
##
## [[4]]
##
## [[5]]
##
## [[6]]
##
## [[7]]
##
## [[8]]
##
## [[9]]
##
## [[10]]
##
## [[11]]
##
## [[12]]
##
## [[13]]
##
## [[14]]
##
## [[15]]
##
## [[16]]
##
## [[17]]
##
## [[18]]
##
## [[19]]
##
## [[20]]
##
## [[21]]
##
## [[22]]
##
## [[23]]
##
## [[24]]
##
## [[25]]
##
## [[26]]
##
## [[27]]
##
## [[28]]
##
## [[29]]
##
## [[30]]
##
## [[31]]
##
## [[32]]
u_alpha_a
histogram_and_ecdf_list("u_alpha_a")
u_alpha_as
histogram_and_ecdf_list("u_alpha_as")
u_alpha_xa
histogram_and_ecdf_list("u_alpha_xa")
ui_lambda_x
histogram_and_ecdf_list("ui_lambda_x")
ui_anc_rho_x
histogram_and_ecdf_list("ui_anc_rho_x")
ui_anc_alpha_x
histogram_and_ecdf_list("ui_anc_alpha_x")
log_or_gamma
histogram_and_ecdf_list("log_or_gamma")
rho_t1_out
There are too many variables to plot this here!
alpha_t1_out
There are too many variables to plot this here!
lambda_t1_out
There are too many variables to plot this here!
r <- tmb$fit$obj$env$random
x_names <- names(tmb$fit$obj$env$par[r])
theta_names <- names(tmb$fit$obj$env$par[-r])
dict <- data.frame(
parname = c(unique(x_names), unique(theta_names)),
type = c(rep("Latent", length(unique(x_names))), rep("Hyper", length(unique(theta_names))))
)
ks_df <- lapply(unique(names(tmb$fit$obj$env$par)), to_ks_df) %>%
bind_rows() %>%
mutate(parname = str_extract(par, ".*(?=\\[)")) %>%
left_join(dict, by = "parname")
ks_plot_starts_with <- function(param) {
ks_df %>%
filter(startsWith(par, param)) %>%
ks_plot(par = param)
}
ks_df_out <- function(par) {
to_ks_df(par = par, outputs = TRUE, id = mf_out_fine$id) %>%
select(-par) %>%
left_join(
mf_out_fine %>%
rename("full_id" = "id") %>%
tibble::rowid_to_column("index") %>%
mutate(index = as.numeric(index))
)
}
beta
ks_plot_starts_with("beta")
logit
ks_plot_starts_with("logit")
log_sigma
ks_plot_starts_with("log_sigma")
u_rho_x
ks_plot_starts_with("u_rho_x")
u_rho_xs
ks_plot_starts_with("u_rho_xs")
us_rho_x
ks_plot_starts_with("us_rho_x")
us_rho_xs
ks_plot_starts_with("us_rho_xs")
u_rho_a
ks_plot_starts_with("u_rho_a")
u_rho_as
ks_plot_starts_with("u_rho_as")
u_alpha_x
ks_plot_starts_with("u_alpha_x")
u_alpha_xs
ks_plot_starts_with("u_alpha_xs")
us_alpha_x
ks_plot_starts_with("us_alpha_x")
us_alpha_xs
ks_plot_starts_with("us_alpha_xs")
u_alpha_a
ks_plot_starts_with("u_alpha_a")
u_alpha_as
ks_plot_starts_with("u_alpha_as")
u_alpha_xa
ks_plot_starts_with("u_alpha_xa")
ui_anc_rho_x
ks_plot_starts_with("ui_anc_rho_x")
ui_anc_alpha_x
ks_plot_starts_with("ui_anc_alpha_x")
log_or_gamma
ks_plot_starts_with("log_or_gamma")
rho_t1_out
ks_df_out(par = "rho_t1_out") %>%
ks_plot(par = "rho_t1_out", alpha = 0.2)
## Joining with `by = join_by(index)`
alpha_t1_out
ks_df_out(par = "alpha_t1_out") %>%
ks_plot(par = "rho_t1_out", alpha = 0.2)
## Joining with `by = join_by(index)`
lambda_t1_out
ks_df_out(par = "lambda_t1_out") %>%
filter(!age_group %in% c("Y005_009", "Y010_014", "Y080_999")) %>%
ks_plot(par = "lambda_t1_out", alpha = 0.2)
## Joining with `by = join_by(index)`
Filter out the age group 5-9, 10-14 and 80+ here because there are new new infections in those age groups, and the posterior samples from each method are exactly the same: in particular draws of identically zero.
options(dplyr.summarise.inform = FALSE)
ks_summary <- ks_df %>%
group_by(method, parname, type) %>%
summarise(
ks = mean(ks),
size = n()
) %>%
pivot_wider(names_from = "method", values_from = "ks")
saveRDS(ks_summary, "ks-summary.rds")
extended_cbpalette <- colorRampPalette(multi.utils::cbpalette())
ks_summary_latent <- filter(ks_summary, type == "Latent field")
xy_length <- min(1, max(ks_summary_latent$aghq, ks_summary_latent$TMB) + 0.03)
## Warning in max(ks_summary_latent$aghq, ks_summary_latent$TMB): no non-missing arguments to max; returning -Inf
ks_summary_latent %>%
ggplot(aes(x = TMB, y = aghq, col = parname, size = size)) +
geom_point(alpha = 0.4) +
xlim(0, xy_length) +
ylim(0, xy_length) +
scale_color_manual(values = extended_cbpalette(n = 20)) +
scale_size_continuous(breaks = c(2, 10, 32), labels = c(2, 10, 32)) +
geom_abline(slope = 1, intercept = 0, linetype = "dashed") +
labs(x = "KS(TMB, NUTS)", y = "KS(PCA-AGHQ, NUTS)", col = "Parameter", size = "Length", subtitle = "Smaller KS values indicate higher accuracy") +
theme_minimal() +
guides(size = guide_legend(title.position = "top", direction = "vertical"), col = guide_legend(title.position = "top")) +
theme(legend.position = "bottom", legend.title = element_text(size = rel(0.9)), legend.text = element_text(size = rel(0.7)))
ggsave("ks-summary.png", h = 5, w = 6.25, bg = "white")
ks_summary %>%
group_by(type) %>%
summarise(
TMB = mean(TMB),
aghq = mean(aghq)
) %>%
gt::gt()
type | TMB | aghq |
---|---|---|
Latent | 0.08185225 | 0.07688924 |
NA | 0.53123077 | 0.47416923 |
ks_summary %>%
filter(type == "Latent") %>%
mutate(diff = signif(TMB - aghq, 3)) %>%
DT::datatable()
Want to create rank order lists of largest KS differences between methods:
ks_df_wide <- ks_df %>%
pivot_wider(names_from = "method", values_from = "ks") %>%
mutate(diff = TMB - aghq)
TMB
beats aghq
(tmb_beats_aghq <- ks_df_wide %>%
filter(!is.na(type)) %>%
arrange(diff) %>%
head(n = 10))
histogram_and_ecdf(par = tmb_beats_aghq$parname[1], i = tmb_beats_aghq$index[1])
histogram_and_ecdf(par = tmb_beats_aghq$parname[2], i = tmb_beats_aghq$index[2])
histogram_and_ecdf(par = tmb_beats_aghq$parname[3], i = tmb_beats_aghq$index[3])
aghq
beats TMB
(aghq_beats_tmb <- ks_df_wide %>%
filter(!is.na(type)) %>%
arrange(desc(diff)) %>%
head(n = 10))
histogram_and_ecdf(par = aghq_beats_tmb$parname[1], i = aghq_beats_tmb$index[1])
histogram_and_ecdf(par = aghq_beats_tmb$parname[2], i = aghq_beats_tmb$index[2])
histogram_and_ecdf(par = aghq_beats_tmb$parname[3], i = aghq_beats_tmb$index[3])
Is there any correlation between the value of \(\text{KS}(\texttt{method}, \texttt{tmbstan})\) for a particular parameter and the ESS of that parameter from tmbstan
output?
rhats <- bayesplot::rhat(tmbstan$mcmc$stanfit)
ess_ratio <- bayesplot::neff_ratio(tmbstan$mcmc$stanfit)
niter <- 0.5 * 4 * tmbstan_details$niter / tmbstan_details$nthin
ess <- ess_ratio * niter
fct_reorg <- function(fac, ...) {
fct_recode(fct_relevel(fac, ...), ...)
}
ks_df %>%
filter(type != "Hyper") %>%
mutate(method = fct_reorg(method, "TMB" = "TMB", "PCA-AGHQ" = "aghq")) %>%
left_join(data.frame(ess) %>%
tibble::rownames_to_column("par"),
) %>%
ggplot(aes(x = ess, y = ks)) +
geom_point(shape = 1, alpha = 0.5) +
geom_smooth(method = "lm", color = "#56B4E9", fullrange = TRUE) +
stat_poly_eq(use_label("eq"), label.x.npc = "right") +
stat_poly_eq(label.y = 0.9, label.x.npc = "right") +
scale_x_continuous(limits = c(0, NA)) +
facet_grid(~ method) +
theme_minimal() +
labs(x = "ESS", y = "KS(method, NUTS)")
## Joining with `by = join_by(par)`
## `geom_smooth()` using formula = 'y ~ x'
ggsave("ks-ess.png", h = 4, w = 6.25)
## `geom_smooth()` using formula = 'y ~ x'