Source code for gensbi.diagnostics.tarp

# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
#
# --------------------------------------------------------------------------
# MODIFICATION NOTICE:
# This file was modified by Aurelio Amerio on 01-2026.
# Description: Ported implementation to use JAX instead of PyTorch.
# --------------------------------------------------------------------------

"""
Implementation taken from Lemos et al, 'Sampling-Based Accuracy Testing of
Posterior Estimators for General Inference' https://arxiv.org/abs/2302.03026

The TARP diagnostic is a global diagnostic which can be used to check a
trained posterior against a set of true values of theta.
"""

from typing import Callable, Optional, Tuple

from scipy.stats import kstest
import jax
from jax import numpy as jnp
from jax import Array

import numpy as np

from matplotlib.axes import Axes
from matplotlib.colors import Normalize
from matplotlib.figure import Figure, FigureBase
import matplotlib.pyplot as plt


from gensbi.diagnostics.metrics import l1, l2


[docs] def run_tarp( thetas: Array, posterior_samples: Array, seed: int = 1, references: Optional[Array] = None, distance: Callable = l2, num_bins: Optional[int] = 30, z_score_theta: bool = True, ) -> Tuple[Array, Array]: """ Estimates coverage of samples given true values `thetas` with the TARP method. Reference --------- Lemos, Coogan et al. (2023). "Sampling-Based Accuracy Testing of Posterior Estimators for General Inference". https://arxiv.org/abs/2302.03026 Parameters ---------- thetas : Array Ground-truth parameters for TARP, simulated from the prior. Shape: (num_tarp_samples, dim_theta). posterior_samples : Array Posterior samples. Shape: (num_posterior_samples, num_tarp_samples, dim_theta). seed : int, optional Random seed for sampling reference points. Default is 1. references : Array, optional Reference points for the coverage regions. If None, reference points are chosen uniformly from the parameter space. distance : Callable, optional Distance metric to use when computing the distance. Should accept two tensors and return distance values. Possible values: ``gensbi.diagnostics.metrics.l1`` or ``gensbi.diagnostics.metrics.l2``. ``l2`` is the default. num_bins : int, optional Number of bins to use for the credibility values. If None, then num_tarp_samples // 10 bins are used. Default is 30. z_score_theta : bool, optional Whether to normalize parameters before coverage test. Default is True. Returns ------- ecp : Array Expected coverage probability, see equation 4 of the paper. alpha : Array Credibility values, see equation 2 of the paper. """ key = jax.random.PRNGKey(seed) num_tarp_samples, dim_theta = thetas.shape num_posterior_samples = posterior_samples.shape[0] assert posterior_samples.shape == ( num_posterior_samples, num_tarp_samples, dim_theta, ), f"Wrong posterior samples shape for TARP: {posterior_samples.shape}, expected {(num_posterior_samples, num_tarp_samples, dim_theta)}" # Sample reference points uniformly if not provided if references is None: references = get_tarp_references(key, thetas) return _run_tarp( posterior_samples, thetas, references, distance, num_bins, z_score_theta )
[docs] def _run_tarp( posterior_samples: Array, thetas: Array, references: Array, distance: Callable = l2, num_bins: Optional[int] = 30, z_score_theta: bool = False, ) -> Tuple[Array, Array]: """ Estimates coverage of samples given true values `thetas` with the TARP method. Reference --------- Lemos, Coogan et al. (2023). "Sampling-Based Accuracy Testing of Posterior Estimators for General Inference". https://arxiv.org/abs/2302.03026 Parameters ---------- posterior_samples : Array Predicted parameter samples to compute the coverage of. Shape: (num_posterior_samples, num_tarp_samples, dim_theta). thetas : Array True parameter values. Shape: (num_tarp_samples, dim_theta). references : Array Reference points for the coverage regions. Shape: (num_tarp_samples, dim_theta). distance : Callable, optional Distance metric to use when computing the distance. Should accept two tensors and return distance values. Possible values: ``gensbi.diagnostics.metrics.l1`` or ``gensbi.diagnostics.metrics.l2``. ``l2`` is the default. num_bins : int, optional Number of bins to use for the credibility values. If None, then num_tarp_samples // 10 bins are used. Default is 30. z_score_theta : bool, optional Whether to normalize parameters before coverage test. Default is False. Returns ------- ecp : Array Expected coverage probability, see equation 4 of the paper. alpha : Array Grid of credibility values, see equation 2 of the paper. """ num_posterior_samples, num_tarp_samples, _ = posterior_samples.shape assert ( references.shape == thetas.shape ), "references must have the same shape as thetas" if num_bins is None: num_bins = num_tarp_samples // 10 if z_score_theta: lo = thetas.min(axis=0, keepdims=True) # min over batch hi = thetas.max(axis=0, keepdims=True) # max over batch posterior_samples = (posterior_samples - lo) / (hi - lo + 1e-10) thetas = (thetas - lo) / (hi - lo + 1e-10) # distances between references and samples sample_dists = distance(references, posterior_samples) # distances between references and true values theta_dists = distance(references, thetas) # compute coverage, f in algorithm 2 coverage_values = ( jnp.sum(sample_dists < theta_dists, axis=0) / num_posterior_samples ) hist, alpha_grid = jnp.histogram( coverage_values, density=True, bins=num_bins ) # calculate empirical CDF via cumsum and normalize ecp = jnp.cumsum(hist, axis=0) / hist.sum() # add 0 to the beginning of the ecp curve to match the alpha grid ecp = jnp.concatenate([jnp.zeros((1,)), ecp]) return ecp, alpha_grid
[docs] def get_tarp_references(key, thetas: Array) -> Array: """Returns reference points for the TARP diagnostic, sampled from a uniform.""" # obtain min/max per dimension of theta lo = thetas.min(axis=0) # min for each theta dimension hi = thetas.max(axis=0) # max for each theta dimension samples = jax.random.uniform(key, thetas.shape, minval=lo, maxval=hi) # sample one reference point for each entry in theta return samples
[docs] def check_tarp( ecp: Array, alpha: Array, ) -> Tuple[float, float]: r""" Check the obtained TARP credibility levels and expected coverage probabilities. This diagnostic helps to uncover underdispersed, well-covering, or overdispersed posteriors. Let :math:`\mathrm{ecp}` be the expected coverage probability computed with the TARP method, and :math:`\alpha` the credibility levels (second output of ``run_tarp``). The area to curve (ATC) is defined as: .. math:: \mathrm{ATC} = \sum_{i: \alpha_i > 0.5} \left( \mathrm{ecp}_i - \alpha_i \right) where values close to zero indicate well-calibrated posteriors. Values larger than zero indicate overdispersed distributions (the estimated posterior is too wide), while values smaller than zero indicate underdispersed distributions (the estimated posterior is too narrow). This property can also indicate if the posterior is biased (see Figure 2 of the reference paper). A two-sample Kolmogorov-Smirnov test is performed between :math:`\mathrm{ecp}` and :math:`\alpha` to test the null hypothesis that both distributions are identical (produced by one common CDF). The p-value should be close to 1 for well-calibrated posteriors. Commonly, the null is rejected if p-value is below 0.05. Reference --------- Lemos, Coogan et al. (2023). "Sampling-Based Accuracy Testing of Posterior Estimators for General Inference". https://arxiv.org/abs/2302.03026 Parameters ---------- ecp : array-like Expected coverage probabilities computed with the TARP method (first output of ``run_tarp``). alpha : array-like Credibility levels :math:`\alpha` (second output of ``run_tarp``). Returns ------- atc : float Area to curve, the difference between the ecp and alpha curve for :math:`\alpha > 0.5`. ks_prob : float p-value for a two-sample Kolmogorov-Smirnov test between ecp and alpha. """ # get the index of the middle of the alpha grid midindex = alpha.shape[0] // 2 # area to curve: difference between ecp and alpha above 0.5. atc = (ecp[midindex:] - alpha[midindex:]).sum().item() # Kolmogorov-Smirnov test between ecp and alpha kstest_pvals: float = kstest(np.array(ecp), np.array(alpha))[1] # type: ignore return atc, kstest_pvals
[docs] def plot_tarp( ecp: Array, alpha: Array, title: Optional[str] = None ) -> Tuple[Figure, Axes]: """ Plot the expected coverage probability (ECP) against the credibility level (alpha). Parameters ---------- ecp : array-like Array of expected coverage probabilities. alpha : array-like Array of credibility levels. title : str, optional Title for the plot. Default is "". Returns ------- fig : matplotlib.figure.Figure The figure object. ax : matplotlib.axes.Axes The axes object. """ fig = plt.figure(figsize=(6, 6)) ax: Axes = plt.gca() ecp = np.array(ecp) alpha = np.array(alpha) ax.plot(alpha, ecp, color="blue", label="TARP") ax.plot(alpha, alpha, color="black", linestyle="--", label="ideal") ax.set_xlabel(r"Credibility Level $\alpha$") ax.set_ylabel(r"Expected Coverage Probability") ax.set_xlim(0.0, 1.0) ax.set_ylim(0.0, 1.0) ax.set_title(title or "") ax.legend() return fig, ax # type: ignore