Source code for gensbi.models.wrappers.conditional

"""
Conditional model wrapper for GenSBI.

This module provides a wrapper class for conditional models used in flow matching,
handling proper input expansion and calling conventions for conditional inference.
"""

from jax import Array


from gensbi.utils.model_wrapping import ModelWrapper, _expand_dims, _expand_time




[docs] class ConditionalWrapper(ModelWrapper): """ Wrapper for conditional models to handle input expansion and calling convention. Parameters ---------- model: The conditional model instance to wrap. """ def __init__(self, model): """ Initialize the ConditionalWrapper. Parameters ---------- model: The conditional model instance to wrap. """ super().__init__(model)
[docs] def __call__( self, t: Array, obs: Array, obs_ids: Array, cond: Array, cond_ids: Array, conditioned: bool | Array = True, guidance: Array | None = None, **kwargs, ) -> Array: """ Call the wrapped model with expanded inputs. Parameters ---------- t : Array Time steps. obs : Array Observations. obs_ids : Array Observation identifiers. cond : Array Conditioning values. cond_ids : Array Conditioning identifiers. conditioned : bool | Array, optional Whether to use conditioning. Defaults to True. guidance : Array | None, optional Optional guidance input. Returns ------- Array Model output. """ obs = _expand_dims(obs) t = _expand_time(t) cond = _expand_dims(cond) obs_ids = _expand_dims(obs_ids) cond_ids = _expand_dims(cond_ids) return self.model( obs=obs, t=t, cond=cond, obs_ids=obs_ids, cond_ids=cond_ids, conditioned=conditioned, guidance=guidance, **kwargs, )