Validation Guide#

Validating the posterior estimates is a crucial step in Simulation-Based Inference (SBI). Since we do not have access to the ground truth posterior for complex simulators, we rely on statistical diagnostics to check if our model is well-calibrated.

GenSBI provides a compatibility layer to leverage the powerful validation tools from the sbi library.

This page focuses on three complementary diagnostics:

  • Simulation-Based Calibration (SBC): a global calibration check averaged over the prior predictive distribution.

  • Targeted At Random Parameters (TARP): a global calibration/accuracy test based on expected coverage probabilities.

  • Local Classifier 2-Sample Test (L-C2ST): a local hypothesis test to assess posterior correctness for a specific observation.

Note

These diagnostics test different properties. In practice, they are most useful when combined with posterior predictive checks (PPCs): SBC/TARP can catch systematic miscalibration, while PPCs often reveal model misspecification.

Prerequisites#

To run these tests, you need to install the validation dependencies:

pip install "GenSBI[validation] @ git+https://github.com/aurelio-amerio/GenSBI.git" --extra-index-url https://download.pytorch.org/whl/cpu

The Posterior Wrapper#

The sbi library expects a specific interface for posterior objects. GenSBI provides PosteriorWrapper to adapt your trained pipeline.

from flax import nnx
from gensbi_validation import PosteriorWrapper

# Assuming 'pipeline' is your trained GenSBI pipeline
# We need a PRNGKey for sampling within the wrapper
posterior = PosteriorWrapper(pipeline, rngs=nnx.Rngs(1234))

The wrapper exposes (a subset of) the sbi posterior API (e.g. sample(), log_prob(), and sample_batched()), while internally running the GenSBI pipeline.

Preparing Validation Data#

Most diagnostics require a set of “true” parameters (\(\theta\)) drawn from the prior and their corresponding observations (\(x\)) generated by the simulator.

Important

sbi expects 2D tensors of shape (num_samples, features). GenSBI models often work with 3D arrays (batch, features, channels). You must flatten the feature dimensions before passing them to sbi functions.

import jax
import numpy as np
import torch

# 1. Generate data using your simulator
# This should be separate from your training data
key = jax.random.PRNGKey(1234)
test_data = simulator(key, 1000) # Shape: (1000, joint_dim, 1)

# 2. Split into parameters (theta) and observations (x)
# Adjust indices based on your data structure
obs_dim = 3 
thetas = test_data[:, :obs_dim, :] 
xs = test_data[:, obs_dim:, :]

# 3. Flatten and convert to Torch tensors
# The wrapper provides a helper _ravel method, or you can use reshape
thetas = posterior._ravel(thetas) 
xs = posterior._ravel(xs)

thetas_torch = torch.Tensor(np.array(thetas))
xs_torch = torch.Tensor(np.array(xs))

Tip

Keep validation data separate from training data. Most diagnostics assume i.i.d. samples from the prior predictive distribution.

Suggested sample sizes (rule of thumb)#

These diagnostics trade compute for statistical power. The values below are commonly used starting points:

  • SBC:

    • prior-predictive pairs: ~200 (often enough for a first signal)

    • posterior samples per pair: ~1000

    • histogram bins: choose \(B\) such that \(N/B \approx 20\), where \(N\) is the number of SBC runs

  • TARP:

    • prior-predictive pairs: ~200

    • posterior samples per pair: ~1000 is a good default; using ~10_000 yields smoother curves but can be memory-heavy

  • L-C2ST (most expensive):

    • calibration pairs: at least in the thousands (often 10_000)

    • posterior samples used for evaluation at a single observation \(x_o\): typically ~10_000

    • number of posterior samples per calibration observation: commonly 1 (this is what sbi’s examples use)

If you are memory constrained, prefer fewer posterior samples or enable batched sampling where supported.

Note

These numbers match the typical settings used in the my_first_model notebook: SBC uses 200 pairs and 1,000 posterior samples per pair; TARP uses 200 pairs and (optionally) up to 10,000 posterior samples per pair; L-C2ST uses 10,000 calibration pairs and 10,000 posterior samples for evaluation at \(x_o\).

Simulation-Based Calibration (SBC)#

Simulation-Based Calibration (SBC) is a diagnostic tool to check if the posterior approximation is well-calibrated on average (over the prior). It relies on the self-consistency property of Bayesian inference: if we sample \(\theta \sim p(\theta)\) and \(x \sim p(x|\theta)\), then the rank of the true parameters \(\theta\) within the set of posterior samples drawn from \(p(\theta|x)\) should be uniformly distributed.

SBC is particularly useful for detecting:

  • Overconfidence (Under-dispersion): The posterior is too narrow. The rank histogram will look U-shaped.

  • Underconfidence (Over-dispersion): The posterior is too wide. The rank histogram will look like an inverted U (hump-shaped).

  • Bias: The posterior is systematically shifted. The rank histogram will be skewed.

Conceptually, for each simulated pair \((\theta_i, x_i)\) you:

  1. Draw posterior samples \(\theta_{i,1:S} \sim q(\theta\mid x_i)\).

  2. Compute the rank of each true parameter component \(\theta_i^{(d)}\) within the sampled marginal \(\{\theta_{i,1:S}^{(d)}\}\).

  3. Check whether the rank distribution is consistent with a discrete uniform distribution.

In sbi, check_sbc() additionally provides quantitative tests (e.g. a KS test on ranks) and a “data averaged posterior” (DAP) sanity check.

Note

SBC is a necessary condition for correctness, but not sufficient: a posterior can be uninformative and still pass SBC. Treat SBC as a calibration alarm bell, and complement it with PPCs.

