Source code for gensbi.models.wrappers.unconditional

"""
Unconditional model wrapper for GenSBI.

This module provides a wrapper class for unconditional models used in flow matching,
handling proper input expansion and calling conventions.
"""
from jax import Array
from typing import Optional


import jax.numpy as jnp
from jax import Array
from jax.typing import DTypeLike

from gensbi.utils.model_wrapping import ModelWrapper, _expand_dims, _expand_time



[docs] class UnconditionalWrapper(ModelWrapper): """ Wrapper for unconditional models to handle input expansion and calling convention. Parameters ---------- model: The unconditional model instance to wrap. """ def __init__(self, model): """ Initialize the UnconditionalWrapper. Parameters ---------- model: The unconditional model instance to wrap. """ super().__init__(model)
[docs] def __call__( self, t: Array, obs: Array, obs_ids: Array, **kwargs, ) -> Array: """ Call the wrapped model with expanded inputs. Parameters ---------- t : Array Time steps. obs : Array Observations. obs_ids : Array Observation identifiers. **kwargs: Additional keyword arguments passed to the model. Returns ------- Array Model output. """ t = _expand_time(t) obs = _expand_dims(obs) obs_ids = _expand_dims(obs_ids) return self.model( obs=obs, t=t, node_ids=obs_ids, condition_mask=jnp.zeros(obs.shape, dtype=jnp.bool_), **kwargs, )