"""
Pipeline for training and using a Unconditional model for simulation-based inference.
"""
import jax
import jax.numpy as jnp
from flax import nnx
from numpyro import distributions as dist
from gensbi.flow_matching.path import AffineProbPath
from gensbi.flow_matching.path.scheduler import CondOTScheduler
from gensbi.diffusion.path import EDMPath
from gensbi.diffusion.path.scheduler import EDMScheduler, VEEdmScheduler, VPEdmScheduler
from gensbi.diffusion.solver import EDMSolver
from gensbi.diffusion.path.sm_path import SMPath
from gensbi.diffusion.path.scheduler import VPSmScheduler, VESmScheduler
from gensbi.models import UnconditionalWrapper
from gensbi.recipes.utils import init_ids_1d, build_edm_path, build_sm_path
from einops import repeat
from gensbi.utils.model_wrapping import _expand_dims
from gensbi.recipes.pipeline import AbstractPipeline
# ---------------------------------------------------------------------------
# Unified UnconditionalPipeline (Phase 2)
# ---------------------------------------------------------------------------
from gensbi.core.generative_method import GenerativeMethod
[docs]
class UnconditionalPipeline(AbstractPipeline):
"""Model-agnostic unconditional pipeline parameterized by a ``GenerativeMethod``.
Unlike the old method-specific pipeline classes, this
class works with **any** generative method and **any** user-provided
model that conforms to the ``UnconditionalWrapper`` interface.
Parameters
----------
model : nnx.Module
The model to be trained.
train_dataset : iterable
Training dataset yielding ``x_1`` batches (not tuples).
val_dataset : iterable
Validation dataset.
dim_obs : int
Dimension of the data space.
method : GenerativeMethod
Strategy object (e.g. ``FlowMatchingMethod()``,
``DiffusionEDMMethod()``, ``ScoreMatchingMethod()``).
ch_obs : int, optional
Number of channels per token. Default is 1.
params : optional
Model parameters (stored but not used directly).
training_config : dict, optional
Training configuration.
Examples
--------
>>> from gensbi.core import FlowMatchingMethod
>>> pipeline = UnconditionalPipeline(
... model=my_model,
... train_dataset=train_ds,
... val_dataset=val_ds,
... dim_obs=9,
... method=FlowMatchingMethod(),
... )
"""
def __init__(
self,
model,
train_dataset,
val_dataset,
dim_obs: int,
method: GenerativeMethod,
ch_obs=1,
params=None,
training_config=None,
):
if training_config is None:
training_config = self.get_default_training_config()
extra = method.get_extra_training_config()
for k, v in extra.items():
training_config.setdefault(k, v)
super().__init__(
model=model,
train_dataset=train_dataset,
val_dataset=val_dataset,
dim_obs=dim_obs,
dim_cond=0,
ch_obs=ch_obs,
params=params,
training_config=training_config,
)
self.obs_ids, self.dim_obs = init_ids_1d(self.dim_obs)
[docs]
self.path = method.build_path(self.training_config, event_shape=(self.dim_obs, self.ch_obs))
[docs]
self.loss_obj = method.build_loss(self.path)
# -- Factory stubs ------------------------------------------------------
@classmethod
[docs]
def init_pipeline_from_config(cls, *args, **kwargs):
raise NotImplementedError(
"UnconditionalPipeline is model-agnostic. "
"Use model-specific pipelines for config init."
)
[docs]
def _make_model(self):
raise NotImplementedError(
"UnconditionalPipeline is model-agnostic — the user provides the model."
)
@classmethod
[docs]
def get_default_params(cls, *args, **kwargs):
raise NotImplementedError(
"UnconditionalPipeline is model-agnostic — the user provides model params."
)
# -- Core pipeline methods ----------------------------------------------
[docs]
def get_loss_fn(self):
def loss_fn(model, batch, key):
x_1 = batch
prepared = self.method.prepare_batch(key, x_1, self.path)
model_extras = {"node_ids": self.obs_ids}
return self.loss_obj(
model, prepared,
condition_mask=None,
model_extras=model_extras,
)
return loss_fn
[docs]
def _wrap_model(self):
self.model_wrapped = UnconditionalWrapper(self.model)
self.ema_model_wrapped = UnconditionalWrapper(self.ema_model)
[docs]
def get_sampler(self, use_ema=True, **sampler_kwargs):
"""Get a sampler function.
Parameters
----------
use_ema : bool, optional
Whether to use the EMA model. Default is True.
**sampler_kwargs
Forwarded to ``method.build_sampler_fn``.
Returns
-------
Callable
``sampler(key, nsamples) -> samples``
"""
model_wrapped = self.ema_model_wrapped if use_ema else self.model_wrapped
model_extras = {"obs_ids": self.obs_ids}
sampler_fn = self.method.build_sampler_fn(
model_wrapped, self.path, model_extras, **sampler_kwargs,
)
def sampler(key, nsamples):
key, key_init = jax.random.split(key)
x_init = self.method.sample_init(
key_init, nsamples,
)
return sampler_fn(key, x_init, model_extras)
return sampler
[docs]
def sample(self, key, nsamples=10_000, use_ema=True, **sampler_kwargs):
"""Draw samples from the model.
Parameters
----------
key : jax.random.PRNGKey
Random key.
nsamples : int, optional
Number of samples. Default is 10 000.
use_ema : bool, optional
Use the EMA model. Default is True.
**sampler_kwargs
Forwarded to :meth:`get_sampler`.
Returns
-------
Array
Samples of shape ``(nsamples, dim_obs, ch_obs)``.
"""
sampler = self.get_sampler(use_ema=use_ema, **sampler_kwargs)
return sampler(key, nsamples)
[docs]
def get_log_prob_fn(self, use_ema=True, **kwargs):
"""Get a log-probability function.
Parameters
----------
use_ema : bool, optional
Whether to use the EMA model. Default is True.
**kwargs
Forwarded to ``method.build_log_prob_fn``.
Returns
-------
Callable
``log_prob_fn(x_1) -> log_prob``
"""
model_wrapped = self.ema_model_wrapped if use_ema else self.model_wrapped
model_extras = {"obs_ids": self.obs_ids}
log_prob_fn = self.method.build_log_prob_fn(
model_wrapped, self.path, model_extras, **kwargs,
)
def _log_prob(x_1, *, key=None):
return log_prob_fn(x_1, model_extras, key=key)
return _log_prob
[docs]
def log_prob(self, x_1, use_ema=True, *, key=None, **kwargs):
"""Compute log-probability of x_1.
Parameters
----------
x_1 : array-like
Data samples to evaluate.
use_ema : bool, optional
Use the EMA model. Default is True.
key : jax.random.PRNGKey, optional
Required when ``exact_divergence=False`` (Hutchinson).
**kwargs
Forwarded to :meth:`get_log_prob_fn`.
Returns
-------
Array
Log-probabilities.
"""
log_prob_fn = self.get_log_prob_fn(use_ema=use_ema, **kwargs)
return log_prob_fn(x_1, key=key)
[docs]
def sample_batched(self, *args, **kwargs):
raise NotImplementedError(
"Batched sampling not implemented for UnconditionalPipeline."
)