gensbi.core.generative_method#

Abstract base class for generative method strategies.

A GenerativeMethod encapsulates the mathematical framework used for training and sampling (e.g., flow matching, EDM diffusion, or score matching). It is composed into mode-specific pipelines (Conditional, Joint, Unconditional) via the strategy pattern.

Classes#

GenerativeMethod

Strategy that encapsulates a generative framework.

Module Contents#

class gensbi.core.generative_method.GenerativeMethod[source]#

Bases: abc.ABC

Strategy that encapsulates a generative framework.

Concrete implementations handle:

  • Path construction — the probability path defining the forward process

  • Loss creation — the training objective

  • Batch preparation — sampling noise / time for training batches

  • Solver construction — building the sampler for inference

  • Initial sample generation — drawing from the prior

abstractmethod build_log_prob_fn(model_wrapped, path, model_extras, **kwargs)[source]#

Build a log-probability closure for inference.

Only supported by methods whose solver can evaluate the continuous change-of-variables formula (e.g., flow matching + ODE solver).

Parameters:
  • model_wrapped – The wrapped model.

  • path – The probability path.

  • model_extras (dict) – Mode-specific extras (cond, obs_ids, cond_ids, etc.).

  • **kwargs – Method-specific arguments (step_size, method, etc.).

Returns:

log_prob_fn – A function (x_1, model_extras) -> log_prob.

Return type:

Callable

Raises:

NotImplementedError – If the method does not support log-probability computation.

abstractmethod build_loss(path, weights=None)[source]#

Create the loss callable for this method.

Parameters:
  • path – The probability path returned by build_path().

  • weights (Array, optional) – Per-dimension loss weights (e.g. for rebalancing joint parameter/observation dimensions).

Returns:

A loss object whose __call__ computes the training loss.

Return type:

loss

abstractmethod build_path(config, event_shape)[source]#

Create the probability path and construct the prior.

Parameters:
  • config (dict) – Training configuration dictionary (may contain scheduler hyperparameters such as sigma_min, sigma_max).

  • event_shape (tuple of (int, int)) – (dim, ch) — feature and channel dimensions. Used to construct self.prior as a numpyro distribution.

Returns:

A probability path object (e.g. AffineProbPath, EDMPath, SMPath).

Return type:

path

abstractmethod build_sampler_fn(model_wrapped, path, model_extras, **kwargs)[source]#

Build a sampler closure for inference.

The pipeline should use sample_init() to generate the initial noise x_init before calling the returned sampler.

Parameters:
  • model_wrapped – The wrapped model.

  • path – The probability path.

  • model_extras (dict) – Mode-specific extras (cond, obs_ids, cond_ids, etc.).

  • **kwargs – Method-specific sampler arguments (step_size, nsteps, time_grid, solver, etc.).

Returns:

sampler_fn – A function (key, x_init) -> samples.

Return type:

Callable

abstractmethod build_solver(model_wrapped, path, solver=None)[source]#

Instantiate a solver.

Parameters:
  • model_wrapped – The wrapped model (velocity field or score model).

  • path – The probability path.

  • solver (tuple of (type, dict), optional) – (SolverClass, kwargs) override. If None, uses get_default_solver().

Returns:

An instantiated solver.

Return type:

solver_instance

abstractmethod get_default_solver()[source]#

Return the default (SolverClass, kwargs) for this method.

Returns:

solver – A (SolverClass, default_kwargs) pair.

Return type:

tuple

get_extra_training_config()[source]#

Return method-specific training config defaults.

Override in subclasses to supply extra defaults (e.g. sigma_min).

Returns:

Extra configuration parameters.

Return type:

dict

abstractmethod prepare_batch(key, x_1, path)[source]#

Prepare a training batch from clean data.

All methods return a uniform (x_0, x_1, t_or_sigma) tuple where x_0 is noise sampled from the prior.

Parameters:
  • key (jax.random.PRNGKey) – Random key.

  • x_1 (Array) – Clean data batch of shape (batch_size, dim, ch).

  • path – The probability path.

Returns:

batch(x_0, x_1, t_or_sigma) prepared training batch.

Return type:

tuple

abstractmethod sample_init(key, nsamples)[source]#

Sample initial noise for the generative process.

Uses self.prior (constructed during build_path()).

Parameters:
  • key (jax.random.PRNGKey) – Random key.

  • nsamples (int) – Number of samples to draw.

Returns:

x_init – Initial noise sample of shape (nsamples, *event_shape).

Return type:

Array

property has_custom_prior[source]#

Whether the user provided a custom prior at construction time.