gensbi.core.flow_matching#

Flow matching generative method strategy.

Implements GenerativeMethod using optimal-transport conditional flow matching with an affine probability path.

Classes#

FlowMatchingMethod

Flow matching strategy using affine probability paths.

Module Contents#

class gensbi.core.flow_matching.FlowMatchingMethod(prior=None)[source]#

Bases: gensbi.core.generative_method.GenerativeMethod

Flow matching strategy using affine probability paths.

Uses the conditional optimal-transport scheduler and an ODE or SDE solver for sampling.

Parameters:

prior (numpyro.distributions.Distribution, optional) – Source distribution. Must implement sample(key, shape) and log_prob(x). Validated against event_shape in build_path(). If None, a standard normal prior is constructed automatically.

Examples

>>> method = FlowMatchingMethod()
>>> path = method.build_path(config={}, event_shape=(5, 1))
>>> loss = method.build_loss(path)

Using a custom numpyro prior (x has shape (batch, dim_obs, ch_obs)):

>>> import numpyro.distributions as dist
>>> dim_obs, ch_obs = 3, 1
>>> prior = dist.Independent(
...     dist.Normal(loc=jnp.zeros((dim_obs, ch_obs)), scale=jnp.ones((dim_obs, ch_obs))),
...     reinterpreted_batch_ndims=2,
... )
>>> method = FlowMatchingMethod(prior=prior)
build_log_prob_fn(model_wrapped, path, model_extras, step_size=0.01, method='Dopri5', atol=1e-05, rtol=1e-05, time_grid=None, solver=None, exact_divergence=True, log_prior=None, **kwargs)[source]#

Build a log-probability closure for flow matching.

Uses the continuous change-of-variables formula via ODESolver. Only works with ODE solvers (not SDE solvers).

Parameters:
  • model_wrapped – The wrapped velocity field model.

  • path – The probability path.

  • model_extras (dict) – Mode-specific extras (cond, obs_ids, etc.).

  • step_size (float, optional) – Step size for fixed-step solvers. Default is 0.01.

  • method (str or diffrax solver, optional) – Integration method. Default is "Dopri5".

  • atol (float, optional) – Absolute tolerance for adaptive solvers.

  • rtol (float, optional) – Relative tolerance for adaptive solvers.

  • time_grid (list, optional) – Time grid. Defaults to [1.0, 0.0].

  • solver (tuple of (type, dict), optional) – (SolverClass, kwargs). Must be an ODE solver.

  • exact_divergence (bool, optional) – If True (default), compute exact divergence via full Jacobian. If False, use the Hutchinson estimator (requires a PRNG key at call time).

  • log_prior (callable, optional) – Override for the prior’s log_prob. If None, uses self.prior.log_prob. Used by the joint pipeline to pass a user-supplied obs-space prior.

Returns:

log_prob_fn(x_1, model_extras, *, key=None) -> log_prob.

Return type:

Callable

Raises:

NotImplementedError – If a non-ODE solver is specified.

build_loss(path, weights=None)[source]#

Build the continuous flow matching loss.

Parameters:
  • path (AffineProbPath) – The probability path.

  • weights (Array, optional) – Per-dimension loss weights.

Returns:

A loss callable with uniform interface (model, batch, condition_mask=None, model_extras=None) -> loss.

Return type:

FMLoss

build_path(config, event_shape)[source]#

Build an affine probability path with the CondOT scheduler.

Also constructs or validates self.prior.

Parameters:
  • config (dict) – Training configuration (unused for flow matching).

  • event_shape (tuple of (int, int)) – (dim, ch) — feature and channel dimensions.

Returns:

The probability path.

Return type:

AffineProbPath

Raises:

ValueError – If a user-supplied prior has a mismatched event_shape.

build_sampler_fn(model_wrapped, path, model_extras, step_size=0.01, method='Euler', time_grid=None, solver=None, **kwargs)[source]#

Build a sampler closure for flow matching.

Supports ODE solvers (deterministic) and SDE solvers (stochastic; ZeroEndsSolver, NonSingularSolver). When an SDE solver is used, the sampler function accepts and splits an extra random key.

Parameters:
  • model_wrapped – The wrapped velocity field model.

  • path – The probability path.

  • model_extras (dict) – Mode-specific extras (cond, obs_ids, cond_ids, etc.).

  • step_size (float, optional) – Step size for fixed-step solvers. Default is 0.01.

  • method (str or diffrax solver, optional) – Integration method for the ODE/SDE solver. Default is "Euler". Other commonly used solvers are "Dopri5" (adaptive), diffrax.Heun(), and diffrax.Midpoint().

  • time_grid (Array, optional) – Time grid for integration. If None, uses [0, 1].

  • solver (tuple of (type, dict), optional) – (SolverClass, kwargs). Defaults to (ODESolver, {}).

Returns:

sampler_fn – A function (key, x_init) -> samples.

Return type:

Callable

build_solver(model_wrapped, path, solver=None)[source]#

Instantiate a flow matching solver.

Supports both ODE solvers (ODESolver) and SDE solvers (ZeroEndsSolver, NonSingularSolver).

Parameters:
  • model_wrapped – The wrapped velocity field model.

  • path – The probability path (unused by ODE solver, but may be needed by SDE solvers).

  • solver (tuple of (type, dict), optional) – (SolverClass, kwargs). Defaults to (ODESolver, {}).

Returns:

An instantiated solver.

Return type:

solver_instance

get_default_solver()[source]#

Return the default ODE solver.

Returns:

(FMODESolver, {})

Return type:

tuple

prepare_batch(key, x_1, path)[source]#

Sample from the prior and time for a flow matching training batch.

Parameters:
  • key (jax.random.PRNGKey) – Random key.

  • x_1 (Array) – Clean data of shape (batch_size, dim, ch).

  • path (AffineProbPath) – The probability path (unused, kept for interface consistency).

Returns:

(x_0, x_1, t) where x_0 is drawn from the prior and t is uniform in [0, 1).

Return type:

tuple

sample_init(key, nsamples)[source]#

Sample from the prior distribution.

Parameters:
  • key (jax.random.PRNGKey) – Random key.

  • nsamples (int) – Number of samples to draw.

Returns:

Sample from the prior.

Return type:

Array

_user_prior = None[source]#
prior = None[source]#