Source code for gensbi.recipes.conditional_pipeline

"""
Pipeline for training and using a Conditional model for simulation-based inference.
"""

import jax
import jax.numpy as jnp
from flax import nnx
import optax
from optax.contrib import reduce_on_plateau

from numpyro import distributions as dist
from tqdm.auto import tqdm
from functools import partial
import orbax.checkpoint as ocp

from typing import Union, Tuple

from gensbi.flow_matching.path import AffineProbPath
from gensbi.flow_matching.path.scheduler import CondOTScheduler

from gensbi.diffusion.path import EDMPath
from gensbi.diffusion.path.scheduler import EDMScheduler, VEEdmScheduler, VPEdmScheduler
from gensbi.diffusion.solver import EDMSolver

from gensbi.diffusion.path.sm_path import SMPath
from gensbi.diffusion.path.scheduler import VPSmScheduler, VESmScheduler


from gensbi.models import ConditionalWrapper

from einops import repeat

from gensbi.models.flux1 import model
from gensbi.utils.model_wrapping import _expand_dims

import os

import yaml

from gensbi.recipes.pipeline import AbstractPipeline

from gensbi.recipes.utils import _resolve_embedding_ids, build_edm_path, build_sm_path

import warnings


# ---------------------------------------------------------------------------
# Unified ConditionalPipeline (Phase 2)
# ---------------------------------------------------------------------------

from gensbi.core.generative_method import GenerativeMethod


