gensbi.core#

Core module for GenSBI.

Provides the GenerativeMethod strategy pattern and its concrete implementations for flow matching, EDM diffusion, and score matching.

These strategy objects encapsulate the generative framework (path, loss, solver, batch preparation) and are composed into mode-specific pipelines (Conditional, Joint, Unconditional) in the recipes module.

Public API (all lazy-loaded to avoid circular imports with solver subclasses that inherit from core base classes):

from gensbi.core import FlowMatchingMethod
from gensbi.core import DiffusionEDMMethod
from gensbi.core import ScoreMatchingMethod
from gensbi.core import GenerativeMethod

Submodules#

Classes#

GenerativeMethod

Strategy that encapsulates a generative framework.

Package Contents#

class gensbi.core.GenerativeMethod[source]#

Bases: abc.ABC

Strategy that encapsulates a generative framework.

Concrete implementations handle:

  • Path construction — the probability path defining the forward process

  • Loss creation — the training objective

  • Batch preparation — sampling noise / time for training batches

  • Solver construction — building the sampler for inference

  • Initial sample generation — drawing from the prior

abstractmethod build_log_prob_fn(model_wrapped, path, model_extras, **kwargs)[source]#

Build a log-probability closure for inference.

Only supported by methods whose solver can evaluate the continuous change-of-variables formula (e.g., flow matching + ODE solver).

Parameters:
  • model_wrapped – The wrapped model.

  • path – The probability path.

  • model_extras (dict) – Mode-specific extras (cond, obs_ids, cond_ids, etc.).

  • **kwargs – Method-specific arguments (step_size, method, etc.).

Returns:

log_prob_fn – A function (x_1, model_extras) -> log_prob.

Return type:

Callable

Raises:

NotImplementedError – If the method does not support log-probability computation.

abstractmethod build_loss(path, weights=None)[source]#

Create the loss callable for this method.

Parameters:
  • path – The probability path returned by build_path().

  • weights (Array, optional) – Per-dimension loss weights (e.g. for rebalancing joint parameter/observation dimensions).

Returns:

A loss object whose __call__ computes the training loss.

Return type:

loss

abstractmethod build_path(config, event_shape)[source]#

Create the probability path and construct the prior.

Parameters:
  • config (dict) – Training configuration dictionary (may contain scheduler hyperparameters such as sigma_min, sigma_max).

  • event_shape (tuple of (int, int)) – (dim, ch) — feature and channel dimensions. Used to construct self.prior as a numpyro distribution.

Returns:

A probability path object (e.g. AffineProbPath, EDMPath, SMPath).

Return type:

path

abstractmethod build_sampler_fn(model_wrapped, path, model_extras, **kwargs)[source]#

Build a sampler closure for inference.

The pipeline should use sample_init() to generate the initial noise x_init before calling the returned sampler.

Parameters:
  • model_wrapped – The wrapped model.

  • path – The probability path.

  • model_extras (dict) – Mode-specific extras (cond, obs_ids, cond_ids, etc.).

  • **kwargs – Method-specific sampler arguments (step_size, nsteps, time_grid, solver, etc.).

Returns:

sampler_fn – A function (key, x_init) -> samples.

Return type:

Callable

abstractmethod build_solver(model_wrapped, path, solver=None)[source]#

Instantiate a solver.

Parameters:
  • model_wrapped – The wrapped model (velocity field or score model).

  • path – The probability path.

  • solver (tuple of (type, dict), optional) – (SolverClass, kwargs) override. If None, uses get_default_solver().

Returns:

An instantiated solver.

Return type:

solver_instance

abstractmethod get_default_solver()[source]#

Return the default (SolverClass, kwargs) for this method.

Returns:

solver – A (SolverClass, default_kwargs) pair.

Return type:

tuple

get_extra_training_config()[source]#

Return method-specific training config defaults.

Override in subclasses to supply extra defaults (e.g. sigma_min).

Returns:

Extra configuration parameters.

Return type:

dict

abstractmethod prepare_batch(key, x_1, path)[source]#

Prepare a training batch from clean data.

All methods return a uniform (x_0, x_1, t_or_sigma) tuple where x_0 is noise sampled from the prior.

Parameters:
  • key (jax.random.PRNGKey) – Random key.

  • x_1 (Array) – Clean data batch of shape (batch_size, dim, ch).

  • path – The probability path.

Returns:

batch(x_0, x_1, t_or_sigma) prepared training batch.

Return type:

tuple

abstractmethod sample_init(key, nsamples)[source]#

Sample initial noise for the generative process.

Uses self.prior (constructed during build_path()).

Parameters:
  • key (jax.random.PRNGKey) – Random key.

  • nsamples (int) – Number of samples to draw.

Returns:

x_init – Initial noise sample of shape (nsamples, *event_shape).

Return type:

Array

property has_custom_prior#

Whether the user provided a custom prior at construction time.