gensbi.diagnostics.tarp#

Implementation taken from Lemos et al, ‘Sampling-Based Accuracy Testing of Posterior Estimators for General Inference’ https://arxiv.org/abs/2302.03026

The TARP diagnostic is a global diagnostic which can be used to check a trained posterior against a set of true values of theta.

Classes#

TARPResult

Result of the TARP diagnostic.

Functions#

_compute_tarp(posterior_samples, thetas, references[, ...])

Core TARP computation (JIT-compiled).

_plot_tarp_confidence(result, ax[, title])

Internal function to plot confidence mode (Z-scores).

_plot_tarp_credibility(result, ax[, title])

Internal function to plot credibility mode.

_run_tarp_bootstrap(rng_key, posterior_samples, ...)

Bootstrap TARP: resample (theta, posterior) pairs with replacement.

_run_tarp_single(rng_key, posterior_samples, thetas, ...)

Runs a single iteration of TARP.

check_tarp(result)

Quantitative check of the TARP coverage curve.

get_tarp_references(key, thetas)

Sample reference points uniformly from the bounding box of theta.

plot_tarp(result[, title, figsize, mode])

Plot the expected coverage probability (ECP).

run_tarp(thetas, posterior_samples[, seed, ...])

Estimates coverage of samples given true values thetas with the TARP method.

Module Contents#

class gensbi.diagnostics.tarp.TARPResult[source]#

Result of the TARP diagnostic.

Stores the Expected Coverage Probability (ECP) curve and its uncertainty bounds. Provides z-score properties for the confidence-level view.

Parameters:
  • ecp (Array) – Expected coverage probability. shape: (num_bootstrap, num_bins + 1) or (num_bins + 1,)

  • alpha (Array) – Credibility levels (histogram bin edges). shape: (num_bins + 1,)

  • ecp_mean (Array) – Mean ECP across bootstrap iterations (or identical to ecp if no bootstrap).

  • ecp_lower (Optional[Array]) – Lower bound of ECP (2.5th percentile or Jeffrey’s lower). shape: (num_bins + 1,)

  • ecp_upper (Optional[Array]) – Upper bound of ECP (97.5th percentile or Jeffrey’s upper). shape: (num_bins + 1,)

_to_z(p)[source]#

Convert coverage probability p to z-score via probit(0.5 + p/2).

This maps p ∈ [0, 1] to z ∈ [0, ∞) such that p = 0.6827 → z = 1 (1σ), p = 0.9545 → z = 2 (2σ), etc. Values are clipped to [eps, 1-eps] to avoid infinities at the boundaries.

Parameters:

p (jax.Array)

Return type:

jax.Array

alpha: jax.Array[source]#
ecp: jax.Array[source]#
ecp_lower: jax.Array | None = None[source]#
ecp_mean: jax.Array[source]#
ecp_upper: jax.Array | None = None[source]#
property z_alpha: jax.Array[source]#

Z-scores corresponding to alpha (nominal coverage).

Return type:

jax.Array

property z_lower: jax.Array | None[source]#

Z-scores of the lower bound.

Return type:

Optional[jax.Array]

property z_mean: jax.Array[source]#

Z-scores corresponding to ecp_mean (empirical coverage).

Return type:

jax.Array

property z_upper: jax.Array | None[source]#

Z-scores of the upper bound.

Return type:

Optional[jax.Array]

gensbi.diagnostics.tarp._compute_tarp(posterior_samples, thetas, references, distance=l2, num_bins=30, z_score_theta=False)[source]#

Core TARP computation (JIT-compiled).

For each simulation i, computes f_i = fraction of posterior samples that are closer to the reference point than the true parameter theta_i (Algorithm 2, Eq. 4 in Lemos et al.). Under perfect calibration, f_i ~ Uniform(0, 1), so the empirical CDF of {f_i} should follow the diagonal.

