gensbi.experimental.models#
Submodules#
Classes#
1D Autoencoder model with Gaussian latent space. |
|
2D Autoencoder model with Gaussian latent space. |
|
Configuration parameters for the AutoEncoder models. |
|
Wraps a VAE encoder and a Flux1 SBI model for 1D sequence conditioning. |
|
Wraps a VAE encoder and a Flux1 SBI model for 2D image conditioning. |
|
Conditional flow-matching network for 2D fields (pixel space). |
|
Configuration for |
|
Dual-level pixel-space DiT for conditional flow matching on 2D fields. |
|
Configuration for |
Functions#
|
Compute the VAE loss as the sum of reconstruction and KL divergence losses. |
Package Contents#
- class gensbi.experimental.models.AutoEncoder1D(params)[source]#
Bases:
flax.nnx.Module1D Autoencoder model with Gaussian latent space.
- Parameters:
params (AutoEncoderParams) – Configuration parameters for the autoencoder.
- __call__(x, key=None)[source]#
Forward pass: encode and then decode the input.
- Parameters:
x (Array) – Input tensor.
- Returns:
Reconstructed output.
- Return type:
Array
- decode(z)[source]#
Decode latent representation back to data space.
- Parameters:
z (Array) – Latent tensor.
- Returns:
Reconstructed output.
- Return type:
Array
- encode(x, key=None)[source]#
Encode input data into the latent space.
- Parameters:
x (Array) – Input tensor.
key (Array) – PRNG key for sampling the latent variable.
- Returns:
Latent representation.
- Return type:
Array
- Decoder1D#
- Encoder1D#
- latent_shape#
- reg#
- rngs#
- scale_factor#
- shift_factor#
- class gensbi.experimental.models.AutoEncoder2D(params)[source]#
Bases:
flax.nnx.Module2D Autoencoder model with Gaussian latent space.
- Parameters:
params (AutoEncoderParams) – Configuration parameters for the autoencoder.
- __call__(x, key=None)[source]#
Forward pass: encode and then decode the input.
- Parameters:
x (Array) – Input tensor.
- Returns:
Reconstructed output.
- Return type:
Array
- decode(z)[source]#
Decode latent representation back to data space.
- Parameters:
z (Array) – Latent tensor.
- Returns:
Reconstructed output.
- Return type:
Array
- encode(x, key=None)[source]#
Encode input data into the latent space.
- Parameters:
x (Array) – Input tensor.
key (Array) – PRNG key for sampling the latent variable.
- Returns:
Latent representation.
- Return type:
Array
- Decoder2D#
- Encoder2D#
- latent_shape#
- reg#
- rngs#
- scale_factor#
- shift_factor#
- class gensbi.experimental.models.AutoEncoderParams[source]#
Configuration parameters for the AutoEncoder models.
- resolution#
The input feature dimension (length for 1D, height/width for 2D).
- Type:
int
- in_channels#
Number of input channels (e.g., 1 for scalar features, >1 for multi-channel).
- Type:
int
- ch#
Base number of channels for the first convolutional layer.
- Type:
int
- out_ch#
Number of output channels produced by the decoder (matches input channels for reconstruction).
- Type:
int
- ch_mult#
Multipliers for the number of channels at each resolution level (controls model width/depth).
- Type:
list[int]
- num_res_blocks#
Number of residual blocks per resolution level.
- Type:
int
- z_channels#
Number of latent channels in the bottleneck (size of encoded representation).
- Type:
int
- scale_factor#
Scaling factor applied to the latent representation (for normalization or data scaling).
- Type:
float
- shift_factor#
Shift factor applied to the latent representation (for normalization or data centering).
- Type:
float
- rngs#
Random number generators for parameter initialization and stochastic layers.
- Type:
nnx.Rngs
- param_dtype#
Data type for model parameters (e.g., jnp.float32, jnp.bfloat16).
- Type:
DTypeLike
- ch: int#
- ch_mult: list[int]#
- in_channels: int#
- num_res_blocks: int#
- out_ch: int#
- param_dtype: jax.typing.DTypeLike#
- resolution: int#
- rngs: flax.nnx.Rngs#
- scale_factor: float#
- shift_factor: float#
- z_channels: int#
- class gensbi.experimental.models.Embedded1DModel(vae, sbi_model)[source]#
Bases:
flax.nnx.ModuleWraps a VAE encoder and a Flux1 SBI model for 1D sequence conditioning.
Encodes raw 1D conditioning sequences into latent space via the VAE and forwards everything to the underlying SBI model. Unlike
Embedded2DModel, no patchification step is applied after encoding.- Parameters:
vae (AutoEncoder1D) – Variational autoencoder used to encode conditioning sequences into latents.
sbi_model (Flux1) – Simulation-based inference model that receives the encoded conditioning.
- __call__(t, obs, obs_ids, cond, cond_ids, conditioned=True, guidance=None, encoder_key=None)[source]#
Encode conditioning sequences and run the SBI model forward pass.
- Parameters:
t (Array) – Diffusion/flow timestep, shape
(batch,)or scalar.obs (Array) – Observation tokens (noisy samples) passed directly to the SBI model.
obs_ids (Array) – Positional IDs for the observation tokens.
cond (Array) – Raw 1D conditioning sequences to be encoded by the VAE, shape
(batch, L, C).cond_ids (Array) – Positional IDs for the conditioning tokens.
conditioned (bool or Array, optional) – Whether to apply conditioning (classifier-free guidance mask). Defaults to
True.guidance (Array or None, optional) – Guidance scale for classifier-free guidance.
Nonedisables it.encoder_key (jax.random.PRNGKey or None, optional) – PRNG key forwarded to the VAE encoder for stochastic encoding. Pass
Nonefor deterministic encoding.
- Returns:
Output of the SBI model with the VAE-encoded conditioning.
- Return type:
Array
- sbi_model#
- vae#
- class gensbi.experimental.models.Embedded2DModel(vae, sbi_model)[source]#
Bases:
flax.nnx.ModuleWraps a VAE encoder and a Flux1 SBI model for 2D image conditioning.
Encodes raw 2D conditioning images into latent space via the VAE, patchifies the result, and forwards everything to the underlying SBI model.
- Parameters:
vae (AutoEncoder2D) – Variational autoencoder used to encode conditioning images into latents.
sbi_model (Flux1) – Simulation-based inference model that receives the encoded conditioning.
- __call__(t, obs, obs_ids, cond, cond_ids, conditioned=True, guidance=None, encoder_key=None)[source]#
Encode conditioning images and run the SBI model forward pass.
- Parameters:
t (Array) – Diffusion/flow timestep, shape
(batch,)or scalar.obs (Array) – Observation tokens (noisy samples) passed directly to the SBI model.
obs_ids (Array) – Positional IDs for the observation tokens.
cond (Array) – Raw 2D conditioning images to be encoded by the VAE, shape
(batch, H, W, C).cond_ids (Array) – Positional IDs for the conditioning tokens after patchification.
conditioned (bool or Array, optional) – Whether to apply conditioning (classifier-free guidance mask). Defaults to
True.guidance (Array or None, optional) – Guidance scale for classifier-free guidance.
Nonedisables it.encoder_key (jax.random.PRNGKey or None, optional) – PRNG key forwarded to the VAE encoder for stochastic encoding. Pass
Nonefor deterministic encoding.
- Returns:
Output of the SBI model with the VAE-encoded conditioning.
- Return type:
Array
- sbi_model#
- vae#
- class gensbi.experimental.models.FieldDiT(params)[source]#
Bases:
flax.nnx.ModuleConditional flow-matching network for 2D fields (pixel space).
Forward:
(t, obs=field, cond) -> velocity fieldof the same shape. The conv encoder is modulated by time only (or by the fullvecwhencond_modulates_encoder=True); the decoder and the MMDiT core are modulated byvec = time (+ cond summary if flagged-C) (+ guidance). The meeting-grid rope2d obs ids and absolute cond ids are built internally.- Parameters:
params (FieldDiTParams)
- cond_dim#
- cond_embedder#
- cond_ids#
- cond_modulates_encoder#
- core#
- decoder#
- encoder#
- field_shape#
- guidance_embed#
- in_channels#
- obs_ids#
- param_dtype#
- time_in#
- token_grid#
- tokenizer#
- untokenizer#
- use_cond_summary_in_vec#
- class gensbi.experimental.models.FieldDiTParams[source]#
Configuration for
FieldDiT(mirrors the style ofFlux1Params).The meeting-grid token count is derived from
field_shape,encoder_widths(depth) andpatch_size— it is not prescribed:tokens = (H / (2**D * p)) * (W / (2**D * p))withD = len(encoder_widths) - 1.Note:
rngsis a livennx.Rngsstream (mirrorsFlux1Params). Constructing two models from the same params object yields different weights, because the stream advances; build a freshFieldDiTParams(or a freshnnx.Rngs(seed)) per model for reproducibility.- axes_dim: List[int] | None = None#
- cond_dim: int#
- cond_in_channels: int = 1#
- cond_modulates_encoder: bool = False#
- depth: int = 2#
- depth_single_blocks: int = 2#
- encoder_widths: Tuple[int, Ellipsis]#
- field_shape: Tuple[int, int]#
- guidance_embed: bool = False#
- in_channels: int#
- mlp_ratio: float = 4.0#
- norm_groups: int = 8#
- num_heads: int = 12#
- param_dtype: jax.typing.DTypeLike#
- patch_size: int = 2#
- qkv_bias: bool = False#
- res_blocks_down: int = 2#
- res_blocks_up: int = 2#
- rngs: flax.nnx.Rngs#
- theta: int | None = None#
- use_cond_summary_in_vec: bool = True#
- vec_in_dim: int | None = None#
- class gensbi.experimental.models.PixelDiT(params)[source]#
Bases:
flax.nnx.ModuleDual-level pixel-space DiT for conditional flow matching on 2D fields.
Forward:
(t, obs=field, cond) -> velocity fieldof the same shape. Patch-level MMDiT blocks attend over patch tokens jointly with cond tokens, producing per-patch conditionings_cond; pixel-level PiT blocks then refine every pixel inside each patch under that conditioning.- Parameters:
params (PixelDiTParams)
- cond_dim#
- cond_embedder#
- cond_in_channels#
- field_shape#
- final_layer#
- in_channels#
- param_dtype#
- patch_blocks#
- patch_size#
- pe_patch#
- pe_pit#
- pixel_blocks#
- pixel_embedder#
- s_embedder#
- t_conditioner#
- token_grid#
- class gensbi.experimental.models.PixelDiTParams[source]#
Configuration for
PixelDiT(mirrors the style ofFieldDiTParams).Note:
rngsis a livennx.Rngsstream (mirrorsFieldDiTParams/Flux1Params). Constructing two models from the same params object yields different weights, because the stream advances; build a freshPixelDiTParams(or a freshnnx.Rngs(seed)) per model for reproducibility.- cond_dim: int#
- cond_id_embedding: str = 'absolute'#
- cond_in_channels: int = 1#
- field_shape: Tuple[int, int]#
- in_channels: int#
- mlp_ratio: float = 4.0#
- num_heads: int = 6#
- param_dtype: jax.typing.DTypeLike#
- patch_depth: int = 6#
- patch_size: int = 4#
- pit_post_modulation: bool = False#
- pixel_depth: int = 2#
- pixel_num_heads: int | None = None#
- rngs: flax.nnx.Rngs#
- rope_scale: float = 16.0#
- theta: float = 10000.0#
- use_cond_rope: bool = True#
- use_pixel_abs_pos: bool = True#
- zero_init_blocks: bool = True#
- gensbi.experimental.models.vae_loss_fn(model, x, key, kl_weight=0.1)[source]#
Compute the VAE loss as the sum of reconstruction and KL divergence losses.
- Parameters:
model (The VAE model.)
x (Input data.)
key (PRNG key for stochastic operations.)
kl_weight (Weight for the KL divergence term. Defaults to 0.1.)
- Return type:
Scalar loss value combining reconstruction and KL losses.