Methods and Samplers in GenSBI#

GenSBI supports three generative methods for simulation-based inference:

  1. Flow Matching — learns a velocity field and integrates an ODE from noise to data.

  2. EDM Diffusion — learns a denoiser in σ-space (Karras et al., 2022).

  3. Score Matching — learns the score function ∇ log p_t(x) (Song et al., 2021).

Each method has a default solver and one or more alternatives that can be swapped at sample time without retraining. This example trains one model per method and demonstrates all available sampling strategies.

We use the unified ConditionalPipeline API throughout, which is model-agnostic and parameterized by a GenerativeMethod object.

# automatically install dependencies if using Colab
try: #check if we are using colab, if so install all the required software
    import google.colab
    colab=True
except:
    colab=False

if colab: # you may have to restart the runtime after installing the packages
    !uv pip install --quiet "gensbi[cuda12, examples] @ git+https://github.com/aurelio-amerio/GenSBI"
    !git clone --depth 1 https://github.com/aurelio-amerio/GenSBI-examples
    %cd GenSBI-examples/examples/methods_and_samplers
  Installing build dependencies ... ?25l?25hdone
  Getting requirements to build wheel ... ?25l?25hdone
  Preparing metadata (pyproject.toml) ... ?25l?25hdone
fatal: destination path 'GenSBI-examples' already exists and is not an empty directory.
/content/GenSBI-examples/examples/methods_and_samplers
import os

# Set JAX backend (use 'cuda' for GPU, 'cpu' otherwise)
os.environ["JAX_PLATFORMS"] = "cuda"

import grain
import numpy as np
import jax
from jax import numpy as jnp
from numpyro import distributions as dist
from flax import nnx

# Unified pipeline and generative methods
from gensbi.recipes import ConditionalPipeline
from gensbi.core import FlowMatchingMethod, DiffusionEDMMethod, ScoreMatchingMethod

# Model
from gensbi.models import Flux1, Flux1Params

# Plotting
from gensbi.utils.plotting import plot_marginals
import matplotlib.pyplot as plt

Shared Setup#

We use a simple 3D toy problem throughout: the simulator draws parameters θ from a uniform prior and produces observations x = θ + 1 + noise. This is identical to the conditional_pipeline.py example, so we can focus on the methods and samplers.

theta_prior = dist.Uniform(
    low=jnp.array([-2.0, -2.0, -2.0]), high=jnp.array([2.0, 2.0, 2.0])
)

dim_obs = 3
dim_cond = 3
dim_joint = dim_obs + dim_cond


def simulator(key, nsamples):
    theta_key, sample_key = jax.random.split(key, 2)
    thetas = theta_prior.sample(theta_key, (nsamples,))
    xs = thetas + 1 + jax.random.normal(sample_key, thetas.shape) * 0.1

    thetas = thetas[..., None]
    xs = xs[..., None]

    # For the conditional pipeline, thetas (observations) come first
    data = jnp.concatenate([thetas, xs], axis=1)
    return data
train_data = simulator(jax.random.PRNGKey(0), 100_000)
val_data = simulator(jax.random.PRNGKey(1), 2000)
# Normalizing the data to zero mean and unit variance is important for stable training.
means = jnp.mean(train_data, axis=0)
stds = jnp.std(train_data, axis=0)


def normalize(data, means, stds):
    return (data - means) / stds


def unnormalize(data, means, stds):
    return data * stds + means
# The conditional pipeline expects each batch to be a tuple of (observations, conditions).
def split_obs_cond(data):
    data = normalize(data, means, stds)
    return (
        data[:, :dim_obs],
        data[:, dim_obs:],
    )
batch_size = 256

train_dataset_grain = (
    grain.MapDataset.source(np.array(train_data))
    .shuffle(42)
    .repeat()
    .to_iter_dataset()
    .batch(batch_size)
    .map(split_obs_cond)
)

