# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
#
# --------------------------------------------------------------------------
# MODIFICATION NOTICE:
# This file was modified by Aurelio Amerio on 01-2026.
# Description: Ported implementation to use JAX instead of PyTorch.
# --------------------------------------------------------------------------
import warnings
from typing import Callable, Dict, List, Tuple, Union
from jax import Array
import jax.numpy as jnp
import jax
import numpy as np
from scipy.stats import kstest, uniform
from tqdm import tqdm
from gensbi.diagnostics.metrics import c2st
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Tuple,
TypeVar,
Union,
cast,
get_args,
)
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.figure import Figure, FigureBase
from scipy.stats import binom, gaussian_kde, iqr
[docs]
def run_sbc(
thetas: Array,
xs: Array,
posterior_samples: Array,
reduce_fns: Union[
str,
Callable[[Array, Array], Array],
List[Callable[[Array, Array], Array]],
] = "marginals",
show_progress_bar: bool = True,
**kwargs,
) -> Tuple[Array, Array]:
"""Run simulation-based calibration (SBC) or expected coverage.
Note: This function implements two versions of coverage diagnostics:
- Setting ``reduce_fns = "marginals"`` performs SBC as proposed in Talts et al.
(see https://arxiv.org/abs/1804.06788).
- Setting ``reduce_fns = posterior.log_prob`` performs sample-based expected
coverage as proposed in Deistler et al.
(see https://arxiv.org/abs/2210.04815).
Parameters
----------
thetas : Array
Ground-truth parameters for SBC, simulated from the prior.
xs : Array
Observed data for SBC, simulated from thetas.
posterior_samples : Array
Samples from the posterior. Shape: (num_posterior_samples, num_sbc_samples, dim_theta).
reduce_fns : str or Callable or List[Callable], optional
Function used to reduce the parameter space into 1D.
Simulation-based calibration can be recovered by setting this to the
string `"marginals"`. Sample-based expected coverage can be recovered
by setting it to `posterior.log_prob` (as a Callable). Default is "marginals".
show_progress_bar : bool, optional
Whether to display a progress bar over SBC runs. Default is True.
**kwargs
Additional keyword arguments.
Returns
-------
ranks : Array
Ranks of the ground truth parameters under the inferred posterior.
dap_samples : Array
Samples from the data-averaged posterior.
"""
# Remove NaNs and infinities from the input data.
num_sbc_samples, dim_theta = thetas.shape
num_posterior_samples = posterior_samples.shape[0]
# _validate_sbc_inputs(thetas, xs, num_sbc_samples, num_posterior_samples)
assert posterior_samples.shape == (
num_posterior_samples,
num_sbc_samples,
dim_theta,
), f"Wrong posterior samples shape for SBC: {posterior_samples.shape}, expected ({num_posterior_samples}, {num_sbc_samples}, {dim_theta})"
# Take a random draw from each posterior to get data-averaged posterior samples.
dap_samples = posterior_samples[0, :, :]
assert dap_samples.shape == (num_sbc_samples, thetas.shape[1]), "Wrong DAP shape."
# Calculate ranks
ranks = _run_sbc(thetas, xs, posterior_samples, reduce_fns, show_progress_bar)
return ranks, dap_samples
[docs]
def _run_sbc(
thetas: Array,
xs: Array,
posterior_samples: Array,
reduce_fns: Union[
str,
Callable[[Array, Array], Array],
List[Callable[[Array, Array], Array]],
] = "marginals",
show_progress_bar: bool = True,
) -> Array:
"""Calculate ranks for SBC or expected coverage.
Parameters
----------
thetas : Array
Ground-truth parameters.
xs : Array
Observed data corresponding to thetas.
posterior_samples : Array
Samples from posterior distribution.
reduce_fns : str or Callable or List[Callable], optional
Functions to reduce parameter space to 1D. Default is "marginals".
show_progress_bar : bool, optional
Whether to show progress bar. Default is True.
Returns
-------
Array
Array of ranks for each parameter and reduction function.
"""
num_sbc_samples = thetas.shape[0]
# Construct reduce functions for SBC or expected coverage.
reduce_fns = _prepare_reduce_functions(reduce_fns, thetas.shape[1])
# Initialize ranks tensor.
ranks = np.zeros((num_sbc_samples, len(reduce_fns)))
# Iterate over all SBC samples and calculate ranks.
for sbc_idx, (true_theta, x_i) in tqdm(
enumerate(zip(thetas, xs, strict=False)),
total=num_sbc_samples,
disable=not show_progress_bar,
desc=f"Calculating ranks for {num_sbc_samples} SBC samples",
):
# For each reduce_fn (e.g., per marginal for SBC)
for dim_idx, reduce_fn in enumerate(reduce_fns):
# Rank posterior samples against true parameter, reduced to 1D
ranks[sbc_idx, dim_idx] = (
(
reduce_fn(posterior_samples[:, sbc_idx, :], x_i)
< reduce_fn(true_theta[None, ...], x_i)
)
.sum()
.item()
)
return ranks
[docs]
def _prepare_reduce_functions(
reduce_fns: Union[
str,
Callable[[Array, Array], Array],
List[Callable[[Array, Array], Array]],
],
param_dim: int,
) -> List[Callable[[Array, Array], Array]]:
"""Prepare reduction functions for SBC analysis.
Parameters
----------
reduce_fns : str or Callable or List[Callable]
Function(s) to reduce parameters to 1D.
param_dim : int
Dimensionality of parameter space.
Returns
-------
List[Callable]
List of callable reduction functions.
"""
# For SBC, we simply take the marginals for each parameter dimension.
if isinstance(reduce_fns, str):
if reduce_fns != "marginals":
raise ValueError(
"`reduce_fn` must either be the string `marginals` or a Callable or a "
"List of Callables."
)
return [eval(f"lambda theta, x: theta[:, {i}]") for i in range(param_dim)]
if isinstance(reduce_fns, Callable):
return [reduce_fns]
return reduce_fns
[docs]
def check_sbc(
ranks: Array,
prior_samples: Array,
dap_samples: Array,
num_posterior_samples: int = 1000,
num_c2st_repetitions: int = 1,
) -> Dict[str, Array]:
"""Return uniformity checks and data-averaged posterior checks for SBC.
Parameters
----------
ranks : Array
Ranks for each SBC run and for each model parameter,
shape (N, dim_parameters).
prior_samples : Array
N samples from the prior.
dap_samples : Array
N samples from the data-averaged posterior.
num_posterior_samples : int, optional
Number of posterior samples used for SBC ranking. Default is 1000.
num_c2st_repetitions : int, optional
Number of times C2ST is repeated to estimate robustness. Default is 1.
Returns
-------
Dict[str, Array]
Dictionary containing:
- ks_pvals: p-values of the Kolmogorov-Smirnov test of uniformity,
one for each dim_parameters.
- c2st_ranks: C2ST accuracy between ranks and uniform baseline,
one for each dim_parameters.
- c2st_dap: C2ST accuracy between prior and DAP samples, single value.
"""
if ranks.shape[0] < 100:
warnings.warn(
"You are computing SBC checks with less than 100 samples. These checks "
"should be based on a large number of test samples theta_o, x_o. We "
"recommend using at least 100.",
stacklevel=2,
)
# Run uniformity checks
ks_pvals = check_uniformity_frequentist(ranks, num_posterior_samples)
c2st_ranks = check_uniformity_c2st(
ranks, num_posterior_samples, num_repetitions=num_c2st_repetitions
)
# Compare prior and data-averaged posterior
c2st_scores_dap = check_prior_vs_dap(prior_samples, dap_samples)
return {
"ks_pvals": ks_pvals,
"c2st_ranks": c2st_ranks,
"c2st_dap": c2st_scores_dap,
}
[docs]
def check_prior_vs_dap(prior_samples: Array, dap_samples: Array) -> Array:
"""Returns the C2ST accuracy between prior and data-averaged posterior samples.
C2ST is calculated for each dimension separately.
According to simulation-based calibration, the inference method is well-calibrated
if the data-averaged posterior samples follow the same distribution as the prior,
i.e., if the C2ST score is close to 0.5. If it is not, then this suggests that the
inference method is not well-calibrated (see Talts et al, "Simulation-based
calibration" for details).
Parameters
----------
prior_samples : Array
Samples from the prior distribution.
dap_samples : Array
Samples from the data-averaged posterior.
Returns
-------
Array
Array of C2ST scores for each parameter dimension.
"""
if prior_samples.shape != dap_samples.shape:
raise ValueError("Prior and DAP samples must have the same shape")
return jnp.array([
c2st(s1[:, None], s2[:, None])
for s1, s2 in zip(prior_samples.T, dap_samples.T, strict=False)
])
# plotting utilities
[docs]
def sbc_rank_plot(
ranks: Union[Array, np.ndarray, List[Array], List[np.ndarray]],
num_posterior_samples: int,
num_bins: Optional[int] = None,
plot_type: str = "cdf",
parameter_labels: Optional[List[str]] = None,
ranks_labels: Optional[List[str]] = None,
colors: Optional[List[str]] = None,
fig: Optional[Figure] = None,
ax: Optional[Axes] = None,
figsize: Optional[tuple] = None,
**kwargs,
) -> Tuple[Figure, Axes]:
"""Plot simulation-based calibration ranks as empirical CDFs or histograms.
Additional options can be passed via the kwargs argument, see _sbc_rank_plot.
Parameters
----------
ranks : Array or List[Array]
Array of ranks to be plotted shape (num_sbc_runs, num_parameters), or
list of Arrays when comparing several sets of ranks, e.g., set of ranks
obtained from different methods.
num_posterior_samples : int
Number of posterior samples used for ranking.
num_bins : int, optional
Number of bins used for binning the ranks. Default is num_sbc_runs / 20.
plot_type : str, optional
Type of SBC plot, histograms ("hist") or empirical cdfs ("cdf"). Default is "cdf".
parameter_labels : List[str], optional
List of labels for each parameter dimension.
ranks_labels : List[str], optional
List of labels for each set of ranks.
colors : List[str], optional
List of colors for each parameter dimension, or each set of ranks.
fig : Figure, optional
Figure object to plot in.
ax : Axes, optional
Axis object to plot in.
figsize : tuple, optional
Dimensions of figure object.
**kwargs
Additional keyword arguments passed to _sbc_rank_plot.
Returns
-------
fig : Figure
Figure object.
ax : Axes
Axis object.
"""
return _sbc_rank_plot(
ranks,
num_posterior_samples,
num_bins,
plot_type,
parameter_labels,
ranks_labels,
colors,
fig=fig,
ax=ax,
figsize=figsize,
**kwargs,
)
[docs]
def _sbc_rank_plot(
ranks: Union[Array, np.ndarray, List[Array], List[np.ndarray]],
num_posterior_samples: int,
num_bins: Optional[int] = None,
plot_type: str = "cdf",
parameter_labels: Optional[List[str]] = None,
ranks_labels: Optional[List[str]] = None,
colors: Optional[List[str]] = None,
num_repeats: int = 50,
line_alpha: float = 0.8,
show_uniform_region: bool = True,
uniform_region_alpha: float = 0.3,
xlim_offset_factor: float = 0.1,
num_cols: int = 4,
params_in_subplots: bool = False,
show_ylabel: bool = False,
sharey: bool = False,
fig: Optional[FigureBase] = None,
legend_kwargs: Optional[Dict] = None,
ax=None, # no type hint to avoid hassle with pyright. Should be `array(Axes).`
figsize: Optional[tuple] = None,
) -> Tuple[Figure, Axes]:
"""Plot simulation-based calibration ranks as empirical CDFs or histograms.
Parameters
----------
ranks : Array or List[Array]
Array of ranks to be plotted shape (num_sbc_runs, num_parameters), or
list of Arrays when comparing several sets of ranks, e.g., set of ranks
obtained from different methods.
num_posterior_samples : int
Number of posterior samples used for ranking.
num_bins : int, optional
Number of bins used for binning the ranks. Default is num_sbc_runs / 20.
plot_type : str, optional
Type of SBC plot, histograms ("hist") or empirical cdfs ("cdf"). Default is "cdf".
parameter_labels : List[str], optional
List of labels for each parameter dimension.
ranks_labels : List[str], optional
List of labels for each set of ranks.
colors : List[str], optional
List of colors for each parameter dimension, or each set of ranks.
num_repeats : int, optional
Number of repeats for each empirical CDF step (resolution). Default is 50.
line_alpha : float, optional
Alpha for cdf lines or histograms. Default is 0.8.
show_uniform_region : bool, optional
Whether to plot the region showing the cdfs expected under uniformity. Default is True.
uniform_region_alpha : float, optional
Alpha for region showing the cdfs expected under uniformity. Default is 0.3.
xlim_offset_factor : float, optional
Factor for empty space left and right of the histogram. Default is 0.1.
num_cols : int, optional
Number of subplot columns, e.g., when plotting ranks of many parameters. Default is 4.
params_in_subplots : bool, optional
Whether to show each parameter in a separate subplot, or all in one. Default is False.
show_ylabel : bool, optional
Whether to show ylabels and ticks. Default is False.
sharey : bool, optional
Whether to share the y-labels, ticks, and limits across subplots. Default is False.
fig : Figure, optional
Figure object to plot in.
legend_kwargs : Dict, optional
Kwargs for the legend.
ax : Axes, optional
Axis object, must contain as many sublpots as parameters or len(ranks).
figsize : tuple, optional
Dimensions of figure object, default (8, 5) or (len(ranks) * 4, 5).
Returns
-------
fig : Figure
Figure object.
ax : Axes
Axis object.
"""
if isinstance(ranks, (Array, np.ndarray)):
ranks_list = [ranks]
else:
assert isinstance(ranks, List)
ranks_list = ranks
for idx, rank in enumerate(ranks_list):
assert isinstance(rank, (Array, np.ndarray))
if isinstance(rank, Array):
ranks_list[idx]: np.ndarray = rank.numpy() # type: ignore
plot_types = ["hist", "cdf"]
assert plot_type in plot_types, (
f"plot type {plot_type} not implemented, use one in {plot_types}."
)
if legend_kwargs is None:
legend_kwargs = dict(loc="best", handlelength=0.8)
num_sbc_runs, num_parameters = ranks_list[0].shape
num_ranks = len(ranks_list)
# For multiple methods, and for the hist plots plot each param in a separate subplot
if num_ranks > 1 or plot_type == "hist":
params_in_subplots = True
for ranki in ranks_list:
assert ranki.shape == ranks_list[0].shape, (
"all ranks in list must have the same shape."
)
num_rows = int(np.ceil(num_parameters / num_cols))
if figsize is None:
figsize = (num_parameters * 4, num_rows * 5) if params_in_subplots else (8, 5)
if parameter_labels is None:
parameter_labels = [f"dim {i + 1}" for i in range(num_parameters)]
if ranks_labels is None:
ranks_labels = [f"rank set {i + 1}" for i in range(num_ranks)]
if num_bins is None:
# Recommendation from Talts et al.
num_bins = num_sbc_runs // 20
# Plot one row subplot for each parameter, different "methods" on top of each other.
if params_in_subplots:
if fig is None or ax is None:
fig, ax = plt.subplots(
num_rows,
min(num_parameters, num_cols),
figsize=figsize,
sharey=sharey,
)
ax = np.atleast_1d(ax) # type: ignore
else:
assert ax.size >= num_parameters, (
"There must be at least as many subplots as parameters."
)
num_rows = ax.shape[0] if ax.ndim > 1 else 1
assert ax is not None
col_idx, row_idx = 0, 0
for ii, ranki in enumerate(ranks_list):
for jj in range(num_parameters):
col_idx = jj if num_rows == 1 else jj % num_cols
row_idx = jj // num_cols
plt.sca(ax[col_idx] if num_rows == 1 else ax[row_idx, col_idx])
if plot_type == "cdf":
_plot_ranks_as_cdf(
ranki[:, jj], # type: ignore
num_bins,
num_repeats,
ranks_label=ranks_labels[ii],
color=f"C{ii}" if colors is None else colors[ii],
xlabel=f"posterior ranks {parameter_labels[jj]}",
# Show legend and ylabel only in first subplot.
show_ylabel=jj == 0,
alpha=line_alpha,
)
if ii == 0 and show_uniform_region:
_plot_cdf_region_expected_under_uniformity(
num_sbc_runs,
num_bins,
num_repeats,
alpha=uniform_region_alpha,
)
elif plot_type == "hist":
_plot_ranks_as_hist(
ranki[:, jj], # type: ignore
num_bins,
num_posterior_samples,
ranks_label=ranks_labels[ii],
color="firebrick" if colors is None else colors[ii],
xlabel=f"posterior rank {parameter_labels[jj]}",
# Show legend and ylabel only in first subplot.
show_ylabel=show_ylabel,
alpha=line_alpha,
xlim_offset_factor=xlim_offset_factor,
)
# Plot expected uniform band.
_plot_hist_region_expected_under_uniformity(
num_sbc_runs,
num_bins,
num_posterior_samples,
alpha=uniform_region_alpha,
)
# show legend only in first subplot.
if jj == 0 and ranks_labels[ii] is not None:
plt.legend(**legend_kwargs)
else:
raise ValueError(
f"plot_type {plot_type} not defined, use one in {plot_types}"
)
# Remove empty subplots.
col_idx += 1
while num_rows > 1 and col_idx < num_cols:
ax[row_idx, col_idx].axis("off")
col_idx += 1
# When there is only one set of ranks show all params in a single subplot.
else:
if fig is None or ax is None:
fig, ax = plt.subplots(1, 1, figsize=figsize)
plt.sca(ax)
ranki = ranks_list[0]
for jj in range(num_parameters):
_plot_ranks_as_cdf(
ranki[:, jj], # type: ignore
num_bins,
num_repeats,
ranks_label=parameter_labels[jj],
color=f"C{jj}" if colors is None else colors[jj],
xlabel="posterior rank",
# Plot ylabel and legend at last.
show_ylabel=jj == (num_parameters - 1),
alpha=line_alpha,
)
if show_uniform_region:
_plot_cdf_region_expected_under_uniformity(
num_sbc_runs,
num_bins,
num_repeats,
alpha=uniform_region_alpha,
)
# show legend on the last subplot.
plt.legend(**legend_kwargs)
return fig, ax # pyright: ignore[reportReturnType]
[docs]
def _plot_ranks_as_hist(
ranks: np.ndarray,
num_bins: int,
num_posterior_samples: int,
ranks_label: Optional[str] = None,
xlabel: Optional[str] = None,
color: str = "firebrick",
alpha: float = 0.8,
show_ylabel: bool = False,
num_ticks: int = 3,
xlim_offset_factor: float = 0.1,
) -> None:
"""Plot ranks as histograms on the current axis.
Parameters
----------
ranks : np.ndarray
SBC ranks in shape (num_sbc_runs, ).
num_bins : int
Number of bins for the histogram, recommendation is num_sbc_runs / 20.
num_posterior_samples : int
Number of posterior samples used for ranking.
ranks_label : str, optional
Label for the ranks, e.g., when comparing ranks of different methods.
xlabel : str, optional
Label for the current parameter.
color : str, optional
Histogram color, default from Talts et al. Default is "firebrick".
alpha : float, optional
Histogram transparency. Default is 0.8.
show_ylabel : bool, optional
Whether to show y-label "counts". Default is False.
num_ticks : int, optional
Number of ticks on the x-axis. Default is 3.
xlim_offset_factor : float, optional
Factor for empty space left and right of the histogram. Default is 0.1.
"""
xlim_offset = int(num_posterior_samples * xlim_offset_factor)
plt.hist(
ranks,
bins=num_bins,
label=ranks_label,
color=color,
alpha=alpha,
)
if show_ylabel:
plt.ylabel("counts")
else:
plt.yticks([])
plt.xlim(-xlim_offset, num_posterior_samples + xlim_offset)
plt.xticks(np.linspace(0, num_posterior_samples, num_ticks))
plt.xlabel("posterior rank" if xlabel is None else xlabel)
[docs]
def _plot_ranks_as_cdf(
ranks: np.ndarray,
num_bins: int,
num_repeats: int,
ranks_label: Optional[str] = None,
xlabel: Optional[str] = None,
color: Optional[str] = None,
alpha: float = 0.8,
show_ylabel: bool = True,
num_ticks: int = 3,
) -> None:
"""Plot ranks as empirical CDFs on the current axis.
Parameters
----------
ranks : np.ndarray
SBC ranks in shape (num_sbc_runs, ).
num_bins : int
Number of bins for the histogram, recommendation is num_sbc_runs / 20.
num_repeats : int
Number of repeats of each CDF step, i.e., resolution of the eCDF.
ranks_label : str, optional
Label for the ranks, e.g., when comparing ranks of different methods.
xlabel : str, optional
Label for the current parameter.
color : str, optional
Line color for the cdf.
alpha : float, optional
Line transparency. Default is 0.8.
show_ylabel : bool, optional
Whether to show y-label "counts". Default is True.
num_ticks : int, optional
Number of ticks on the x-axis. Default is 3.
"""
# Generate histogram of ranks.
hist, *_ = np.histogram(ranks, bins=num_bins, density=False)
# Construct empirical CDF.
histcs = hist.cumsum()
# Plot cdf and repeat each stair step
plt.plot(
np.linspace(0, num_bins, num_repeats * num_bins),
np.repeat(histcs / histcs.max(), num_repeats),
label=ranks_label,
color=color,
alpha=alpha,
)
if show_ylabel:
plt.yticks(np.linspace(0, 1, 3))
plt.ylabel("empirical CDF")
else:
# Plot ticks only
plt.yticks(np.linspace(0, 1, 3), [])
plt.ylim(0, 1)
plt.xlim(0, num_bins)
plt.xticks(np.linspace(0, num_bins, num_ticks))
plt.xlabel("posterior rank" if xlabel is None else xlabel)