gensbi.core.score_matching#

Score matching generative method strategy.

Implements GenerativeMethod using standard score matching diffusion with reverse SDE or probability flow ODE solvers and support for VP and VE SDE formulations.

Classes#

ScoreMatchingMethod

Score matching strategy.

Module Contents#

class gensbi.core.score_matching.ScoreMatchingMethod(sde_type='VP')[source]#

Bases: gensbi.core.generative_method.GenerativeMethod

Score matching strategy.

Supports two SDE formulations via the sde_type parameter:

  • "VP" — variance-preserving (default)

  • "VE" — variance-exploding

Sampling can use either the reverse SDE (SMSDESolver, default) or the probability flow ODE (SMODESolver via ScoreToODEDrift).

Parameters:

sde_type (str, optional) – SDE type. One of "VP" or "VE". Default is "VP".

Examples

>>> method = ScoreMatchingMethod(sde_type="VP")
>>> path = method.build_path(config={"beta_min": 0.001, "beta_max": 3.0})
>>> loss = method.build_loss(path)
build_log_prob_fn(model_wrapped, path, model_extras, step_size=0.01, method='Dopri5', atol=1e-05, rtol=1e-05, nsteps=1000, solver=None, exact_divergence=True, log_prior=None, **kwargs)[source]#

Build a log-probability closure for the score matching probability flow ODE.

Uses the continuous change-of-variables formula via SMODESolver (which inherits get_log_prob from ODESolver). Only works with SMODESolver; passing SMSDESolver will raise an error.

Note

Unlike the default sampling behaviour of ScoreMatchingMethod, which integrates the reverse SDE (via SMSDESolver), log-probability evaluation requires the probability flow ODE formulation (via SMODESolver). Both formulations share the same marginal distributions, so a score model trained with the standard SM loss is mathematically valid for the probability flow ODE. When no solver is passed, SMODESolver is selected automatically.

Parameters:
  • model_wrapped – The wrapped score model.

  • path (SMPath) – The score matching path.

  • model_extras (dict) – Mode-specific extras (cond, obs_ids, etc.).

  • step_size (float, optional) – Step size for fixed-step solvers. Default is 0.01.

  • method (str or diffrax solver, optional) – Integration method. Default is "Dopri5".

  • atol (float, optional) – Absolute tolerance for adaptive solvers.

  • rtol (float, optional) – Relative tolerance for adaptive solvers.

  • nsteps (int, optional) – Number of integration steps (used for step_size calculation when a fixed-step solver is selected). Default is 1000.

  • solver (tuple of (type, dict), optional) – (SolverClass, kwargs). Must be an ODE-based solver (SMODESolver).

  • exact_divergence (bool, optional) – If True (default), compute exact divergence via full Jacobian. If False, use the Hutchinson estimator (requires a PRNG key at call time).

  • log_prior (callable, optional) – Override for the prior’s log_prob. If None, uses self.prior.log_prob. Used by the joint pipeline.

Returns:

log_prob_fn(x_1, model_extras=None, *, key=None) -> log_prob.

Return type:

Callable

Raises:

NotImplementedError – If a non-ODE solver (e.g. SMSDESolver) is specified.

build_loss(path, weights=None)[source]#

Build the score matching loss.

Wraps path.get_loss_fn() into a callable object.

Parameters:
  • path (SMPath) – The score matching path.

  • weights (Array, optional) – Per-dimension loss weights.

Returns:

A loss callable with signature (key, model, batch, condition_mask=None, model_extras=None) -> loss.

Return type:

SMLoss

build_path(config, event_shape)[source]#

Build a score matching path.

Also constructs self.prior as a numpyro distribution.

Parameters:
  • config (dict) – Training configuration. Reads scheduler hyperparameters (beta_min, beta_max for VP; sigma_min, sigma_max for VE) with sensible defaults.

  • event_shape (tuple of (int, int)) – (dim, ch) — feature and channel dimensions.

Returns:

The configured score matching path.

Return type:

SMPath

build_sampler_fn(model_wrapped, path, model_extras, nsteps=1000, method='Euler', return_intermediates=False, solver=None, **kwargs)[source]#

Build a sampler closure for score matching.

Supports SMSDESolver (reverse SDE) and SMODESolver (probability flow ODE via ScoreToODEDrift).

Parameters:
  • model_wrapped – The wrapped score model.

  • path (SMPath) – The score matching path.

  • model_extras (dict) – Mode-specific extras (cond, obs_ids, etc.).

  • nsteps (int, optional) – Number of integration steps. Default is 1000.

  • method (str or diffrax solver, optional) – Integration method. Default is "Euler".

  • return_intermediates (bool, optional) – Whether to return intermediate steps. Default is False.

  • solver (tuple of (type, dict), optional) – (SolverClass, kwargs). Defaults to (SMSDESolver, {}).

Returns:

sampler_fn – A function (key, x_init) -> samples.

Return type:

Callable

build_solver(model_wrapped, path, solver=None)[source]#

Instantiate a score matching solver.

For the reverse SDE (SMSDESolver), wraps the model and extracts prior parameters. For the probability flow ODE (SMODESolver), wraps the score model with ScoreToODEDrift first.

Parameters:
  • model_wrapped – The wrapped score model.

  • path (SMPath) – The score matching path.

  • solver (tuple of (type, dict), optional) – (SolverClass, kwargs). Defaults to (SMSDESolver, {}).

Returns:

An instantiated solver.

Return type:

solver_instance

get_default_solver()[source]#

Return the default reverse SDE solver.

Returns:

(SMSDESolver, {})

Return type:

tuple

get_extra_training_config()[source]#

Return SM-specific training config defaults.

Returns:

Scheduler defaults for the selected SDE type.

Return type:

dict

prepare_batch(key, x_1, path)[source]#

Sample noise and diffusion time for a score matching training batch.

Parameters:
  • key (jax.random.PRNGKey) – Random key.

  • x_1 (Array) – Clean data of shape (batch_size, dim, ch).

  • path (SMPath) – The score matching path.

Returns:

(x_0, x_1, t) where x_0 is standard normal noise and t has shape (batch_size, 1, 1).

Return type:

tuple

sample_init(key, nsamples)[source]#

Sample from the score matching prior.

Parameters:
  • key (jax.random.PRNGKey) – Random key.

  • nsamples (int) – Number of samples to draw.

Returns:

Sample from the prior.

Return type:

Array

prior = None[source]#
sde_type = 'VP'[source]#