val_dataset_grain = (
    grain.MapDataset.source(np.array(val_data))
    .shuffle(42)
    .repeat()
    .to_iter_dataset()
    .batch(batch_size)
    .map(split_obs_cond)
)
# We generate one observation from the simulator and extract the true parameters
# and the conditioning data x_o. This will be reused across all methods.
new_sample = simulator(jax.random.PRNGKey(20), 1)
true_theta = new_sample[:, :dim_obs, :]

new_sample_norm = normalize(new_sample, means, stds)
x_o = new_sample_norm[:, dim_obs:, :]

# Plotting range (same for all methods)
plot_range = [(1, 3), (1, 3), (-0.6, 0.5)]

Section 1: Flow Matching#

Flow Matching learns a velocity field \(v_\theta(t, x)\) that transports samples from a simple prior (Gaussian noise at \(t=0\)) to the data distribution (at \(t=1\)) via an ordinary differential equation (ODE).

  • Default solver: ODESolver — deterministic ODE integration (Euler or Dopri5).

  • Alternative solvers (SDE-based, stochastic):

    • ZeroEndsSolver — diffusion vanishes at both time endpoints (arXiv:2410.02217).

    • NonSingularSolver — non-singular diffusion coefficient (arXiv:2410.02217).

The SDE solvers can sometimes improve sample diversity at the cost of additional stochasticity. They require prior statistics (mu0, sigma0) and a diffusion strength parameter alpha.

params_fm = Flux1Params(
    in_channels=1,
    vec_in_dim=None,
    context_in_dim=1,
    mlp_ratio=3,
    num_heads=2,
    depth=4,
    depth_single_blocks=8,
    axes_dim=[10],
    qkv_bias=True,
    dim_obs=dim_obs,
    dim_cond=dim_cond,
    id_embedding_strategy=("absolute", "absolute"),
    theta=10 * dim_joint,
    rngs=nnx.Rngs(default=42),
    param_dtype=jnp.float32,
)

model_fm = Flux1(params_fm)

method_fm = FlowMatchingMethod()

training_config_fm = ConditionalPipeline.get_default_training_config()
training_config_fm["nsteps"] = 10000
training_config_fm["checkpoint_dir"] = os.path.join(os.getcwd(), "checkpoints", "flow")

pipeline_fm = ConditionalPipeline(
    model_fm,
    train_dataset_grain,
    val_dataset_grain,
    dim_obs=dim_obs,
    dim_cond=dim_cond,
    method=method_fm,
    training_config=training_config_fm,
)
# Uncomment the following lines to train the model.
# Once trained, the model is saved to checkpoints/flow and can be restored below.
rngs = nnx.Rngs(42)
# pipeline_fm.train(rngs, save_model=True)
pipeline_fm.restore_model()
WARNING:absl:CheckpointManagerOptions.read_only=True, setting save_interval_steps=0.
WARNING:absl:CheckpointManagerOptions.read_only=True, setting create=False.
WARNING:absl:Given directory is read only=/content/GenSBI-examples/examples/methods_and_samplers/checkpoints/flow
WARNING:absl:CheckpointManagerOptions.read_only=True, setting save_interval_steps=0.
WARNING:absl:CheckpointManagerOptions.read_only=True, setting create=False.
WARNING:absl:Given directory is read only=/content/GenSBI-examples/examples/methods_and_samplers/checkpoints/flow/ema
Restored model from checkpoint
# The default solver for flow matching is the ODESolver, which performs deterministic
# ODE integration from noise (t=0) to data (t=1).
samples_fm = pipeline_fm.sample(rngs.sample(), x_o, nsamples=100_000)
samples_fm = unnormalize(samples_fm, means[:dim_obs], stds[:dim_obs])
plot_marginals(
    np.array(samples_fm[..., 0]),
    gridsize=30,
    true_param=np.array(true_theta[0, :, 0]),
    range=plot_range,
)
plt.suptitle("Flow Matching — ODE Solver (default)", y=1.02)
plt.savefig("fm_ode_marginals.png", dpi=100, bbox_inches="tight")
plt.show()
<Figure size 640x480 with 0 Axes>
../_images/a110edea36339a72930517289c9e062de50e632e4693c0f2e6659948e5bac677.png
# The ZeroEndsSolver adds stochastic noise during sampling. The diffusion coefficient
# vanishes at both t=0 and t=1, ensuring clean endpoints.
# Required kwargs: mu0 (prior mean), sigma0 (prior std), alpha (diffusion strength).
from gensbi.flow_matching.solver import ZeroEndsSolver

