gensbi.core.score_matching#
Score matching generative method strategy.
Implements GenerativeMethod using
standard score matching diffusion with reverse SDE or probability flow ODE
solvers and support for VP and VE SDE formulations.
Classes#
Score matching strategy. |
Module Contents#
- class gensbi.core.score_matching.ScoreMatchingMethod(sde_type='VP')[source]#
Bases:
gensbi.core.generative_method.GenerativeMethodScore matching strategy.
Supports two SDE formulations via the
sde_typeparameter:"VP"— variance-preserving (default)"VE"— variance-exploding
Sampling can use either the reverse SDE (
SMSDESolver, default) or the probability flow ODE (SMODESolverviaScoreToODEDrift).- Parameters:
sde_type (str, optional) – SDE type. One of
"VP"or"VE". Default is"VP".
Examples
>>> method = ScoreMatchingMethod(sde_type="VP") >>> path = method.build_path(config={"beta_min": 0.001, "beta_max": 3.0}) >>> loss = method.build_loss(path)
- build_log_prob_fn(model_wrapped, path, model_extras, step_size=0.01, method='Dopri5', atol=1e-05, rtol=1e-05, nsteps=1000, solver=None, exact_divergence=True, log_prior=None, **kwargs)[source]#
Build a log-probability closure for the score matching probability flow ODE.
Uses the continuous change-of-variables formula via
SMODESolver(which inheritsget_log_probfromODESolver). Only works withSMODESolver; passingSMSDESolverwill raise an error.Note
Unlike the default sampling behaviour of
ScoreMatchingMethod, which integrates the reverse SDE (viaSMSDESolver), log-probability evaluation requires the probability flow ODE formulation (viaSMODESolver). Both formulations share the same marginal distributions, so a score model trained with the standard SM loss is mathematically valid for the probability flow ODE. When nosolveris passed,SMODESolveris selected automatically.- Parameters:
model_wrapped – The wrapped score model.
path (SMPath) – The score matching path.
model_extras (dict) – Mode-specific extras (
cond,obs_ids, etc.).step_size (float, optional) – Step size for fixed-step solvers. Default is 0.01.
method (str or diffrax solver, optional) – Integration method. Default is
"Dopri5".atol (float, optional) – Absolute tolerance for adaptive solvers.
rtol (float, optional) – Relative tolerance for adaptive solvers.
nsteps (int, optional) – Number of integration steps (used for step_size calculation when a fixed-step solver is selected). Default is 1000.
solver (tuple of (type, dict), optional) –
(SolverClass, kwargs). Must be an ODE-based solver (SMODESolver).exact_divergence (bool, optional) – If
True(default), compute exact divergence via full Jacobian. IfFalse, use the Hutchinson estimator (requires a PRNGkeyat call time).log_prior (callable, optional) – Override for the prior’s
log_prob. IfNone, usesself.prior.log_prob. Used by the joint pipeline.
- Returns:
log_prob_fn –
(x_1, model_extras=None, *, key=None) -> log_prob.- Return type:
Callable
- Raises:
NotImplementedError – If a non-ODE solver (e.g.
SMSDESolver) is specified.
- build_loss(path, weights=None)[source]#
Build the score matching loss.
Wraps
path.get_loss_fn()into a callable object.
- build_path(config, event_shape)[source]#
Build a score matching path.
Also constructs
self.prioras a numpyro distribution.- Parameters:
config (dict) – Training configuration. Reads scheduler hyperparameters (
beta_min,beta_maxfor VP;sigma_min,sigma_maxfor VE) with sensible defaults.event_shape (tuple of (int, int)) –
(dim, ch)— feature and channel dimensions.
- Returns:
The configured score matching path.
- Return type:
- build_sampler_fn(model_wrapped, path, model_extras, nsteps=1000, method='Euler', return_intermediates=False, solver=None, **kwargs)[source]#
Build a sampler closure for score matching.
Supports
SMSDESolver(reverse SDE) andSMODESolver(probability flow ODE viaScoreToODEDrift).- Parameters:
model_wrapped – The wrapped score model.
path (SMPath) – The score matching path.
model_extras (dict) – Mode-specific extras (
cond,obs_ids, etc.).nsteps (int, optional) – Number of integration steps. Default is 1000.
method (str or diffrax solver, optional) – Integration method. Default is
"Euler".return_intermediates (bool, optional) – Whether to return intermediate steps. Default is False.
solver (tuple of (type, dict), optional) –
(SolverClass, kwargs). Defaults to(SMSDESolver, {}).
- Returns:
sampler_fn – A function
(key, x_init) -> samples.- Return type:
Callable
- build_solver(model_wrapped, path, solver=None)[source]#
Instantiate a score matching solver.
For the reverse SDE (
SMSDESolver), wraps the model and extracts prior parameters. For the probability flow ODE (SMODESolver), wraps the score model withScoreToODEDriftfirst.- Parameters:
model_wrapped – The wrapped score model.
path (SMPath) – The score matching path.
solver (tuple of (type, dict), optional) –
(SolverClass, kwargs). Defaults to(SMSDESolver, {}).
- Returns:
An instantiated solver.
- Return type:
solver_instance
- get_default_solver()[source]#
Return the default reverse SDE solver.
- Returns:
(SMSDESolver, {})- Return type:
tuple
- get_extra_training_config()[source]#
Return SM-specific training config defaults.
- Returns:
Scheduler defaults for the selected SDE type.
- Return type:
dict
- prepare_batch(key, x_1, path)[source]#
Sample noise and diffusion time for a score matching training batch.
- Parameters:
key (jax.random.PRNGKey) – Random key.
x_1 (Array) – Clean data of shape
(batch_size, dim, ch).path (SMPath) – The score matching path.
- Returns:
(x_0, x_1, t)wherex_0is standard normal noise andthas shape(batch_size, 1, 1).- Return type:
tuple