gensbi.core.generative_method#
Abstract base class for generative method strategies.
A GenerativeMethod encapsulates the mathematical framework used for
training and sampling (e.g., flow matching, EDM diffusion, or score matching).
It is composed into mode-specific pipelines (Conditional, Joint, Unconditional)
via the strategy pattern.
Classes#
Strategy that encapsulates a generative framework. |
Module Contents#
- class gensbi.core.generative_method.GenerativeMethod[source]#
Bases:
abc.ABCStrategy that encapsulates a generative framework.
Concrete implementations handle:
Path construction — the probability path defining the forward process
Loss creation — the training objective
Batch preparation — sampling noise / time for training batches
Solver construction — building the sampler for inference
Initial sample generation — drawing from the prior
- abstractmethod build_log_prob_fn(model_wrapped, path, model_extras, **kwargs)[source]#
Build a log-probability closure for inference.
Only supported by methods whose solver can evaluate the continuous change-of-variables formula (e.g., flow matching + ODE solver).
- Parameters:
model_wrapped – The wrapped model.
path – The probability path.
model_extras (dict) – Mode-specific extras (
cond,obs_ids,cond_ids, etc.).**kwargs – Method-specific arguments (
step_size,method, etc.).
- Returns:
log_prob_fn – A function
(x_1, model_extras) -> log_prob.- Return type:
Callable
- Raises:
NotImplementedError – If the method does not support log-probability computation.
- abstractmethod build_loss(path, weights=None)[source]#
Create the loss callable for this method.
- Parameters:
path – The probability path returned by
build_path().weights (Array, optional) – Per-dimension loss weights (e.g. for rebalancing joint parameter/observation dimensions).
- Returns:
A loss object whose
__call__computes the training loss.- Return type:
loss
- abstractmethod build_path(config, event_shape)[source]#
Create the probability path and construct the prior.
- Parameters:
config (dict) – Training configuration dictionary (may contain scheduler hyperparameters such as
sigma_min,sigma_max).event_shape (tuple of (int, int)) –
(dim, ch)— feature and channel dimensions. Used to constructself.prioras a numpyro distribution.
- Returns:
A probability path object (e.g.
AffineProbPath,EDMPath,SMPath).- Return type:
- abstractmethod build_sampler_fn(model_wrapped, path, model_extras, **kwargs)[source]#
Build a sampler closure for inference.
The pipeline should use
sample_init()to generate the initial noisex_initbefore calling the returned sampler.- Parameters:
model_wrapped – The wrapped model.
path – The probability path.
model_extras (dict) – Mode-specific extras (
cond,obs_ids,cond_ids, etc.).**kwargs – Method-specific sampler arguments (
step_size,nsteps,time_grid,solver, etc.).
- Returns:
sampler_fn – A function
(key, x_init) -> samples.- Return type:
Callable
- abstractmethod build_solver(model_wrapped, path, solver=None)[source]#
Instantiate a solver.
- Parameters:
model_wrapped – The wrapped model (velocity field or score model).
path – The probability path.
solver (tuple of (type, dict), optional) –
(SolverClass, kwargs)override. IfNone, usesget_default_solver().
- Returns:
An instantiated solver.
- Return type:
solver_instance
- abstractmethod get_default_solver()[source]#
Return the default
(SolverClass, kwargs)for this method.- Returns:
solver – A
(SolverClass, default_kwargs)pair.- Return type:
tuple
- get_extra_training_config()[source]#
Return method-specific training config defaults.
Override in subclasses to supply extra defaults (e.g.
sigma_min).- Returns:
Extra configuration parameters.
- Return type:
dict
- abstractmethod prepare_batch(key, x_1, path)[source]#
Prepare a training batch from clean data.
All methods return a uniform
(x_0, x_1, t_or_sigma)tuple wherex_0is noise sampled from the prior.- Parameters:
key (jax.random.PRNGKey) – Random key.
x_1 (Array) – Clean data batch of shape
(batch_size, dim, ch).path – The probability path.
- Returns:
batch –
(x_0, x_1, t_or_sigma)prepared training batch.- Return type:
tuple
- abstractmethod sample_init(key, nsamples)[source]#
Sample initial noise for the generative process.
Uses
self.prior(constructed duringbuild_path()).- Parameters:
key (jax.random.PRNGKey) – Random key.
nsamples (int) – Number of samples to draw.
- Returns:
x_init – Initial noise sample of shape
(nsamples, *event_shape).- Return type:
Array