solver_kwargs_ze = {
    "mu0": jnp.zeros((dim_obs, 1)),  # prior mean (data is normalized)
    "sigma0": jnp.ones((dim_obs, 1)),  # prior std
    "alpha": 0.2,  # diffusion strength
}

samples_fm_ze = pipeline_fm.sample(
    rngs.sample(),
    x_o,
    nsamples=100_000,
    solver=(ZeroEndsSolver, solver_kwargs_ze),
)
samples_fm_ze = unnormalize(samples_fm_ze, means[:dim_obs], stds[:dim_obs])
plot_marginals(
    np.array(samples_fm_ze[..., 0]),
    gridsize=30,
    true_param=np.array(true_theta[0, :, 0]),
    range=plot_range,
)
plt.suptitle("Flow Matching — ZeroEndsSolver (SDE)", y=1.02)
plt.savefig("fm_zeroends_marginals.png", dpi=100, bbox_inches="tight")
plt.show()
<Figure size 640x480 with 0 Axes>
../_images/8a8ea214909d6bd530a2918bb1e197363f010de8caeb076a22908bf6019753a3.png
# The NonSingularSolver uses a non-singular diffusion coefficient, which can provide
# different sample quality characteristics compared to ZeroEndsSolver.
# It takes the same kwargs as ZeroEndsSolver.
from gensbi.flow_matching.solver import NonSingularSolver

solver_kwargs_ns = {
    "mu0": jnp.zeros((dim_obs, 1)),
    "sigma0": jnp.ones((dim_obs, 1)),
    "alpha": 0.2,
}

samples_fm_ns = pipeline_fm.sample(
    rngs.sample(),
    x_o,
    nsamples=100_000,
    solver=(NonSingularSolver, solver_kwargs_ns),
)
samples_fm_ns = unnormalize(samples_fm_ns, means[:dim_obs], stds[:dim_obs])
plot_marginals(
    np.array(samples_fm_ns[..., 0]),
    gridsize=30,
    true_param=np.array(true_theta[0, :, 0]),
    range=plot_range,
)
plt.suptitle("Flow Matching — NonSingularSolver (SDE)", y=1.02)
plt.savefig("fm_nonsingular_marginals.png", dpi=100, bbox_inches="tight")
plt.show()
<Figure size 640x480 with 0 Axes>
../_images/7d7d963f201e0136d5b1d3eea28da05c9a67a72bf4f84c29f2ecb5a7ca92c016.png
del model_fm, pipeline_fm

Section 2: EDM Diffusion#

EDM Diffusion (Karras et al., 2022) learns a denoiser \(D_\theta(x; \sigma)\) in \(\sigma\)-space. The training noise schedule can use one of three prescriptions:

  • DiffusionEDMMethod() — default EDM scheduler (recommended)

  • DiffusionEDMMethod(sde="VP") — Variance Preserving scheduler

  • DiffusionEDMMethod(sde="VE") — Variance Exploding scheduler

The model can be trained with any of these three prescriptions, and then sampled using any of the three as well. However, the EDM scheduler is recommended for both training and sampling. The scheduler variants are training-time choices that affect the noise schedule used during the diffusion process.

Solver: EDMSolver is the only available solver for EDM. It implements the stochastic denoising sampler from Karras et al., 2022.

