gensbi.experimental.models#

Submodules#

Classes#

AutoEncoder1D

1D Autoencoder model with Gaussian latent space.

AutoEncoder2D

2D Autoencoder model with Gaussian latent space.

AutoEncoderParams

Configuration parameters for the AutoEncoder models.

Embedded1DModel

Wraps a VAE encoder and a Flux1 SBI model for 1D sequence conditioning.

Embedded2DModel

Wraps a VAE encoder and a Flux1 SBI model for 2D image conditioning.

FieldDiT

Conditional flow-matching network for 2D fields (pixel space).

FieldDiTParams

Configuration for FieldDiT (mirrors the style of Flux1Params).

PixelDiT

Dual-level pixel-space DiT for conditional flow matching on 2D fields.

PixelDiTParams

Configuration for PixelDiT (mirrors the style of FieldDiTParams).

Functions#

vae_loss_fn(model, x, key[, kl_weight])

Compute the VAE loss as the sum of reconstruction and KL divergence losses.

Package Contents#

class gensbi.experimental.models.AutoEncoder1D(params)[source]#

Bases: flax.nnx.Module

1D Autoencoder model with Gaussian latent space.

Parameters:

params (AutoEncoderParams) – Configuration parameters for the autoencoder.

__call__(x, key=None)[source]#

Forward pass: encode and then decode the input.

Parameters:

x (Array) – Input tensor.

Returns:

Reconstructed output.

Return type:

Array

decode(z)[source]#

Decode latent representation back to data space.

Parameters:

z (Array) – Latent tensor.

Returns:

Reconstructed output.

Return type:

Array

encode(x, key=None)[source]#

Encode input data into the latent space.

Parameters:
  • x (Array) – Input tensor.

  • key (Array) – PRNG key for sampling the latent variable.

Returns:

Latent representation.

Return type:

Array

Decoder1D#
Encoder1D#
latent_shape#
reg#
rngs#
scale_factor#
shift_factor#
class gensbi.experimental.models.AutoEncoder2D(params)[source]#

Bases: flax.nnx.Module

2D Autoencoder model with Gaussian latent space.

Parameters:

params (AutoEncoderParams) – Configuration parameters for the autoencoder.

__call__(x, key=None)[source]#

Forward pass: encode and then decode the input.

Parameters:

x (Array) – Input tensor.

Returns:

Reconstructed output.

Return type:

Array

decode(z)[source]#

Decode latent representation back to data space.

Parameters:

z (Array) – Latent tensor.

Returns:

Reconstructed output.

Return type:

Array

encode(x, key=None)[source]#

Encode input data into the latent space.

Parameters:
  • x (Array) – Input tensor.

  • key (Array) – PRNG key for sampling the latent variable.

Returns:

Latent representation.

Return type:

Array

Decoder2D#
Encoder2D#
latent_shape#
reg#
rngs#
scale_factor#
shift_factor#
class gensbi.experimental.models.AutoEncoderParams[source]#

Configuration parameters for the AutoEncoder models.

resolution#

The input feature dimension (length for 1D, height/width for 2D).

Type:

int

in_channels#

Number of input channels (e.g., 1 for scalar features, >1 for multi-channel).

Type:

int

ch#

Base number of channels for the first convolutional layer.

Type:

int

out_ch#

Number of output channels produced by the decoder (matches input channels for reconstruction).

Type:

int

ch_mult#

Multipliers for the number of channels at each resolution level (controls model width/depth).

Type:

list[int]

num_res_blocks#

Number of residual blocks per resolution level.

Type:

int

z_channels#

Number of latent channels in the bottleneck (size of encoded representation).

Type:

int

scale_factor#

Scaling factor applied to the latent representation (for normalization or data scaling).

Type:

float

shift_factor#

Shift factor applied to the latent representation (for normalization or data centering).

Type:

float

rngs#

Random number generators for parameter initialization and stochastic layers.

Type:

nnx.Rngs

param_dtype#

Data type for model parameters (e.g., jnp.float32, jnp.bfloat16).

Type:

DTypeLike

