Validation Guide#
Validating the posterior estimates is a crucial step in Simulation-Based Inference (SBI). Since we do not have access to the ground truth posterior for complex simulators, we rely on statistical diagnostics to check if our model is well-calibrated.
GenSBI provides a native JAX-based diagnostics module. This module is inspired by and is partially a port of the diagnostics submodule of the sbi library.
We also provide a PosteriorWrapper from gensbi.diagnostics that can be used to create a posterior object with an interface similar to an sbi posterior.
Note
This is a JAX module. To use it with PyTorch (if desired), a compatibility layer must be employed.
This page focuses on three complementary diagnostics:
Simulation-Based Calibration (SBC): a global calibration check averaged over the prior predictive distribution.
Targeted At Random Parameters (TARP): a global calibration/accuracy test based on expected coverage probabilities.
Local Classifier 2-Sample Test (L-C2ST): a local hypothesis test to assess posterior correctness for a specific observation.
Note
These diagnostics test different properties. In practice, they are most useful when combined with posterior predictive checks (PPCs): SBC/TARP can catch systematic miscalibration, while PPCs often reveal model misspecification.
Preparing Validation Data#
Most diagnostics require a set of “true” parameters (\(\theta\)) drawn from the prior and their corresponding observations (\(x\)) generated by the simulator.
Important
GenSBI diagnostics expect flattened arrays (numpy or JAX arrays) of shape (num_samples, features). GenSBI models often work with 3D arrays (batch, features, channels). You must flatten the feature dimensions before passing them to diagnostic functions.
import jax
import numpy as np
from gensbi.diagnostics import run_sbc, sbc_rank_plot
from gensbi.diagnostics import run_tarp, plot_tarp
from gensbi.diagnostics import LC2ST, plot_lc2st
# 1. Generate data using your simulator
# This should be separate from your training data
key = jax.random.PRNGKey(1234)
test_data = simulator(key, 1000) # Shape: (1000, dim_joint, 1)
# 2. Split into parameters (theta) and observations (x)
# Adjust indices based on your data structure
dim_obs = 3
thetas = test_data[:, :dim_obs, :]
xs = test_data[:, dim_obs:, :]
# 3. Flatten for diagnostics
# (batch, dim, channel) -> (batch, dim * channel)
thetas =thetas.reshape(thetas.shape[0], -1)
xs = xs.reshape(xs.shape[0], -1)
Tip
Keep validation data separate from training data. Most diagnostics assume i.i.d. samples from the prior predictive distribution.
Suggested sample sizes (rule of thumb)#
These diagnostics trade compute for statistical power. The values below are commonly used starting points:
SBC:
prior-predictive pairs:
~200(often enough for a first signal)posterior samples per pair:
~1000histogram bins: choose \(B\) such that \(N/B \approx 20\), where \(N\) is the number of SBC runs
TARP:
prior-predictive pairs:
~200posterior samples per pair:
~1000is a good default; using~10_000yields smoother curves but can be memory-heavy
L-C2ST (most expensive):
calibration pairs: at least in the thousands (often
10_000)posterior samples used for evaluation at a single observation \(x_o\): typically
~10_000number of posterior samples per calibration observation: commonly
1(this is whatsbi’s examples use)
If you are memory constrained, prefer fewer posterior samples or enable batched sampling where supported.
Note
These numbers match the typical settings used in the my_first_model notebook: SBC uses 200 pairs and 1,000 posterior samples per pair; TARP uses 200 pairs and (optionally) up to 10,000 posterior samples per pair; L-C2ST uses 10,000 calibration pairs and 10,000 posterior samples for evaluation at \(x_o\).
Simulation-Based Calibration (SBC)#
Simulation-Based Calibration (SBC) is a diagnostic tool to check if the posterior approximation is well-calibrated on average (over the prior). It relies on the self-consistency property of Bayesian inference: if we sample \(\theta \sim p(\theta)\) and \(x \sim p(x|\theta)\), then the rank of the true parameters \(\theta\) within the set of posterior samples drawn from \(p(\theta|x)\) should be uniformly distributed.
SBC is particularly useful for detecting:
Overconfidence (Under-dispersion): The posterior is too narrow. The rank histogram will look U-shaped.
Underconfidence (Over-dispersion): The posterior is too wide. The rank histogram will look like an inverted U (hump-shaped).
Bias: The posterior is systematically shifted. The rank histogram will be skewed.
Conceptually, for each simulated pair \((\theta_i, x_i)\) you:
Draw posterior samples \(\theta_{i,1:S} \sim q(\theta\mid x_i)\).
Compute the rank of each true parameter component \(\theta_i^{(d)}\) within the sampled marginal \(\{\theta_{i,1:S}^{(d)}\}\).
Check whether the rank distribution is consistent with a discrete uniform distribution.
check_sbc() additionally provides quantitative tests (e.g. a KS test on ranks) and a “data averaged posterior” (DAP) sanity check.
Note
SBC is a necessary condition for correctness, but not sufficient: a posterior can be uninformative and still pass SBC. Treat SBC as a calibration alarm bell, and complement it with PPCs.
from gensbi.diagnostics import run_sbc, sbc_rank_plot, check_sbc
num_posterior_samples = 1000
# 1. Generate posterior samples for each observation in `xs`
# We use `sample_batched` for efficiency
posterior_samples_ = pipeline.sample_batched(
jax.random.PRNGKey(42),
xs, # Validation observations
num_samples=num_posterior_samples,
chunk_size=10, # Adjust based on memory
)
# Flatten posterior samples: (batch, num_samples, dim, 1) -> (batch, num_samples, dim)
posterior_samples = posterior_samples_.reshape(
posterior_samples_.shape[0], posterior_samples_.shape[1], -1
)
# 2. Run SBC
ranks, dap_samples = run_sbc(thetas, xs, posterior_samples)
# 3. Check for uniformity
check_stats = check_sbc(
ranks,
thetas,
dap_samples,
num_posterior_samples=num_posterior_samples,
)
print(check_stats)
# 4. Plot ranks
f, ax = sbc_rank_plot(
ranks,
num_posterior_samples=num_posterior_samples,
plot_type="hist",
num_bins=20
)
plt.show()
If the histogram is uniform (flat), the posterior is well-calibrated. U-shapes indicate under-dispersion (overconfidence), while inverse U-shapes indicate over-dispersion (underconfidence).
For more details, see the sbi guide and tutorial:
Targeted At Random Parameters (TARP)#
TARP estimates the expected coverage probability of the posterior.
At a high level, TARP compares (for many simulated pairs) the empirical fraction of posterior samples falling “closer” to a reference point than the true parameter, across a grid of nominal coverage levels \(\alpha\). The output is a curve of expected coverage probability (ECP) versus \(\alpha\).
Note
TARP comes with strong theoretical guarantees in the original work, but its full diagnostic power depends on how reference points are chosen. In practice, it is typically interpreted similarly to SBC and combined with PPCs.
# Create TARP plot using the same samples from SBC
ecp, alpha = run_tarp(
thetas,
posterior_samples,
references=None, # Automatically calculated if None
)
# Plot coverage
plot_tarp(ecp, alpha)
plt.show()
Ideally, the curve should follow the diagonal. If the curve is above the diagonal, the model is under-confident. If below, it is over-confident.
For more details, see the sbi how-to:
Local Classifier 2-Sample Test (L-C2ST)#
L-C2ST (Local Classifier 2-Sample Test) is a local diagnostic: it aims to detect whether the posterior is accurate for a specific observation \(x_o\).
It works by training a classifier on a calibration dataset sampled from the joint distribution (prior + simulator), and then evaluating a hypothesis test at \(x_o\) using posterior samples from your estimator.
Important
The calibration dataset must be independent from the data used to train your posterior estimator.
import jax.numpy as jnp
# 0. Reuse calibration data (thetas, xs) or generate new ones.
theta_cal, x_cal = thetas, xs
# 1. Generate one posterior sample per calibration observation
# Here we reuse the previous samples, taking just the first one
posterior_samples_cal = posterior_samples[:, 0, :]
# 2. Train the classifier
lc2st = LC2ST(
thetas=theta_cal[:-1], # Train set parameters
xs=x_cal[:-1], # Train set observations
posterior_samples=posterior_samples_cal[:-1], # Posterior samples for train set
classifier="mlp",
num_ensemble=1,
)
# Train under null hypothesis (permutation test) and on observed data
_ = lc2st.train_under_null_hypothesis()
_ = lc2st.train_on_observed_data()
# 3. Evaluate for a specific observation x_o
# Take the last item as our "test" observation x_o
x_o = x_cal[-1:]
theta_o = thetas[-1:] # True parameter
# Generate many posterior samples for this specific observation
post_samples_star_ = pipeline.sample(
jax.random.PRNGKey(42),
x_o,
nsamples=10_000
)
post_samples_star = post_samples_star_.reshape(post_samples_star_.shape[0], -1)
# Plot results
fig, ax = plot_lc2st(
lc2st,
post_samples_star,
x_o,
)
plt.show()
If L-C2ST rejects the null hypothesis (typically p-value < 0.05), it indicates that your approximate posterior is detectably different from the true posterior at this observation.
For more details, see:
sbihow-to: https://sbi.readthedocs.io/en/latest/how_to_guide/13_diagnostics_lc2st.htmlsbiadvanced tutorial: https://sbi.readthedocs.io/en/latest/advanced_tutorials/13_diagnostics_lc2st.html
For a complete GenSBI example workflow, refer to the my_first_model.ipynb notebook.
References and further reading#
Original papers:
Cook, S. R., Gelman, A., & Rubin, D. B. (2006). Validation of software for Bayesian models using posterior quantiles. Journal of Computational and Graphical Statistics.
Talts, S., Betancourt, M., Simpson, D., Vehtari, A., & Gelman, A. (2018). Validating Bayesian inference algorithms with simulation-based calibration. arXiv:1804.06788.
Lemos, P., Coogan, A., Hezaveh, Y., & Perreault-Levasseur, L. (2023). Sampling-based accuracy testing of posterior estimators for general inference. ICML (PMLR).
Linhart, J., Gramfort, A., & Rodrigues, P. (2023). L-C2ST: Local diagnostics for posterior approximations in simulation-based inference. NeurIPS.
sbi documentation pages (recommended):
SBC how-to: https://sbi.readthedocs.io/en/latest/how_to_guide/16_sbc.html
SBC tutorial: https://sbi.readthedocs.io/en/latest/advanced_tutorials/11_diagnostics_simulation_based_calibration.html
TARP how-to: https://sbi.readthedocs.io/en/latest/how_to_guide/17_tarp.html
L-C2ST how-to: https://sbi.readthedocs.io/en/latest/how_to_guide/13_diagnostics_lc2st.html
L-C2ST tutorial: https://sbi.readthedocs.io/en/latest/advanced_tutorials/13_diagnostics_lc2st.html