params_edm = Flux1Params(
    in_channels=1,
    vec_in_dim=None,
    context_in_dim=1,
    mlp_ratio=3,
    num_heads=2,
    depth=4,
    depth_single_blocks=8,
    axes_dim=[10],
    qkv_bias=True,
    dim_obs=dim_obs,
    dim_cond=dim_cond,
    id_embedding_strategy=("absolute", "absolute"),
    theta=10 * dim_joint,
    rngs=nnx.Rngs(default=42),
    param_dtype=jnp.float32,
)

model_edm = Flux1(params_edm)

# Default EDM scheduler (recommended for both training and sampling)
method_edm = DiffusionEDMMethod()
# Alternative training schedulers (uncomment to use):
# method_edm = DiffusionEDMMethod(sde="VP")  # Variance Preserving
# method_edm = DiffusionEDMMethod(sde="VE")  # Variance Exploding

training_config_edm = ConditionalPipeline.get_default_training_config()
training_config_edm["nsteps"] = 10000
training_config_edm["checkpoint_dir"] = os.path.join(os.getcwd(), "checkpoints", "edm")

pipeline_edm = ConditionalPipeline(
    model_edm,
    train_dataset_grain,
    val_dataset_grain,
    dim_obs=dim_obs,
    dim_cond=dim_cond,
    method=method_edm,
    training_config=training_config_edm,
)
# Uncomment the following lines to train the model.
# Once trained, the model is saved to checkpoints/edm and can be restored below.
rngs = nnx.Rngs(42)
# pipeline_edm.train(rngs, save_model=True)
pipeline_edm.restore_model()
WARNING:absl:CheckpointManagerOptions.read_only=True, setting save_interval_steps=0.
WARNING:absl:CheckpointManagerOptions.read_only=True, setting create=False.
WARNING:absl:Given directory is read only=/content/GenSBI-examples/examples/methods_and_samplers/checkpoints/edm
WARNING:absl:CheckpointManagerOptions.read_only=True, setting save_interval_steps=0.
WARNING:absl:CheckpointManagerOptions.read_only=True, setting create=False.
WARNING:absl:Given directory is read only=/content/GenSBI-examples/examples/methods_and_samplers/checkpoints/edm/ema
Restored model from checkpoint
# The EDMSolver implements the stochastic denoising sampler from Karras et al., 2022.
# It progressively denoises samples following a noise schedule from high to low sigma.
samples_edm = pipeline_edm.sample(rngs.sample(), x_o, nsamples=100_000)
samples_edm = unnormalize(samples_edm, means[:dim_obs], stds[:dim_obs])
plot_marginals(
    np.array(samples_edm[..., 0]),
    gridsize=30,
    true_param=np.array(true_theta[0, :, 0]),
    range=plot_range,
)
plt.suptitle("EDM Diffusion — EDMSolver (default)", y=1.02)
plt.savefig("edm_marginals.png", dpi=100, bbox_inches="tight")
plt.show()
<Figure size 640x480 with 0 Axes>
../_images/82b6a3b1a499d34014f425b6a4c6913502ea14b54719f3a73232ebc6da4596aa.png
del model_edm, pipeline_edm

Section 3: Score Matching#

Score Matching (Song et al., 2021) learns the score function \(\nabla \log p_t(x)\), which points toward regions of higher data density. Samples are generated by running a reverse-time SDE from noise back to data.

The SDE formulation can be either:

  • ScoreMatchingMethod() — Variance Preserving (VP) SDE (default)

  • ScoreMatchingMethod(sde_type="VE") — Variance Exploding (VE) SDE

Solvers:

  • SMSolver (default) — reverse-time SDE, generates stochastic samples.

  • SMPFSolver — probability flow ODE, generates deterministic samples from the same learned score function. Useful when reproducibility or lower variance is desired.

