Source code for gensbi.recipes.joint_pipeline

"""
Pipeline for training and using a Joint model for simulation-based inference.
"""

import jax
import jax.numpy as jnp
from flax import nnx
import optax
from optax.contrib import reduce_on_plateau
from tqdm.auto import tqdm
from functools import partial
import orbax.checkpoint as ocp

from gensbi.recipes.utils import init_ids_joint, build_edm_path, build_sm_path

from gensbi.flow_matching.path import AffineProbPath
from gensbi.flow_matching.path.scheduler import CondOTScheduler

from gensbi.diffusion.path import EDMPath
from gensbi.diffusion.path.scheduler import EDMScheduler, VEEdmScheduler, VPEdmScheduler
from gensbi.diffusion.solver import EDMSolver

from gensbi.diffusion.path.sm_path import SMPath
from gensbi.diffusion.path.scheduler import VPSmScheduler, VESmScheduler

from gensbi.core.prior import make_gaussian_prior


from einops import repeat

from gensbi.models import JointWrapper

import numpyro.distributions as dist

from gensbi.utils.model_wrapping import _expand_dims

import os
import yaml

from gensbi.recipes.pipeline import AbstractPipeline, ModelEMA

import warnings


[docs] def sample_structured_conditional_mask( key, num_samples, theta_dim, x_dim, p_joint=0.2, p_posterior=0.2, p_likelihood=0.2, p_rnd1=0.2, p_rnd2=0.2, rnd1_prob=0.3, rnd2_prob=0.7, ): """ Sample structured conditional masks for the Joint model. Parameters ---------- key : jax.random.PRNGKey Random key for sampling. num_samples : int Number of samples to generate. theta_dim : int Dimension of the parameter space. x_dim : int Dimension of the observation space. p_joint : float Probability of selecting the joint mask. p_posterior : float Probability of selecting the posterior mask. p_likelihood : float Probability of selecting the likelihood mask. p_rnd1 : float Probability of selecting the first random mask. p_rnd2 : float Probability of selecting the second random mask. rnd1_prob : float Probability of a True value in the first random mask. rnd2_prob : float Probability of a True value in the second random mask. Returns ------- condition_mask : jnp.ndarray Array of shape (num_samples, theta_dim + x_dim) with boolean masks. """ # Joint, posterior, likelihood, random1_mask, random2_mask key1, key2, key3 = jax.random.split(key, 3) joint_mask = jnp.array([False] * (theta_dim + x_dim), dtype=jnp.bool_) posterior_mask = jnp.array([False] * theta_dim + [True] * x_dim, dtype=jnp.bool_) likelihood_mask = jnp.array([True] * theta_dim + [False] * x_dim, dtype=jnp.bool_) random1_mask = jax.random.bernoulli( key2, rnd1_prob, shape=(theta_dim + x_dim,) ).astype(jnp.bool_) random2_mask = jax.random.bernoulli( key3, rnd2_prob, shape=(theta_dim + x_dim,) ).astype(jnp.bool_) mask_options = jnp.stack( [joint_mask, posterior_mask, likelihood_mask, random1_mask, random2_mask], axis=0, ) # (5, theta_dim + x_dim) idx = jax.random.choice( key1, 5, shape=(num_samples,), p=jnp.array([p_joint, p_posterior, p_likelihood, p_rnd1, p_rnd2]), ) condition_mask = mask_options[idx] all_ones_mask = jnp.all(condition_mask, axis=-1) # If all are ones, then set to false condition_mask = jnp.where(all_ones_mask[..., None], False, condition_mask) return condition_mask[..., None]
[docs] def sample_condition_mask( key, num_samples, theta_dim, x_dim, kind="structured", ): if kind == "structured": condition_mask = sample_structured_conditional_mask( key, num_samples, theta_dim, x_dim, ) elif kind == "posterior": condition_mask = jnp.array( [False] * theta_dim + [True] * x_dim, dtype=jnp.bool_ ).reshape(1, -1, 1) condition_mask = jnp.broadcast_to( condition_mask, (num_samples, theta_dim + x_dim, 1) ) elif kind == "likelihood": condition_mask = jnp.array( [True] * theta_dim + [False] * x_dim, dtype=jnp.bool_ ).reshape(1, -1, 1) condition_mask = jnp.broadcast_to( condition_mask, (num_samples, theta_dim + x_dim, 1) ) elif kind == "joint": condition_mask = jnp.array( [False] * (theta_dim + x_dim), dtype=jnp.bool_ ).reshape(1, -1, 1) condition_mask = jnp.broadcast_to( condition_mask, (num_samples, theta_dim + x_dim, 1) ) else: raise ValueError(f"Unknown kind {kind} for condition mask.") return condition_mask
# --------------------------------------------------------------------------- from gensbi.core.generative_method import GenerativeMethod
[docs] class JointPipeline(AbstractPipeline): """Model-agnostic joint pipeline parameterized by a ``GenerativeMethod``. Unlike the old method-specific pipeline classes, this class works with **any** generative method and **any** user-provided model that conforms to the ``JointWrapper`` interface. Parameters ---------- model : nnx.Module The model to be trained. train_dataset : iterable Training dataset yielding concatenated ``x_1`` batches (obs and cond concatenated along the token dimension). val_dataset : iterable Validation dataset. dim_obs : int Dimension of the observation/parameter space. dim_cond : int Dimension of the conditioning space. method : GenerativeMethod Strategy object (e.g. ``FlowMatchingMethod()``, ``DiffusionEDMMethod()``, ``ScoreMatchingMethod()``). ch_obs : int, optional Number of channels per token. Default is 1. condition_mask_kind : str, optional Kind of condition mask. One of ``"structured"`` or ``"posterior"``. Default is ``"structured"``. params : optional Model parameters (stored but not used directly). training_config : dict, optional Training configuration. Examples -------- >>> from gensbi.core import FlowMatchingMethod >>> pipeline = JointPipeline( ... model=my_model, ... train_dataset=train_ds, ... val_dataset=val_ds, ... dim_obs=2, dim_cond=7, ... method=FlowMatchingMethod(), ... ) """ def __init__( self, model, train_dataset, val_dataset, dim_obs: int, dim_cond: int, method: GenerativeMethod, ch_obs=1, condition_mask_kind="structured", params=None, training_config=None, ):
[docs] self.method = method
if training_config is None: training_config = self.get_default_training_config() extra = method.get_extra_training_config() for k, v in extra.items(): training_config.setdefault(k, v) super().__init__( model=model, train_dataset=train_dataset, val_dataset=val_dataset, dim_obs=dim_obs, dim_cond=dim_cond, ch_obs=ch_obs, params=params, training_config=training_config, )
[docs] self.dim_joint = self.dim_obs + self.dim_cond
self.node_ids, self.obs_ids, self.cond_ids = init_ids_joint( self.dim_obs, self.dim_cond )
[docs] self.path = method.build_path(self.training_config, event_shape=(self.dim_joint, self.ch_obs))
# OneFlow reweighting (Eq. 4, https://arxiv.org/pdf/2601.22951v1), currently disabled # upweight parameter (obs) dimensions by d_cond / d_obs to balance # gradient magnitudes when d_cond >> d_obs. # loss_weights = jnp.ones(dim_obs + dim_cond) # loss_weights = loss_weights.at[jnp.arange(dim_obs)].set(dim_cond / dim_obs) # # reshape to (1, dim_joint, 1) to broadcast over (batch, dim_joint, ch) # loss_weights = loss_weights.reshape(1, -1, 1) loss_weights = None
[docs] self.loss_obj = method.build_loss(self.path, weights=loss_weights)
if self.dim_cond == 0: raise ValueError( "JointPipeline initialized with dim_cond=0. " "Use UnconditionalPipeline instead." ) if condition_mask_kind not in ("structured", "posterior"): raise ValueError( f"condition_mask_kind must be one of ['structured', 'posterior'], " f"got {condition_mask_kind}." )
[docs] self.condition_mask_kind = condition_mask_kind
# -- Factory stubs ------------------------------------------------------ @classmethod
[docs] def init_pipeline_from_config(cls, *args, **kwargs): raise NotImplementedError( "JointPipeline is model-agnostic. " "Use model-specific pipelines for config init." )
[docs] def _make_model(self): raise NotImplementedError( "JointPipeline is model-agnostic — the user provides the model." )
@classmethod
[docs] def get_default_params(cls, *args, **kwargs): raise NotImplementedError( "JointPipeline is model-agnostic — the user provides model params." )
# -- Core pipeline methods ----------------------------------------------
[docs] def get_loss_fn(self): def loss_fn(model, x_1, key): batch_size = x_1.shape[0] rng_batch, rng_condition = jax.random.split(key) prepared = self.method.prepare_batch(rng_batch, x_1, self.path) condition_mask = sample_condition_mask( rng_condition, batch_size, self.dim_obs, self.dim_cond, kind=self.condition_mask_kind, ) model_extras = { "node_ids": self.node_ids, "condition_mask": condition_mask, } return self.loss_obj( model, prepared, condition_mask=condition_mask, model_extras=model_extras, ) return loss_fn
[docs] def _wrap_model(self): self.model_wrapped = JointWrapper(self.model) self.ema_model_wrapped = JointWrapper(self.ema_model)
[docs] def get_sampler(self, x_o, use_ema=True, model_extras=None, **sampler_kwargs): """Get a sampler function. Parameters ---------- x_o : array-like Conditioning variable (observed data). use_ema : bool, optional Whether to use the EMA model. Default is True. model_extras : dict, optional Additional keyword arguments passed to the model during sampling (e.g. ``{"edge_mask": mask}``). Cannot override the protected keys ``cond``, ``obs_ids``, ``cond_ids``. **sampler_kwargs Forwarded to ``method.build_sampler_fn``. Returns ------- Callable ``sampler(key, nsamples) -> samples`` """ model_wrapped = self.ema_model_wrapped if use_ema else self.model_wrapped cond = _expand_dims(x_o) _PROTECTED = {"cond", "obs_ids", "cond_ids"} extras = { "cond": cond, "obs_ids": self.obs_ids, "cond_ids": self.cond_ids, } if model_extras: conflict = _PROTECTED & model_extras.keys() if conflict: raise ValueError( f"model_extras cannot override protected keys: {conflict}" ) extras.update(model_extras) sampler_fn = self.method.build_sampler_fn( model_wrapped, self.path, extras, **sampler_kwargs, ) def sampler(key, nsamples, model_extras=None): _extras = model_extras if model_extras is not None else extras key, key_init = jax.random.split(key) x_init_joint = self.method.sample_init(key_init, nsamples) # Marginalize joint prior to obs dims. This matches # JointWrapper.conditioned() which assumes obs = first dim_obs # dims, cond = remaining dims. # When supporting arbitrary condition masks, update both this # and the wrapper. x_init = x_init_joint[:, :self.dim_obs, :] return sampler_fn(key, x_init, _extras) return sampler
[docs] def sample(self, key, x_o, nsamples=10_000, use_ema=True, **sampler_kwargs): """Draw samples from the model. Parameters ---------- key : jax.random.PRNGKey Random key. x_o : array-like Conditioning variable. nsamples : int, optional Number of samples. Default is 10 000. use_ema : bool, optional Use the EMA model. Default is True. **sampler_kwargs Forwarded to :meth:`get_sampler`. Returns ------- Array Samples of shape ``(nsamples, dim_obs, ch_obs)``. """ x_o_shape = x_o.shape[0] if hasattr(x_o, "shape") else len(x_o) if x_o_shape > 1: warnings.warn( f"x_o has batch dimension {x_o_shape} > 1. " "sample() draws all samples for a single condition. " "To sample for multiple conditions, use sample_batched() instead.", UserWarning, stacklevel=2, ) sampler = self.get_sampler(x_o, use_ema=use_ema, **sampler_kwargs) return sampler(key, nsamples)
[docs] def get_log_prob_fn(self, x_o, use_ema=True, prior=None, model_extras=None, **kwargs): """Get a log-probability function. Parameters ---------- x_o : array-like Conditioning variable (observed data). use_ema : bool, optional Whether to use the EMA model. Default is True. prior : numpyro.distributions.Distribution, optional Obs-space prior for log-probability evaluation. The method's prior lives on the full joint space ``(dim_joint, ch)`` and cannot be automatically marginalized for arbitrary priors. - **Default Gaussian**: auto-constructed — no need to provide. - **Custom prior**: must supply the correct obs-space marginal. model_extras : dict, optional Additional model extras. Cannot override protected keys. **kwargs Forwarded to ``method.build_log_prob_fn``. Returns ------- Callable ``log_prob_fn(x_1) -> log_prob`` Raises ------ ValueError If the joint prior is non-Gaussian and no ``prior`` is provided. """ if prior is not None: log_p0 = prior.log_prob elif not self.method.has_custom_prior: # Auto-constructed default prior — we know the marginal is a # standard Gaussian on the obs dims. obs_prior = make_gaussian_prior(self.dim_obs, self.ch_obs) log_p0 = obs_prior.log_prob else: raise ValueError( "Joint pipeline with a custom prior requires an explicit " "`prior` for log_prob — the obs-space marginal of the joint " "prior. Pass a numpyro distribution whose event_shape " "matches (dim_obs, ch_obs)." ) model_wrapped = self.ema_model_wrapped if use_ema else self.model_wrapped cond = _expand_dims(x_o) _PROTECTED = {"cond", "obs_ids", "cond_ids"} extras = { "cond": cond, "obs_ids": self.obs_ids, "cond_ids": self.cond_ids, } if model_extras: conflict = _PROTECTED & model_extras.keys() if conflict: raise ValueError( f"model_extras cannot override protected keys: {conflict}" ) extras.update(model_extras) log_prob_fn = self.method.build_log_prob_fn( model_wrapped, self.path, extras, log_prior=log_p0, **kwargs, ) def _log_prob(x_1, model_extras=None, *, key=None): _extras = model_extras if model_extras is not None else extras return log_prob_fn(x_1, _extras, key=key) return _log_prob
[docs] def log_prob(self, x_1, x_o, use_ema=True, prior=None, *, key=None, **kwargs): """Compute log-probability of x_1 given x_o. Parameters ---------- x_1 : array-like Data samples to evaluate. x_o : array-like Conditioning variable. use_ema : bool, optional Use the EMA model. Default is True. prior : numpyro.distributions.Distribution, optional Obs-space prior distribution. See :meth:`get_log_prob_fn` for details. key : jax.random.PRNGKey, optional Required when ``exact_divergence=False`` (Hutchinson). **kwargs Forwarded to :meth:`get_log_prob_fn`. Returns ------- Array Log-probabilities. """ log_prob_fn = self.get_log_prob_fn( x_o, use_ema=use_ema, prior=prior, **kwargs ) return log_prob_fn(x_1, key=key)