ch: int#
ch_mult: list[int]#
in_channels: int#
num_res_blocks: int#
out_ch: int#
param_dtype: jax.typing.DTypeLike#
resolution: int#
rngs: flax.nnx.Rngs#
scale_factor: float#
shift_factor: float#
z_channels: int#
class gensbi.experimental.models.Embedded1DModel(vae, sbi_model)[source]#

Bases: flax.nnx.Module

Wraps a VAE encoder and a Flux1 SBI model for 1D sequence conditioning.

Encodes raw 1D conditioning sequences into latent space via the VAE and forwards everything to the underlying SBI model. Unlike Embedded2DModel, no patchification step is applied after encoding.

Parameters:
  • vae (AutoEncoder1D) – Variational autoencoder used to encode conditioning sequences into latents.

  • sbi_model (Flux1) – Simulation-based inference model that receives the encoded conditioning.

__call__(t, obs, obs_ids, cond, cond_ids, conditioned=True, guidance=None, encoder_key=None)[source]#

Encode conditioning sequences and run the SBI model forward pass.

Parameters:
  • t (Array) – Diffusion/flow timestep, shape (batch,) or scalar.

  • obs (Array) – Observation tokens (noisy samples) passed directly to the SBI model.

  • obs_ids (Array) – Positional IDs for the observation tokens.

  • cond (Array) – Raw 1D conditioning sequences to be encoded by the VAE, shape (batch, L, C).

  • cond_ids (Array) – Positional IDs for the conditioning tokens.

  • conditioned (bool or Array, optional) – Whether to apply conditioning (classifier-free guidance mask). Defaults to True.

  • guidance (Array or None, optional) – Guidance scale for classifier-free guidance. None disables it.

  • encoder_key (jax.random.PRNGKey or None, optional) – PRNG key forwarded to the VAE encoder for stochastic encoding. Pass None for deterministic encoding.

Returns:

Output of the SBI model with the VAE-encoded conditioning.

Return type:

Array

sbi_model#
vae#
class gensbi.experimental.models.Embedded2DModel(vae, sbi_model)[source]#

Bases: flax.nnx.Module

Wraps a VAE encoder and a Flux1 SBI model for 2D image conditioning.

Encodes raw 2D conditioning images into latent space via the VAE, patchifies the result, and forwards everything to the underlying SBI model.

Parameters:
  • vae (AutoEncoder2D) – Variational autoencoder used to encode conditioning images into latents.

  • sbi_model (Flux1) – Simulation-based inference model that receives the encoded conditioning.

__call__(t, obs, obs_ids, cond, cond_ids, conditioned=True, guidance=None, encoder_key=None)[source]#

Encode conditioning images and run the SBI model forward pass.

Parameters:
  • t (Array) – Diffusion/flow timestep, shape (batch,) or scalar.

  • obs (Array) – Observation tokens (noisy samples) passed directly to the SBI model.

  • obs_ids (Array) – Positional IDs for the observation tokens.

  • cond (Array) – Raw 2D conditioning images to be encoded by the VAE, shape (batch, H, W, C).

  • cond_ids (Array) – Positional IDs for the conditioning tokens after patchification.

  • conditioned (bool or Array, optional) – Whether to apply conditioning (classifier-free guidance mask). Defaults to True.

  • guidance (Array or None, optional) – Guidance scale for classifier-free guidance. None disables it.

  • encoder_key (jax.random.PRNGKey or None, optional) – PRNG key forwarded to the VAE encoder for stochastic encoding. Pass None for deterministic encoding.

Returns:

Output of the SBI model with the VAE-encoded conditioning.

Return type:

Array

sbi_model#
vae#
class gensbi.experimental.models.FieldDiT(params)[source]#

Bases: flax.nnx.Module

Conditional flow-matching network for 2D fields (pixel space).

Forward: (t, obs=field, cond) -> velocity field of the same shape. The conv encoder is modulated by time only (or by the full vec when cond_modulates_encoder=True); the decoder and the MMDiT core are modulated by vec = time (+ cond summary if flagged-C) (+ guidance). The meeting-grid rope2d obs ids and absolute cond ids are built internally.

Parameters:

params (FieldDiTParams)

