gensbi.core.flow_matching#
Flow matching generative method strategy.
Implements GenerativeMethod using
optimal-transport conditional flow matching with an affine probability path.
Classes#
Flow matching strategy using affine probability paths. |
Module Contents#
- class gensbi.core.flow_matching.FlowMatchingMethod(prior=None)[source]#
Bases:
gensbi.core.generative_method.GenerativeMethodFlow matching strategy using affine probability paths.
Uses the conditional optimal-transport scheduler and an ODE or SDE solver for sampling.
- Parameters:
prior (numpyro.distributions.Distribution, optional) – Source distribution. Must implement
sample(key, shape)andlog_prob(x). Validated againstevent_shapeinbuild_path(). IfNone, a standard normal prior is constructed automatically.
Examples
>>> method = FlowMatchingMethod() >>> path = method.build_path(config={}, event_shape=(5, 1)) >>> loss = method.build_loss(path)
Using a custom numpyro prior (x has shape
(batch, dim_obs, ch_obs)):>>> import numpyro.distributions as dist >>> dim_obs, ch_obs = 3, 1 >>> prior = dist.Independent( ... dist.Normal(loc=jnp.zeros((dim_obs, ch_obs)), scale=jnp.ones((dim_obs, ch_obs))), ... reinterpreted_batch_ndims=2, ... ) >>> method = FlowMatchingMethod(prior=prior)
- build_log_prob_fn(model_wrapped, path, model_extras, step_size=0.01, method='Dopri5', atol=1e-05, rtol=1e-05, time_grid=None, solver=None, exact_divergence=True, log_prior=None, **kwargs)[source]#
Build a log-probability closure for flow matching.
Uses the continuous change-of-variables formula via
ODESolver. Only works with ODE solvers (not SDE solvers).- Parameters:
model_wrapped – The wrapped velocity field model.
path – The probability path.
model_extras (dict) – Mode-specific extras (
cond,obs_ids, etc.).step_size (float, optional) – Step size for fixed-step solvers. Default is 0.01.
method (str or diffrax solver, optional) – Integration method. Default is
"Dopri5".atol (float, optional) – Absolute tolerance for adaptive solvers.
rtol (float, optional) – Relative tolerance for adaptive solvers.
time_grid (list, optional) – Time grid. Defaults to
[1.0, 0.0].solver (tuple of (type, dict), optional) –
(SolverClass, kwargs). Must be an ODE solver.exact_divergence (bool, optional) – If
True(default), compute exact divergence via full Jacobian. IfFalse, use the Hutchinson estimator (requires a PRNGkeyat call time).log_prior (callable, optional) – Override for the prior’s
log_prob. IfNone, usesself.prior.log_prob. Used by the joint pipeline to pass a user-supplied obs-space prior.
- Returns:
log_prob_fn –
(x_1, model_extras, *, key=None) -> log_prob.- Return type:
Callable
- Raises:
NotImplementedError – If a non-ODE solver is specified.
- build_loss(path, weights=None)[source]#
Build the continuous flow matching loss.
- Parameters:
path (AffineProbPath) – The probability path.
weights (Array, optional) – Per-dimension loss weights.
- Returns:
A loss callable with uniform interface
(model, batch, condition_mask=None, model_extras=None) -> loss.- Return type:
- build_path(config, event_shape)[source]#
Build an affine probability path with the CondOT scheduler.
Also constructs or validates
self.prior.- Parameters:
config (dict) – Training configuration (unused for flow matching).
event_shape (tuple of (int, int)) –
(dim, ch)— feature and channel dimensions.
- Returns:
The probability path.
- Return type:
- Raises:
ValueError – If a user-supplied prior has a mismatched
event_shape.
- build_sampler_fn(model_wrapped, path, model_extras, step_size=0.01, method='Euler', time_grid=None, solver=None, **kwargs)[source]#
Build a sampler closure for flow matching.
Supports ODE solvers (deterministic) and SDE solvers (stochastic;
ZeroEndsSolver,NonSingularSolver). When an SDE solver is used, the sampler function accepts and splits an extra random key.- Parameters:
model_wrapped – The wrapped velocity field model.
path – The probability path.
model_extras (dict) – Mode-specific extras (
cond,obs_ids,cond_ids, etc.).step_size (float, optional) – Step size for fixed-step solvers. Default is 0.01.
method (str or diffrax solver, optional) – Integration method for the ODE/SDE solver. Default is
"Euler". Other commonly used solvers are"Dopri5"(adaptive),diffrax.Heun(), anddiffrax.Midpoint().time_grid (Array, optional) – Time grid for integration. If
None, uses[0, 1].solver (tuple of (type, dict), optional) –
(SolverClass, kwargs). Defaults to(ODESolver, {}).
- Returns:
sampler_fn – A function
(key, x_init) -> samples.- Return type:
Callable
- build_solver(model_wrapped, path, solver=None)[source]#
Instantiate a flow matching solver.
Supports both ODE solvers (
ODESolver) and SDE solvers (ZeroEndsSolver,NonSingularSolver).- Parameters:
model_wrapped – The wrapped velocity field model.
path – The probability path (unused by ODE solver, but may be needed by SDE solvers).
solver (tuple of (type, dict), optional) –
(SolverClass, kwargs). Defaults to(ODESolver, {}).
- Returns:
An instantiated solver.
- Return type:
solver_instance
- get_default_solver()[source]#
Return the default ODE solver.
- Returns:
(FMODESolver, {})- Return type:
tuple
- prepare_batch(key, x_1, path)[source]#
Sample from the prior and time for a flow matching training batch.
- Parameters:
key (jax.random.PRNGKey) – Random key.
x_1 (Array) – Clean data of shape
(batch_size, dim, ch).path (AffineProbPath) – The probability path (unused, kept for interface consistency).
- Returns:
(x_0, x_1, t)wherex_0is drawn from the prior andtis uniform in[0, 1).- Return type:
tuple