Background The HIV inference group at Imperial
routinely uses TMB
to implement complex spatio-temporal
statistical models because the scale and requirement for in-production
run-time makes MCMC infeasible, and no suitably flexible implementation
of INLA exists. However, inferences obtained from TMB
not as accurate as from MCMC or INLA.
Task We compare inferences produced using
and aghq
for a
collection of HIV evidence synthesis models based upon the Naomi
small-area estimation model.
Findings We find that, treating tmbstan
as the gold-standard, inference from TMB
is substantially
worse than than from aghq
Next steps Work on presentation of document. Understand performance for different parameters. Understand cause of poor performance.
Jeffrey W. Eaton et al. (2019) specify a joint model linking small-area estimation models of HIV prevalence from household surveys, HIV prevalence from antenatal care clinics, and antiretroviral therapy (ART) coverage from routine health data collection. This model forms the basis of the Naomi small-area estimation model (Jeffrey W. Eaton et al. 2021). Modelling data from multiple sources concurrently increases statistical power, and may mitigate the biases of any single source giving a more complete picture of the situation, as well as prompting investigation into any data conflicts. The model is described by three components, as follows.
Consider a country partitioned into \(n\) areas indexed by \(i\). Suppose a simple random household survey of \(m^\text{HS}_i\) people is conducted in each area, and \(y^\text{HS}_i\) HIV positive cases are observed. Cases may be modelled using a binomial logistic regression model \[\begin{align} y^\text{HS}_i &\sim \text{Bin}(m^\text{HS}_i, \rho^\text{HS}_i), \\ \text{logit}(\rho^\text{HS}_i) &\sim \mathcal{N}(\beta_\phi, \sigma_\phi^2), \end{align}\] where HIV prevalence \(\rho^\text{HS}_i\) is modelled by a Gaussian with mean \(\beta_\phi\) and standard deviation \(\sigma_\phi\).
Routinely collected data from pregnant women attending antenatal care clinics (ANCs) is another important source of information about the HIV epidemic. Suppose that of \(m^\text{ANC}_i\) women attending ANC, \(y^\text{ANC}_i\) are HIV positive. Then an analogous binomial logistic regression model \[\begin{align} y^\text{ANC}_i &\sim \text{Bin}(m^\text{ANC}_i, \rho^\text{ANC}_i), \\ \text{logit}(\rho^\text{ANC}_i) &= \text{logit}(\rho^\text{HS}_i) + b_i, \\ b_i &\sim \mathcal{N}(\beta_b, \sigma_b^2), \end{align}\] may be used to describe HIV prevalence amongst the sub-population of women attending ANCs. Reflecting the fact that prevalence in ANCs is related but importantly different to prevalence in the general population, bias terms \(b_i\) are used to offset ANC prevalence from HIV prevalence on the logit scale.
The number of people receiving treatment at district health facilities \(A_i\) provides further information about HIV prevalence. Districts with high prevalence are likely to have a greater number of people receiving treatment, and vice versa. ART coverage, defined to be the proportion of people living with HIV (PLHIV) currently on ART on district \(i\), is given by \(\alpha_i = A_i / \rho^\text{HS}_i N_i\), where \(N_i\) is the total population of district \(i\) and assumed to be constant. As such, ART coverage may also be modelled using a binomial logistic regression model \[\begin{align} A_i &\sim \text{Bin}(N_i, \rho^\text{HS}_i \alpha_i), \\ \text{logit}(\alpha_i) &\sim \mathcal{N}(\beta_\alpha, \sigma_\alpha^2), \end{align}\] where the proportion of people receiving ART is \(\rho^\text{HS}_i \alpha_i\). Here we assume no travel between districts to receive treatment.
We consider five models as described below. For each model we write a
C++ template. As well as the standard TMB
inference approach, this template allows the model to be fit using Stan
via tmbstan
and using adaptive Gauss-Hermite quadrature via
Model | Report | Components |
0 | prev-anc-art_model0 |
Prevalence (no random effects) |
1 | prev-anc-art_model1 |
Prevalence |
2 | prev-anc-art_model2 |
Prevalence, ANC |
3 | prev-anc-art_model3 |
Prevalence, ART |
4 | prev-anc-art_model4 |
Prevalence, ANC, ART |
We perform inference on simulated data from Model 4 generated (prev-anc-art_sim
with the following parameter values:
Parameter | Value |
\(n\) | 36 |
\(m_i\) | 250 |
\(\beta_\rho\) | -2.4 |
\(\sigma_\rho\) | 0.5 |
\(m^\text{ANC}_i\) | 10^4 |
\(\beta_b\) | -0.2 |
\(\sigma_b\) | 0.1 |
\(N_i\) | 10^5 |
\(\beta_\alpha\) | 0.7 |
\(\sigma_\alpha\) | 0.35 |
cache = TRUE,
autodep = TRUE,
cache.lazy = FALSE,
cache.comments = FALSE
options(scipen = 999)
cbpalette <- multi.utils::cbpalette()
rstan::rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())
Create true values dataframe:
true_values <- data.frame(
parameter = c(
true_value = c(
log(sqrt(1 / 4)),
log(sqrt(1/ 100)),
Read in results created in reports:
# model0 <- readRDS("depends/results_model0.rds")
model1 <- readRDS("depends/results_model1.rds")
model2 <- readRDS("depends/results_model2.rds")
model3 <- readRDS("depends/results_model3.rds")
model4 <- readRDS("depends/results_model4.rds")
Boxplots of the mean and standard deviation for each parameter, where each datapoint is a different simulation replicate. For now we will just focus on a subset of the parameters, as it is challenging to nicely present plots for very large numbers of parameters.
params <- c("beta_prev", "log_sigma_phi_prev")
draw_boxplots(model1, params = params) + labs(title = "Model 1")
draw_boxplots(model2, params = params) + labs(title = "Model 2")
draw_boxplots(model3, params = params) + labs(title = "Model 3")
draw_boxplots(model4, params = params) + labs(title = "Model 4")
Scatterplots of posterior summaries (such as mean and standard
deviation) where the \(x\)-axis is the
posterior summary as estimated by tmbstan
and the \(y\)-axis is the posterior summary as
estimated by either ahgq
or TMB
. We treat
as the gold-standard, so any deviation from the
line \(y = x\) suggests inaccurate
posterior summaries.
params <- c("beta_prev", "log_sigma_phi_prev")
draw_scatterplots(model1, params = params) + labs(title = "Model 1")
draw_scatterplots(model2, params = params) + labs(title = "Model 2")
draw_scatterplots(model3, params = params) + labs(title = "Model 3")
draw_scatterplots(model4, params = params) + labs(title = "Model 4")
Take samples from all distributions, then compute maximum ECDF difference \(D\) (two-sample Kolmogorov–Smirnov test). On the \(y\)-axis we plot \(\text{KS}(\texttt{TMB}, \texttt{tmbstan})\) and on the \(x\)-axis \(\text{KS}(\texttt{aghq}, \texttt{tmbstan})\). Lower values of \(D\) (minimum possible zero) correspond to more similar distributions, and higher values of \(D\) correspond to distributions which are more different.
draw_ksplots_D(model1) + labs(title = "Model 1")
## Warning: Removed 4 rows containing missing values (geom_point).
draw_ksplots_D(model2) + labs(title = "Model 2")
## Warning: Removed 68 rows containing missing values (geom_point).
draw_ksplots_D(model3) + labs(title = "Model 3")
## Warning: Removed 2 rows containing missing values (geom_point).
draw_ksplots_D(model4) + labs(title = "Model 4")
## Warning: Removed 65 rows containing missing values (geom_point).
ggsave("ks-example.pdf", draw_ksplots_D_params(model4, params = "phi_prev[5]"), h = 4, w = 6.25)
## Warning: Removed 1 rows containing missing values (geom_point).
We could also assess \(l\), the location of \(D\). Determining if there are patterns in the location of the greatest ECDF difference could present us with useful insights.
# draw_ksplots_l(model1) + labs(title = "Model 1")
# draw_ksplots_l(model2) + labs(title = "Model 2")
# draw_ksplots_l(model3) + labs(title = "Model 3")
# draw_ksplots_l(model4) + labs(title = "Model 4")
When using Markov chain Monte Carlo (MCMC) methods, as we have for
, it’s important to assess for convergence.
One way to do this is via traceplots which visualise the chains over the number of iterations specified.
map(model1, "mcmc_traceplots")
## [[1]]
## [[2]]
## [[3]]
## [[4]]
## [[5]]
map(model2, "mcmc_traceplots")
## [[1]]
## [[2]]
## [[3]]
## [[4]]
## [[5]]
## [[6]]
map(model3, "mcmc_traceplots")
## [[1]]
## [[2]]
## [[3]]
## [[4]]
## [[5]]
## [[6]]
map(model4, "mcmc_traceplots")
## [[1]]
## [[2]]
## [[3]]
## [[4]]
## [[5]]
## [[6]]
The \(\hat R\) statistic (“R hat”) can also be used. A value of \(\hat R < 1.1\) is typically sufficient.
draw_rhatplot(model1) + labs(title = "Model 1")
draw_rhatplot(model2) + labs(title = "Model 2")
draw_rhatplot(model3) + labs(title = "Model 3")
draw_rhatplot(model4) + labs(title = "Model 4")
## R version 4.2.0 (2022-04-22)
## Platform: x86_64-apple-darwin17.0 (64-bit)
## Running under: macOS Monterey 12.6
## Matrix products: default
## LAPACK: /Library/Frameworks/R.framework/Versions/4.2/Resources/lib/libRlapack.dylib
## locale:
## [1] en_GB.UTF-8/en_GB.UTF-8/en_GB.UTF-8/C/en_GB.UTF-8/en_GB.UTF-8
## attached base packages:
## [1] parallel stats graphics grDevices utils datasets methods base
## other attached packages:
## [1] multi.utils_0.1.0 tibble_3.1.7 tidyverse_1.3.1 aghq_0.4.0 tmbstan_1.0.4 rstan_2.21.5 StanHeaders_2.21.0-7 glmmTMB_1.1.4
## [9] TMB_1.9.1 stringr_1.4.0 purrr_0.3.4 tidyr_1.2.0 readr_2.1.2 INLA_22.05.07 sp_1.4-7 foreach_1.5.2
## [17] Matrix_1.5-1 ggplot2_3.3.6 forcats_0.5.1 dplyr_1.0.9
## loaded via a namespace (and not attached):
## [1] minqa_1.2.4 colorspace_2.0-3 ellipsis_0.3.2 class_7.3-20 rgdal_1.5-30 fs_1.5.2 rstudioapi_0.13 proxy_0.4-26
## [9] farver_2.1.0 fansi_1.0.3 lubridate_1.8.0 xml2_1.3.3 codetools_0.2-18 splines_4.2.0 knitr_1.39 jsonlite_1.8.0
## [17] nloptr_2.0.3 broom_0.8.0 dbplyr_2.1.1 compiler_4.2.0 httr_1.4.3 backports_1.4.1 assertthat_0.2.1 fastmap_1.1.0
## [25] cli_3.3.0 htmltools_0.5.2 prettyunits_1.1.1 tools_4.2.0 gtable_0.3.0 glue_1.6.2 Rcpp_1.0.8.3 jquerylib_0.1.4
## [33] cellranger_1.1.0 raster_3.5-15 vctrs_0.4.1 mvQuad_1.0-6 svglite_2.1.0 nlme_3.1-157 iterators_1.0.14 xfun_0.31
## [41] ps_1.7.0 lme4_1.1-30 rvest_1.0.2 bsae_0.2.7 lifecycle_1.0.1 statmod_1.4.36 orderly_1.4.3 terra_1.5-21
## [49] ids_1.0.1 MASS_7.3-56 scales_1.2.0 ragg_1.2.2 hms_1.1.1 inline_0.3.19 yaml_2.3.5 gridExtra_2.3
## [57] sass_0.4.1 loo_2.5.1 stringi_1.7.6 highr_0.9 e1071_1.7-9 boot_1.3-28 pkgbuild_1.3.1 rlang_1.0.2
## [65] pkgconfig_2.0.3 systemfonts_1.0.4 matrixStats_0.62.0 evaluate_0.15 lattice_0.20-45 sf_1.0-7 labeling_0.4.2 processx_3.5.3
## [73] tidyselect_1.1.2 magrittr_2.0.3 R6_2.5.1 generics_0.1.2 DBI_1.1.2 pillar_1.7.0 haven_2.5.0 withr_2.5.0
## [81] units_0.8-0 modelr_0.1.8 crayon_1.5.1 uuid_1.1-0 KernSmooth_2.23-20 utf8_1.2.2 tzdb_0.3.0 rmarkdown_2.14
## [89] grid_4.2.0 readxl_1.4.0 data.table_1.14.2 callr_3.7.0 webshot_0.5.3 reprex_2.0.1 digest_0.6.29 classInt_0.4-3
## [97] numDeriv_2016.8-1.1 textshaping_0.3.6 openssl_2.0.1 RcppParallel_5.1.5 stats4_4.2.0 munsell_0.5.0 viridisLite_0.4.0 kableExtra_1.3.4
## [105] bslib_0.3.1 askpass_1.1