gensbi.experimental.models#

Submodules#

Classes#

AutoEncoder1D

1D Autoencoder model with Gaussian latent space.

AutoEncoder2D

2D Autoencoder model with Gaussian latent space.

AutoEncoderParams

Configuration parameters for the AutoEncoder models.

Functions#

vae_loss_fn(model, x, key[, kl_weight])

Compute the VAE loss as the sum of reconstruction and KL divergence losses.

Package Contents#

class gensbi.experimental.models.AutoEncoder1D(params)[source]#

Bases: flax.nnx.Module

1D Autoencoder model with Gaussian latent space.

Parameters:

params (AutoEncoderParams) – Configuration parameters for the autoencoder.

__call__(x, key=None)[source]#

Forward pass: encode and then decode the input.

Parameters:

x (Array) – Input tensor.

Returns:

Reconstructed output.

Return type:

Array

decode(z)[source]#

Decode latent representation back to data space.

Parameters:

z (Array) – Latent tensor.

Returns:

Reconstructed output.

Return type:

Array

encode(x, key=None)[source]#

Encode input data into the latent space.

Parameters:
  • x (Array) – Input tensor.

  • key (Array) – PRNG key for sampling the latent variable.

Returns:

Latent representation.

Return type:

Array

Decoder1D#
Encoder1D#
latent_shape#
reg#
rngs#
scale_factor#
shift_factor#
class gensbi.experimental.models.AutoEncoder2D(params)[source]#

Bases: flax.nnx.Module

2D Autoencoder model with Gaussian latent space.

Parameters:

params (AutoEncoderParams) – Configuration parameters for the autoencoder.

__call__(x, key=None)[source]#

Forward pass: encode and then decode the input.

Parameters:

x (Array) – Input tensor.

Returns:

Reconstructed output.

Return type:

Array

decode(z)[source]#

Decode latent representation back to data space.

Parameters:

z (Array) – Latent tensor.

Returns:

Reconstructed output.

Return type:

Array

encode(x, key=None)[source]#

Encode input data into the latent space.

Parameters:
  • x (Array) – Input tensor.

  • key (Array) – PRNG key for sampling the latent variable.

Returns:

Latent representation.

Return type:

Array

Decoder2D#
Encoder2D#
latent_shape#
reg#
rngs#
scale_factor#
shift_factor#
class gensbi.experimental.models.AutoEncoderParams[source]#

Configuration parameters for the AutoEncoder models.

resolution#

The input feature dimension (length for 1D, height/width for 2D).

Type:

int

in_channels#

Number of input channels (e.g., 1 for scalar features, >1 for multi-channel).

Type:

int

ch#

Base number of channels for the first convolutional layer.

Type:

int

out_ch#

Number of output channels produced by the decoder (matches input channels for reconstruction).

Type:

int

ch_mult#

Multipliers for the number of channels at each resolution level (controls model width/depth).

Type:

list[int]

num_res_blocks#

Number of residual blocks per resolution level.

Type:

int

z_channels#

Number of latent channels in the bottleneck (size of encoded representation).

Type:

int

scale_factor#

Scaling factor applied to the latent representation (for normalization or data scaling).

Type:

float

shift_factor#

Shift factor applied to the latent representation (for normalization or data centering).

Type:

float

rngs#

Random number generators for parameter initialization and stochastic layers.

Type:

nnx.Rngs

param_dtype#

Data type for model parameters (e.g., jnp.float32, jnp.bfloat16).

Type:

DTypeLike

ch: int#
ch_mult: list[int]#
in_channels: int#
num_res_blocks: int#
out_ch: int#
param_dtype: jax.typing.DTypeLike#
resolution: int#
rngs: flax.nnx.Rngs#
scale_factor: float#
shift_factor: float#
z_channels: int#
gensbi.experimental.models.vae_loss_fn(model, x, key, kl_weight=0.1)[source]#

Compute the VAE loss as the sum of reconstruction and KL divergence losses.

Parameters:
  • model (The VAE model.)

  • x (Input data.)

  • key (PRNG key for stochastic operations.)

  • kl_weight (Weight for the KL divergence term. Defaults to 0.1.)

Return type:

Scalar loss value combining reconstruction and KL losses.