gensbi.experimental.models.glue.embedder#

Classes#

Embedded1DModel

Wraps a VAE encoder and a Flux1 SBI model for 1D sequence conditioning.

Embedded2DModel

Wraps a VAE encoder and a Flux1 SBI model for 2D image conditioning.

Module Contents#

class gensbi.experimental.models.glue.embedder.Embedded1DModel(vae, sbi_model)[source]#

Bases: flax.nnx.Module

Wraps a VAE encoder and a Flux1 SBI model for 1D sequence conditioning.

Encodes raw 1D conditioning sequences into latent space via the VAE and forwards everything to the underlying SBI model. Unlike Embedded2DModel, no patchification step is applied after encoding.

Parameters:
  • vae (AutoEncoder1D) – Variational autoencoder used to encode conditioning sequences into latents.

  • sbi_model (Flux1) – Simulation-based inference model that receives the encoded conditioning.

__call__(t, obs, obs_ids, cond, cond_ids, conditioned=True, guidance=None, encoder_key=None)[source]#

Encode conditioning sequences and run the SBI model forward pass.

Parameters:
  • t (Array) – Diffusion/flow timestep, shape (batch,) or scalar.

  • obs (Array) – Observation tokens (noisy samples) passed directly to the SBI model.

  • obs_ids (Array) – Positional IDs for the observation tokens.

  • cond (Array) – Raw 1D conditioning sequences to be encoded by the VAE, shape (batch, L, C).

  • cond_ids (Array) – Positional IDs for the conditioning tokens.

  • conditioned (bool or Array, optional) – Whether to apply conditioning (classifier-free guidance mask). Defaults to True.

  • guidance (Array or None, optional) – Guidance scale for classifier-free guidance. None disables it.

  • encoder_key (jax.random.PRNGKey or None, optional) – PRNG key forwarded to the VAE encoder for stochastic encoding. Pass None for deterministic encoding.

Returns:

Output of the SBI model with the VAE-encoded conditioning.

Return type:

Array

sbi_model[source]#
vae[source]#
class gensbi.experimental.models.glue.embedder.Embedded2DModel(vae, sbi_model)[source]#

Bases: flax.nnx.Module

Wraps a VAE encoder and a Flux1 SBI model for 2D image conditioning.

Encodes raw 2D conditioning images into latent space via the VAE, patchifies the result, and forwards everything to the underlying SBI model.

Parameters:
  • vae (AutoEncoder2D) – Variational autoencoder used to encode conditioning images into latents.

  • sbi_model (Flux1) – Simulation-based inference model that receives the encoded conditioning.

__call__(t, obs, obs_ids, cond, cond_ids, conditioned=True, guidance=None, encoder_key=None)[source]#

Encode conditioning images and run the SBI model forward pass.

Parameters:
  • t (Array) – Diffusion/flow timestep, shape (batch,) or scalar.

  • obs (Array) – Observation tokens (noisy samples) passed directly to the SBI model.

  • obs_ids (Array) – Positional IDs for the observation tokens.

  • cond (Array) – Raw 2D conditioning images to be encoded by the VAE, shape (batch, H, W, C).

  • cond_ids (Array) – Positional IDs for the conditioning tokens after patchification.

  • conditioned (bool or Array, optional) – Whether to apply conditioning (classifier-free guidance mask). Defaults to True.

  • guidance (Array or None, optional) – Guidance scale for classifier-free guidance. None disables it.

  • encoder_key (jax.random.PRNGKey or None, optional) – PRNG key forwarded to the VAE encoder for stochastic encoding. Pass None for deterministic encoding.

Returns:

Output of the SBI model with the VAE-encoded conditioning.

Return type:

Array

sbi_model[source]#
vae[source]#