params_sm = Flux1Params(
    in_channels=1,
    vec_in_dim=None,
    context_in_dim=1,
    mlp_ratio=3,
    num_heads=2,
    depth=4,
    depth_single_blocks=8,
    axes_dim=[10],
    qkv_bias=True,
    dim_obs=dim_obs,
    dim_cond=dim_cond,
    id_embedding_strategy=("absolute", "absolute"),
    theta=10 * dim_joint,
    rngs=nnx.Rngs(default=42),
    param_dtype=jnp.float32,
)

model_sm = Flux1(params_sm)

# Default: Variance Preserving (VP) SDE
method_sm = ScoreMatchingMethod()
# Alternative: Variance Exploding (VE) SDE (uncomment to use)
# method_sm = ScoreMatchingMethod(sde_type="VE")

training_config_sm = ConditionalPipeline.get_default_training_config()
training_config_sm["nsteps"] = 50000
training_config_sm["checkpoint_dir"] = os.path.join(os.getcwd(), "checkpoints", "sm")

pipeline_sm = ConditionalPipeline(
    model_sm,
    train_dataset_grain,
    val_dataset_grain,
    dim_obs=dim_obs,
    dim_cond=dim_cond,
    method=method_sm,
    training_config=training_config_sm,
)
# Uncomment the following lines to train the model.
# Once trained, the model is saved to checkpoints/sm and can be restored below.
rngs = nnx.Rngs(42)
# pipeline_sm.train(rngs, save_model=True)
pipeline_sm.restore_model()
WARNING:absl:CheckpointManagerOptions.read_only=True, setting save_interval_steps=0.
WARNING:absl:CheckpointManagerOptions.read_only=True, setting create=False.
WARNING:absl:Given directory is read only=/content/GenSBI-examples/examples/methods_and_samplers/checkpoints/sm
WARNING:absl:CheckpointManagerOptions.read_only=True, setting save_interval_steps=0.
WARNING:absl:CheckpointManagerOptions.read_only=True, setting create=False.
WARNING:absl:Given directory is read only=/content/GenSBI-examples/examples/methods_and_samplers/checkpoints/sm/ema
Restored model from checkpoint
# The default solver for score matching generates stochastic samples by running the
# reverse-time SDE. Each call with a different key produces a different set of samples.
samples_sm = pipeline_sm.sample(rngs.sample(), x_o, nsamples=100_000)
samples_sm = unnormalize(samples_sm, means[:dim_obs], stds[:dim_obs])
plot_marginals(
    np.array(samples_sm[..., 0]),
    gridsize=30,
    true_param=np.array(true_theta[0, :, 0]),
    range=plot_range,
)
plt.suptitle("Score Matching — SMSolver (default, reverse SDE)", y=1.02)
plt.savefig("sm_sde_marginals.png", dpi=100, bbox_inches="tight")
plt.show()
<Figure size 640x480 with 0 Axes>
../_images/f894b63dd6f9455c97bfcf87ff33c88fb0831a1f6a88273810d5008a80419455.png
# The probability flow ODE produces deterministic samples from the same learned score
# function. This can be useful when you want reproducible results or lower variance.
from gensbi.diffusion.solver import SMPFSolver

samples_sm_pf = pipeline_sm.sample(
    rngs.sample(),
    x_o,
    nsamples=100_000,
    solver=(SMPFSolver, {}),
)
samples_sm_pf = unnormalize(samples_sm_pf, means[:dim_obs], stds[:dim_obs])
plot_marginals(
    np.array(samples_sm_pf[..., 0]),
    gridsize=30,
    true_param=np.array(true_theta[0, :, 0]),
    range=plot_range,
)
plt.suptitle("Score Matching — SMPFSolver (probability flow ODE)", y=1.02)
plt.savefig("sm_pf_marginals.png", dpi=100, bbox_inches="tight")
plt.show()
<Figure size 640x480 with 0 Axes>
../_images/97aa417410c133d61b644e592457251d4b99698d7a78a31941fe7c3a5890fa0b.png
del model_sm, pipeline_sm