gensbi.core#
Core module for GenSBI.
Provides the GenerativeMethod strategy pattern and its concrete
implementations for flow matching, EDM diffusion, and score matching.
These strategy objects encapsulate the generative framework (path, loss,
solver, batch preparation) and are composed into mode-specific pipelines
(Conditional, Joint, Unconditional) in the recipes module.
Public API (all lazy-loaded to avoid circular imports with solver subclasses that inherit from core base classes):
from gensbi.core import FlowMatchingMethod
from gensbi.core import DiffusionEDMMethod
from gensbi.core import ScoreMatchingMethod
from gensbi.core import GenerativeMethod
Submodules#
Classes#
Strategy that encapsulates a generative framework. |
Package Contents#
- class gensbi.core.GenerativeMethod[source]#
Bases:
abc.ABCStrategy that encapsulates a generative framework.
Concrete implementations handle:
Path construction — the probability path defining the forward process
Loss creation — the training objective
Batch preparation — sampling noise / time for training batches
Solver construction — building the sampler for inference
Initial sample generation — drawing from the prior
- abstractmethod build_log_prob_fn(model_wrapped, path, model_extras, **kwargs)[source]#
Build a log-probability closure for inference.
Only supported by methods whose solver can evaluate the continuous change-of-variables formula (e.g., flow matching + ODE solver).
- Parameters:
model_wrapped – The wrapped model.
path – The probability path.
model_extras (dict) – Mode-specific extras (
cond,obs_ids,cond_ids, etc.).**kwargs – Method-specific arguments (
step_size,method, etc.).
- Returns:
log_prob_fn – A function
(x_1, model_extras) -> log_prob.- Return type:
Callable
- Raises:
NotImplementedError – If the method does not support log-probability computation.
- abstractmethod build_loss(path, weights=None)[source]#
Create the loss callable for this method.
- Parameters:
path – The probability path returned by
build_path().weights (Array, optional) – Per-dimension loss weights (e.g. for rebalancing joint parameter/observation dimensions).
- Returns:
A loss object whose
__call__computes the training loss.- Return type:
loss
- abstractmethod build_path(config, event_shape)[source]#
Create the probability path and construct the prior.
- Parameters:
config (dict) – Training configuration dictionary (may contain scheduler hyperparameters such as
sigma_min,sigma_max).event_shape (tuple of (int, int)) –
(dim, ch)— feature and channel dimensions. Used to constructself.prioras a numpyro distribution.
- Returns:
A probability path object (e.g.
AffineProbPath,EDMPath,SMPath).- Return type:
- abstractmethod build_sampler_fn(model_wrapped, path, model_extras, **kwargs)[source]#
Build a sampler closure for inference.
The pipeline should use
sample_init()to generate the initial noisex_initbefore calling the returned sampler.- Parameters:
model_wrapped – The wrapped model.
path – The probability path.
model_extras (dict) – Mode-specific extras (
cond,obs_ids,cond_ids, etc.).**kwargs – Method-specific sampler arguments (
step_size,nsteps,time_grid,solver, etc.).
- Returns:
sampler_fn – A function
(key, x_init) -> samples.- Return type:
Callable
- abstractmethod build_solver(model_wrapped, path, solver=None)[source]#
Instantiate a solver.
- Parameters:
model_wrapped – The wrapped model (velocity field or score model).
path – The probability path.
solver (tuple of (type, dict), optional) –
(SolverClass, kwargs)override. IfNone, usesget_default_solver().
- Returns:
An instantiated solver.
- Return type:
solver_instance
- abstractmethod get_default_solver()[source]#
Return the default
(SolverClass, kwargs)for this method.- Returns:
solver – A
(SolverClass, default_kwargs)pair.- Return type:
tuple
- get_extra_training_config()[source]#
Return method-specific training config defaults.
Override in subclasses to supply extra defaults (e.g.
sigma_min).- Returns:
Extra configuration parameters.
- Return type:
dict
- abstractmethod prepare_batch(key, x_1, path)[source]#
Prepare a training batch from clean data.
All methods return a uniform
(x_0, x_1, t_or_sigma)tuple wherex_0is noise sampled from the prior.- Parameters:
key (jax.random.PRNGKey) – Random key.
x_1 (Array) – Clean data batch of shape
(batch_size, dim, ch).path – The probability path.
- Returns:
batch –
(x_0, x_1, t_or_sigma)prepared training batch.- Return type:
tuple
- abstractmethod sample_init(key, nsamples)[source]#
Sample initial noise for the generative process.
Uses
self.prior(constructed duringbuild_path()).- Parameters:
key (jax.random.PRNGKey) – Random key.
nsamples (int) – Number of samples to draw.
- Returns:
x_init – Initial noise sample of shape
(nsamples, *event_shape).- Return type:
Array
- property has_custom_prior#
Whether the user provided a custom prior at construction time.