Parameters:
  • posterior_samples (jax.Array)

  • thetas (jax.Array)

  • references (jax.Array)

  • distance (Callable)

  • num_bins (Optional[int])

  • z_score_theta (bool)

Return type:

Tuple[jax.Array, jax.Array]

gensbi.diagnostics.tarp._plot_tarp_confidence(result, ax, title=None)[source]#

Internal function to plot confidence mode (Z-scores).

Parameters:
  • result (TARPResult)

  • ax (matplotlib.axes.Axes)

  • title (Optional[str])

gensbi.diagnostics.tarp._plot_tarp_credibility(result, ax, title=None)[source]#

Internal function to plot credibility mode.

Parameters:
  • result (TARPResult)

  • ax (matplotlib.axes.Axes)

  • title (Optional[str])

gensbi.diagnostics.tarp._run_tarp_bootstrap(rng_key, posterior_samples, thetas, references, distance, num_bins, z_score_theta, num_bootstrap)[source]#

Bootstrap TARP: resample (theta, posterior) pairs with replacement.

Each iteration draws a bootstrap sample, optionally generates new reference points, and runs _compute_tarp. Uses jax.lax.scan instead of vmap so only one iteration is materialized at a time, avoiding OOM when the dataset is large.

Parameters:
  • rng_key (jax.Array)

  • posterior_samples (jax.Array)

  • thetas (jax.Array)

  • references (Optional[jax.Array])

  • distance (Callable)

  • num_bins (int)

  • z_score_theta (bool)

  • num_bootstrap (int)

Return type:

Tuple[jax.Array, jax.Array]

gensbi.diagnostics.tarp._run_tarp_single(rng_key, posterior_samples, thetas, references, distance, num_bins, z_score_theta)[source]#

Runs a single iteration of TARP.

Parameters:
  • rng_key (jax.Array)

  • posterior_samples (jax.Array)

  • thetas (jax.Array)

  • references (Optional[jax.Array])

  • distance (Callable)

  • num_bins (int)

  • z_score_theta (bool)

Return type:

Tuple[jax.Array, jax.Array]

gensbi.diagnostics.tarp.check_tarp(result)[source]#

Quantitative check of the TARP coverage curve.

Returns:

  • atc (float) – Area To Curve for \(\alpha > 0.5\). Positive means conservative (ECP above diagonal), negative means overconfident.

  • ks_prob (float) – Two-sample KS test p-value. Low values indicate the ECP differs significantly from the ideal diagonal.

Parameters:

result (TARPResult)

Return type:

Tuple[float, float]

gensbi.diagnostics.tarp.get_tarp_references(key, thetas)[source]#

Sample reference points uniformly from the bounding box of theta.

Parameters:

thetas (jax.Array)

Return type:

jax.Array

gensbi.diagnostics.tarp.plot_tarp(result, title=None, figsize=None, mode='both')[source]#

Plot the expected coverage probability (ECP).

Parameters:
  • result (TARPResult) – Results from run_tarp.

  • title (str, optional) – Title of the plot.

  • figsize (tuple, optional) – Figure size.

  • mode (str, optional) – “credibility”, “confidence”, or “both”. Default is “credibility”. “credibility” plots ECP vs alpha. “confidence” plots z(ECP) vs z(alpha).

Return type:

Tuple[matplotlib.figure.Figure, Union[matplotlib.axes.Axes, jax.Array]]

gensbi.diagnostics.tarp.run_tarp(thetas, posterior_samples, seed=1, references=None, distance=l2, num_bins=30, z_score_theta=True, bootstrap=False, num_bootstrap=100)[source]#

Estimates coverage of samples given true values thetas with the TARP method.

Parameters:
  • thetas (jax.Array)

  • posterior_samples (jax.Array)

  • seed (int)

  • references (Optional[jax.Array])

  • distance (Callable)

  • num_bins (Optional[int])

  • z_score_theta (bool)

  • bootstrap (bool)

  • num_bootstrap (int)

Return type:

TARPResult