[docs] class ConditionalPipeline(AbstractPipeline): """Model-agnostic conditional pipeline parameterized by a ``GenerativeMethod``. Unlike the old method-specific pipeline classes, this class works with **any** generative method and **any** user-provided model that conforms to the ``ConditionalWrapper`` interface. Parameters ---------- model : nnx.Module The model to be trained. train_dataset : iterable Training dataset yielding ``(obs, cond)`` batches. val_dataset : iterable Validation dataset yielding ``(obs, cond)`` batches. dim_obs : int or tuple of int Dimension of the observation/parameter space. dim_cond : int or tuple of int Dimension of the conditioning space. method : GenerativeMethod Strategy object (e.g. ``FlowMatchingMethod()``, ``DiffusionEDMMethod()``, ``ScoreMatchingMethod()``). ch_obs : int, optional Number of channels per observation token. Default is 1. ch_cond : int, optional Number of channels per conditioning token. Default is 1. id_embedding_strategy : tuple of str, optional Embedding strategy for observation and conditioning IDs. Default is ``("absolute", "absolute")``. params : optional Model parameters (stored but not used directly). training_config : dict, optional Training configuration. If ``None``, uses defaults augmented by ``method.get_extra_training_config()``. Examples -------- >>> from gensbi.core import FlowMatchingMethod >>> pipeline = ConditionalPipeline( ... model=my_model, ... train_dataset=train_ds, ... val_dataset=val_ds, ... dim_obs=5, dim_cond=3, ... method=FlowMatchingMethod(), ... ) """ def __init__( self, model, train_dataset, val_dataset, dim_obs, dim_cond, method: GenerativeMethod, ch_obs=1, ch_cond=1, id_embedding_strategy=("absolute", "absolute"), params=None, training_config=None, ):
[docs] self.method = method
# Merge method-specific defaults before super().__init__ which # computes derived values from training_config. if training_config is None: training_config = self.get_default_training_config() extra = method.get_extra_training_config() for k, v in extra.items(): training_config.setdefault(k, v) super().__init__( model=model, train_dataset=train_dataset, val_dataset=val_dataset, dim_obs=dim_obs, dim_cond=dim_cond, ch_obs=ch_obs, ch_cond=ch_cond, params=params, training_config=training_config, ) self.obs_ids, self.dim_obs = _resolve_embedding_ids( dim_obs, id_embedding_strategy[0], semantic_id=0 ) self.cond_ids, self.dim_cond = _resolve_embedding_ids( dim_cond, id_embedding_strategy[1], semantic_id=1 )
[docs] self.path = method.build_path(self.training_config, event_shape=(self.dim_obs, self.ch_obs))
[docs] self.loss_obj = method.build_loss(self.path)
# -- Factory stubs (model-agnostic: user provides model) ---------------- @classmethod
[docs] def init_pipeline_from_config(cls, *args, **kwargs): raise NotImplementedError( "ConditionalPipeline is model-agnostic. " "Use model-specific pipelines (e.g. Flux1FlowPipeline) for config init." )
[docs] def _make_model(self): raise NotImplementedError( "ConditionalPipeline is model-agnostic — the user provides the model." )
@classmethod
[docs] def get_default_params(cls, *args, **kwargs): raise NotImplementedError( "ConditionalPipeline is model-agnostic — the user provides model params." )
# -- Core pipeline methods ----------------------------------------------
[docs] def get_loss_fn(self): def loss_fn(model, batch, key): obs, cond = batch prepared = self.method.prepare_batch(key, obs, self.path) model_extras = { "cond": cond, "obs_ids": self.obs_ids, "cond_ids": self.cond_ids, } return self.loss_obj(model, prepared, model_extras=model_extras) return loss_fn
[docs] def _wrap_model(self): self.model_wrapped = ConditionalWrapper(self.model) self.ema_model_wrapped = ConditionalWrapper(self.ema_model)
[docs] def get_sampler(self, x_o, use_ema=True, model_extras=None, **sampler_kwargs): """Get a sampler function. Parameters ---------- x_o : array-like Conditioning variable (observed data). use_ema : bool, optional Whether to use the EMA model. Default is True. model_extras : dict, optional Additional keyword arguments passed to the model during sampling (e.g. ``{"edge_mask": mask}``). Cannot override the protected keys ``cond``, ``obs_ids``, ``cond_ids``. **sampler_kwargs Forwarded to ``method.build_sampler_fn`` (e.g. ``step_size``, ``nsteps``, ``solver``, ``time_grid``). Returns ------- Callable ``sampler(key, nsamples) -> samples`` """ model_wrapped = self.ema_model_wrapped if use_ema else self.model_wrapped cond = _expand_dims(x_o) _PROTECTED = {"cond", "obs_ids", "cond_ids"} extras = { "cond": cond, "obs_ids": self.obs_ids, "cond_ids": self.cond_ids, } if model_extras: conflict = _PROTECTED & model_extras.keys() if conflict: raise ValueError( f"model_extras cannot override protected keys: {conflict}" ) extras.update(model_extras) sampler_fn = self.method.build_sampler_fn( model_wrapped, self.path, extras, **sampler_kwargs, ) def sampler(key, nsamples, model_extras=None): _extras = model_extras if model_extras is not None else extras key, key_init = jax.random.split(key) x_init = self.method.sample_init( key_init, nsamples, ) return sampler_fn(key, x_init, _extras) return sampler
[docs] def sample(self, key, x_o, nsamples=10_000, use_ema=True, **sampler_kwargs): """Draw samples from the model. Parameters ---------- key : jax.random.PRNGKey Random key. x_o : array-like Conditioning variable. nsamples : int, optional Number of samples. Default is 10 000. use_ema : bool, optional Use the EMA model. Default is True. **sampler_kwargs Forwarded to :meth:`get_sampler`. Returns ------- Array Samples of shape ``(nsamples, dim_obs, ch_obs)``. """ x_o_shape = x_o.shape[0] if hasattr(x_o, "shape") else len(x_o) if x_o_shape > 1: warnings.warn( f"x_o has batch dimension {x_o_shape} > 1. " "sample() draws all samples for a single condition. " "To sample for multiple conditions, use sample_batched() instead.", UserWarning, stacklevel=2, ) sampler = self.get_sampler(x_o, use_ema=use_ema, **sampler_kwargs) return sampler(key, nsamples)
[docs] def get_log_prob_fn(self, x_o, use_ema=True, model_extras=None, **kwargs): """Get a log-probability function. Parameters ---------- x_o : array-like Conditioning variable (observed data). use_ema : bool, optional Whether to use the EMA model. Default is True. model_extras : dict, optional Additional model extras. Cannot override protected keys. **kwargs Forwarded to ``method.build_log_prob_fn``. Returns ------- Callable ``log_prob_fn(x_1) -> log_prob`` """ model_wrapped = self.ema_model_wrapped if use_ema else self.model_wrapped cond = _expand_dims(x_o) _PROTECTED = {"cond", "obs_ids", "cond_ids"} extras = { "cond": cond, "obs_ids": self.obs_ids, "cond_ids": self.cond_ids, } # TODO: this branch is not currently tested, as we don't really ever use it. # Add tests when we find a good usage for this, same below. if model_extras: conflict = _PROTECTED & model_extras.keys() if conflict: raise ValueError( f"model_extras cannot override protected keys: {conflict}" ) extras.update(model_extras) log_prob_fn = self.method.build_log_prob_fn( model_wrapped, self.path, extras, **kwargs, ) def _log_prob(x_1, model_extras=None, *, key=None): _extras = model_extras if model_extras is not None else extras return log_prob_fn(x_1, _extras, key=key) return _log_prob
[docs] def log_prob(self, x_1, x_o, use_ema=True, *, key=None, **kwargs): """Compute log-probability of x_1 given x_o. Parameters ---------- x_1 : array-like Data samples to evaluate. x_o : array-like Conditioning variable. use_ema : bool, optional Use the EMA model. Default is True. key : jax.random.PRNGKey, optional Required when ``exact_divergence=False`` (Hutchinson). **kwargs Forwarded to :meth:`get_log_prob_fn`. Returns ------- Array Log-probabilities. """ log_prob_fn = self.get_log_prob_fn(x_o, use_ema=use_ema, **kwargs) return log_prob_fn(x_1, key=key)