__call__(t, obs, cond, obs_ids=None, cond_ids=None, conditioned=True, guidance=None)[source]#
cond_dim#
cond_embedder#
cond_ids#
cond_modulates_encoder#
core#
decoder#
encoder#
field_shape#
guidance_embed#
in_channels#
obs_ids#
param_dtype#
time_in#
token_grid#
tokenizer#
untokenizer#
use_cond_summary_in_vec#
class gensbi.experimental.models.FieldDiTParams[source]#

Configuration for FieldDiT (mirrors the style of Flux1Params).

The meeting-grid token count is derived from field_shape, encoder_widths (depth) and patch_size — it is not prescribed: tokens = (H / (2**D * p)) * (W / (2**D * p)) with D = len(encoder_widths) - 1.

Note: rngs is a live nnx.Rngs stream (mirrors Flux1Params). Constructing two models from the same params object yields different weights, because the stream advances; build a fresh FieldDiTParams (or a fresh nnx.Rngs(seed)) per model for reproducibility.

__post_init__()[source]#
axes_dim: List[int] | None = None#
cond_dim: int#
cond_in_channels: int = 1#
cond_modulates_encoder: bool = False#
depth: int = 2#
depth_single_blocks: int = 2#
encoder_widths: Tuple[int, Ellipsis]#
field_shape: Tuple[int, int]#
guidance_embed: bool = False#
in_channels: int#
mlp_ratio: float = 4.0#
norm_groups: int = 8#
num_heads: int = 12#
param_dtype: jax.typing.DTypeLike#
patch_size: int = 2#
qkv_bias: bool = False#
res_blocks_down: int = 2#
res_blocks_up: int = 2#
rngs: flax.nnx.Rngs#
theta: int | None = None#
use_cond_summary_in_vec: bool = True#
vec_in_dim: int | None = None#
class gensbi.experimental.models.PixelDiT(params)[source]#

Bases: flax.nnx.Module

Dual-level pixel-space DiT for conditional flow matching on 2D fields.

Forward: (t, obs=field, cond) -> velocity field of the same shape. Patch-level MMDiT blocks attend over patch tokens jointly with cond tokens, producing per-patch conditioning s_cond; pixel-level PiT blocks then refine every pixel inside each patch under that conditioning.

Parameters:

params (PixelDiTParams)

__call__(t, obs, cond, obs_ids=None, cond_ids=None, conditioned=True, guidance=None)[source]#
cond_dim#
cond_embedder#
cond_in_channels#
field_shape#
final_layer#
in_channels#
param_dtype#
patch_blocks#
patch_size#
pe_patch#
pe_pit#
pixel_blocks#
pixel_embedder#
s_embedder#
t_conditioner#
token_grid#
class gensbi.experimental.models.PixelDiTParams[source]#

Configuration for PixelDiT (mirrors the style of FieldDiTParams).

Note: rngs is a live nnx.Rngs stream (mirrors FieldDiTParams / Flux1Params). Constructing two models from the same params object yields different weights, because the stream advances; build a fresh PixelDiTParams (or a fresh nnx.Rngs(seed)) per model for reproducibility.

__post_init__()[source]#
cond_dim: int#
cond_id_embedding: str = 'absolute'#
cond_in_channels: int = 1#
field_shape: Tuple[int, int]#
hidden_size: int = 384#
in_channels: int#
mlp_ratio: float = 4.0#
num_heads: int = 6#
param_dtype: jax.typing.DTypeLike#
patch_depth: int = 6#
patch_size: int = 4#
pit_post_modulation: bool = False#
pixel_attn_hidden_size: int | None = None#
pixel_depth: int = 2#
pixel_hidden_size: int = 16#
pixel_num_heads: int | None = None#
rngs: flax.nnx.Rngs#
rope_scale: float = 16.0#
theta: float = 10000.0#
use_cond_rope: bool = True#
use_pixel_abs_pos: bool = True#
zero_init_blocks: bool = True#
gensbi.experimental.models.vae_loss_fn(model, x, key, kl_weight=0.1)[source]#

Compute the VAE loss as the sum of reconstruction and KL divergence losses.

Parameters:
  • model (The VAE model.)

  • x (Input data.)

  • key (PRNG key for stochastic operations.)

  • kl_weight (Weight for the KL divergence term. Defaults to 0.1.)

Return type:

Scalar loss value combining reconstruction and KL losses.