Source code for gensbi.core.generative_method

"""
Abstract base class for generative method strategies.

A ``GenerativeMethod`` encapsulates the mathematical framework used for
training and sampling (e.g., flow matching, EDM diffusion, or score matching).
It is composed into mode-specific pipelines (Conditional, Joint, Unconditional)
via the strategy pattern.
"""

from abc import ABC, abstractmethod
from typing import Any, Callable, Optional, Tuple

from jax import Array


[docs] class GenerativeMethod(ABC): """Strategy that encapsulates a generative framework. Concrete implementations handle: - **Path construction** — the probability path defining the forward process - **Loss creation** — the training objective - **Batch preparation** — sampling noise / time for training batches - **Solver construction** — building the sampler for inference - **Initial sample generation** — drawing from the prior """ @property
[docs] def has_custom_prior(self): """Whether the user provided a custom prior at construction time.""" return getattr(self, "_user_prior", None) is not None
@abstractmethod
[docs] def build_path(self, config: dict, event_shape: Tuple[int, int]): """Create the probability path and construct the prior. Parameters ---------- config : dict Training configuration dictionary (may contain scheduler hyperparameters such as ``sigma_min``, ``sigma_max``). event_shape : tuple of (int, int) ``(dim, ch)`` — feature and channel dimensions. Used to construct ``self.prior`` as a numpyro distribution. Returns ------- path A probability path object (e.g. ``AffineProbPath``, ``EDMPath``, ``SMPath``). """ ... # pragma: no cover
@abstractmethod
[docs] def build_loss(self, path, weights=None): """Create the loss callable for this method. Parameters ---------- path The probability path returned by :meth:`build_path`. weights : Array, optional Per-dimension loss weights (e.g. for rebalancing joint parameter/observation dimensions). Returns ------- loss A loss object whose ``__call__`` computes the training loss. """ ... # pragma: no cover
@abstractmethod
[docs] def prepare_batch(self, key, x_1, path): """Prepare a training batch from clean data. All methods return a uniform ``(x_0, x_1, t_or_sigma)`` tuple where ``x_0`` is noise sampled from the prior. Parameters ---------- key : jax.random.PRNGKey Random key. x_1 : Array Clean data batch of shape ``(batch_size, dim, ch)``. path The probability path. Returns ------- batch : tuple ``(x_0, x_1, t_or_sigma)`` prepared training batch. """ ... # pragma: no cover
@abstractmethod
[docs] def get_default_solver(self) -> Tuple[type, dict]: """Return the default ``(SolverClass, kwargs)`` for this method. Returns ------- solver : tuple A ``(SolverClass, default_kwargs)`` pair. """ ... # pragma: no cover
@abstractmethod
[docs] def build_solver(self, model_wrapped, path, solver=None): """Instantiate a solver. Parameters ---------- model_wrapped The wrapped model (velocity field or score model). path The probability path. solver : tuple of (type, dict), optional ``(SolverClass, kwargs)`` override. If ``None``, uses :meth:`get_default_solver`. Returns ------- solver_instance An instantiated solver. """ ... # pragma: no cover
@abstractmethod
[docs] def sample_init(self, key, nsamples): """Sample initial noise for the generative process. Uses ``self.prior`` (constructed during :meth:`build_path`). Parameters ---------- key : jax.random.PRNGKey Random key. nsamples : int Number of samples to draw. Returns ------- x_init : Array Initial noise sample of shape ``(nsamples, *event_shape)``. """ ... # pragma: no cover
@abstractmethod
[docs] def build_sampler_fn(self, model_wrapped, path, model_extras, **kwargs): """Build a sampler closure for inference. The pipeline should use :meth:`sample_init` to generate the initial noise ``x_init`` before calling the returned sampler. Parameters ---------- model_wrapped The wrapped model. path The probability path. model_extras : dict Mode-specific extras (``cond``, ``obs_ids``, ``cond_ids``, etc.). **kwargs Method-specific sampler arguments (``step_size``, ``nsteps``, ``time_grid``, ``solver``, etc.). Returns ------- sampler_fn : Callable A function ``(key, x_init) -> samples``. """ ... # pragma: no cover
[docs] def build_log_prob_fn(self, model_wrapped, path, model_extras, **kwargs): """Build a log-probability closure for inference. Only supported by methods whose solver can evaluate the continuous change-of-variables formula (e.g., flow matching + ODE solver). Parameters ---------- model_wrapped The wrapped model. path The probability path. model_extras : dict Mode-specific extras (``cond``, ``obs_ids``, ``cond_ids``, etc.). **kwargs Method-specific arguments (``step_size``, ``method``, etc.). Returns ------- log_prob_fn : Callable A function ``(x_1, model_extras) -> log_prob``. Raises ------ NotImplementedError If the method does not support log-probability computation. """ raise NotImplementedError( f"{type(self).__name__} does not support log-probability computation." )
[docs] def get_extra_training_config(self) -> dict: """Return method-specific training config defaults. Override in subclasses to supply extra defaults (e.g. ``sigma_min``). Returns ------- dict Extra configuration parameters. """ return {}