from sbi.diagnostics import run_sbc, check_sbc
from sbi.analysis.plot import sbc_rank_plot
import matplotlib.pyplot as plt

num_posterior_samples = 1000

# Run SBC
ranks, dap_samples = run_sbc(
    thetas_torch,
    xs_torch,
    posterior,
    num_posterior_samples=num_posterior_samples,
    # `True` can be faster, but can increase memory use.
    use_batched_sampling=False,
)

# Check for uniformity
check_stats = check_sbc(
    ranks,
    thetas_torch,
    dap_samples,
    num_posterior_samples=num_posterior_samples,
)
print(check_stats)

# Plot ranks
f, ax = sbc_rank_plot(
    ranks,
    num_posterior_samples=num_posterior_samples,
    plot_type="hist",
    num_bins=20,
)
plt.show()

If the histogram is uniform (flat), the posterior is well-calibrated. U-shapes indicate under-dispersion (overconfidence), while inverse U-shapes indicate over-dispersion (underconfidence).

For more details, see the sbi guide and tutorial:

Targeted At Random Parameters (TARP)#

TARP estimates the expected coverage probability of the posterior.

At a high level, TARP compares (for many simulated pairs) the empirical fraction of posterior samples falling “closer” to a reference point than the true parameter, across a grid of nominal coverage levels \(\alpha\). The output is a curve of expected coverage probability (ECP) versus \(\alpha\).

Note

TARP comes with strong theoretical guarantees in the original work, but its full diagnostic power depends on how reference points are chosen. In practice, it is typically interpreted similarly to SBC and combined with PPCs.

from sbi.diagnostics import run_tarp, check_tarp
from sbi.analysis.plot import plot_tarp

# Run TARP
ecp, alpha = run_tarp(
    thetas_torch,
    xs_torch,
    posterior,
    num_posterior_samples=1000,  # consider 10_000 for a smoother curve (more compute/memory)
)

atc, ks_pval = check_tarp(ecp, alpha)
print(atc, "Should be close to 0")
print(ks_pval, "Should be larger than 0.05")

# Plot coverage
plot_tarp(ecp, alpha)
plt.show()

Ideally, the curve should follow the diagonal. If the curve is above the diagonal, the model is under-confident. If below, it is over-confident.

For more details, see the sbi how-to:

Local Classifier 2-Sample Test (L-C2ST)#

L-C2ST (Local Classifier 2-Sample Test) is a local diagnostic: it aims to detect whether the posterior is accurate for a specific observation \(x_o\).

In sbi, L-C2ST works by training a classifier on a calibration dataset sampled from the joint distribution (prior + simulator), and then evaluating a hypothesis test at \(x_o\) using posterior samples from your estimator.

Important

L-C2ST requires additional simulations (calibration data) and training an auxiliary classifier. Its statistical power depends on the classifier and calibration set size, so treat it as a diagnostic tool (not an oracle).
The calibration dataset must be independent from the data used to train your posterior estimator.

Tip

L-C2ST has two key sample counts:

  • the calibration dataset size (often 10_000 pairs), and

  • the number of posterior samples used to evaluate a specific observation (often 10_000).

Using too few samples for either can make the test noisy.

from sbi.diagnostics.lc2st import LC2ST

# 0. Calibration data (theta_cal, x_cal) sampled from the prior predictive.
# You can reuse the `thetas_torch` / `xs_torch` prepared above if they were not
# used to train the posterior estimator.
theta_cal, x_cal = thetas_torch, xs_torch

# 1. Generate one posterior sample per calibration observation
posterior_samples_cal = posterior.sample_batched(
    (1,), # num_samples per observation
    x=x_cal,
)[0]

# 2. Train the classifier
lc2st = LC2ST(
    thetas=theta_cal,
    xs=x_cal,
    posterior_samples=posterior_samples_cal,
    classifier="mlp",
)

# Train under null hypothesis (permutation test) and on observed data
_ = lc2st.train_under_null_hypothesis()
_ = lc2st.train_on_observed_data()

# 3. Evaluate for a specific observation x_o
# Note: x_o must have a batch dimension: x_o.shape == (1, observation_dim)
x_o = x_cal[:1]  # or your real observed data, shaped as (1, obs_dim)
theta_o = posterior.sample((10_000,), x=x_o)

conf_alpha = 0.05
p_value = lc2st.p_value(theta_o, x_o)
reject = lc2st.reject_test(theta_o, x_o, alpha=conf_alpha)
print(f"p-value={p_value:.3f}, reject={reject}")

If L-C2ST rejects the null hypothesis (typically p-value < 0.05), it indicates that your approximate posterior is detectably different from the true posterior at this observation. sbi also provides qualitative plots (e.g. P-P plots) that can hint at over- vs under-confidence.

For more details, see:

For a complete GenSBI example workflow, refer to the my_first_model.ipynb notebook.

References and further reading#

Original papers:

  • Cook, S. R., Gelman, A., & Rubin, D. B. (2006). Validation of software for Bayesian models using posterior quantiles. Journal of Computational and Graphical Statistics.

  • Talts, S., Betancourt, M., Simpson, D., Vehtari, A., & Gelman, A. (2018). Validating Bayesian inference algorithms with simulation-based calibration. arXiv:1804.06788.

  • Lemos, P., Coogan, A., Hezaveh, Y., & Perreault-Levasseur, L. (2023). Sampling-based accuracy testing of posterior estimators for general inference. ICML (PMLR).

  • Linhart, J., Gramfort, A., & Rodrigues, P. (2023). L-C2ST: Local diagnostics for posterior approximations in simulation-based inference. NeurIPS.

sbi documentation pages (recommended):