gensbi.models.embedding#

Submodules#

Classes#

Embed

Wrapper around nnx.Embed that handles 3D input by removing the last dimension.

FeatureEmbedder

General Feature Embedder supporting learned, 1D sinusoidal, and 2D sinusoidal embeddings.

GaussianFourierEmbedding

Base class for all neural network modules.

MLPEmbedder

MLP-based embedder with skip connections.

SimpleTimeEmbedding

Simple time embedding module using cosine and sine transformations.

SinusoidalPosEmbed1D

Base class for all neural network modules.

SinusoidalPosEmbed2D

Base class for all neural network modules.

SinusoidalTimeEmbedding

Base class for all neural network modules.

Package Contents#

class gensbi.models.embedding.Embed(*args, **kwargs)[source]#

Bases: flax.nnx.Module

Wrapper around nnx.Embed that handles 3D input by removing the last dimension.

Parameters:
  • *args – Positional arguments passed to nnx.Embed.

  • **kwargs – Keyword arguments passed to nnx.Embed.

__call__(ids)[source]#

Apply embedding to input IDs.

Parameters:

ids (Array) – Input IDs with shape (batch, seq_len, 1).

Returns:

Embedded output.

Return type:

Array

embed#
class gensbi.models.embedding.FeatureEmbedder(num_embeddings, hidden_size, *, kind='absolute', param_dtype=jnp.float32, rngs=None, **kwargs)[source]#

Bases: flax.nnx.Module

General Feature Embedder supporting learned, 1D sinusoidal, and 2D sinusoidal embeddings. 1D sinusoidal embeddings are suitable for sequences, while 2D sinusoidal embeddings are ideal for grid-like data (e.g., images).

Parameters:
  • num_embeddings (int) – Number of embeddings.

  • hidden_size (int) – Hidden size/embedding dimension.

  • kind (str, optional) – Type of embedding: ‘absolute’, ‘pos1d’, or ‘pos2d’. Defaults to ‘absolute’.

  • param_dtype (DTypeLike, optional) – Data type for parameters. Defaults to jnp.float32.

  • rngs (nnx.Rngs, optional) – Random number generators for initialization.

  • **kwargs – Additional keyword arguments specific to the embedding type.

__call__(ids)[source]#

Apply feature embedding to input IDs.

Parameters:

ids (Array) – Input IDs.

Returns:

Embedded features.

Return type:

Array

class gensbi.models.embedding.GaussianFourierEmbedding(output_dim=128, learnable=False, *, rngs, param_dtype=jnp.float32)[source]#

Bases: flax.nnx.Module

Base class for all neural network modules.

Layers and models should subclass this class.

Module’s can contain submodules, and in this way can be nested in a tree structure. Submodules can be assigned as regular attributes inside the __init__ method.

You can define arbitrary “forward pass” methods on your Module subclass. While no methods are special-cased, __call__ is a popular choice since you can call the Module directly:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     x = nnx.relu(x)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> y = model(x)
Parameters:
  • output_dim (int)

  • learnable (bool)

  • rngs (flax.nnx.Rngs)

  • param_dtype (jax.typing.DTypeLike)

__call__(t)[source]#

Compute Gaussian Fourier time embedding.

Parameters:

t (Array) – Time values.

Returns:

Gaussian Fourier time embeddings.

Return type:

Array

B#
output_dim = 128#
class gensbi.models.embedding.MLPEmbedder(in_dim, hidden_dim, rngs, param_dtype=jnp.float32)[source]#

Bases: flax.nnx.Module

MLP-based embedder with skip connections.

Parameters:
  • in_dim (int) – Input dimension.

  • hidden_dim (int) – Hidden dimension, must be a multiple of in_dim.

  • rngs (nnx.Rngs) – Random number generators for initialization.

  • param_dtype (DTypeLike, optional) – Data type for parameters. Defaults to jnp.float32.

__call__(x)[source]#

Forward pass of the MLP embedder.

Parameters:

x (Array) – Input array.

Returns:

Embedded output with skip connections.

Return type:

Array

in_layer#
out_layer#
p_skip#
repeats#
silu#
class gensbi.models.embedding.SimpleTimeEmbedding[source]#

Bases: flax.nnx.Module

Simple time embedding module using cosine and sine transformations.

__call__(t)[source]#

Compute time embedding.

Parameters:

t (Array) – Time values.

Returns:

Time embeddings.

Return type:

Array

class gensbi.models.embedding.SinusoidalPosEmbed1D(hidden_size, max_len=5000, param_dtype=jnp.float32)[source]#

Bases: flax.nnx.Module

Base class for all neural network modules.

Layers and models should subclass this class.

Module’s can contain submodules, and in this way can be nested in a tree structure. Submodules can be assigned as regular attributes inside the __init__ method.

You can define arbitrary “forward pass” methods on your Module subclass. While no methods are special-cased, __call__ is a popular choice since you can call the Module directly:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     x = nnx.relu(x)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> y = model(x)
Parameters:
  • hidden_size (int)

  • max_len (int)

  • param_dtype (jax.typing.DTypeLike)

__call__(ids)[source]#

Forward pass of the 1D sinusoidal position embedder.

Parameters:

ids (Array) – Input IDs with shape (batch, seq_len).

Returns:

Position embeddings of shape (1, seq_len, hidden_size).

Return type:

Array

hidden_size#
pe#
class gensbi.models.embedding.SinusoidalPosEmbed2D(hidden_size, max_h=128, max_w=128, param_dtype=jnp.float32)[source]#

Bases: flax.nnx.Module

Base class for all neural network modules.

Layers and models should subclass this class.

Module’s can contain submodules, and in this way can be nested in a tree structure. Submodules can be assigned as regular attributes inside the __init__ method.

You can define arbitrary “forward pass” methods on your Module subclass. While no methods are special-cased, __call__ is a popular choice since you can call the Module directly:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     x = nnx.relu(x)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> y = model(x)
Parameters:
  • hidden_size (int)

  • max_h (int)

  • max_w (int)

  • param_dtype (jax.typing.DTypeLike)

__call__(ids)[source]#

Compute 2D sinusoidal position embeddings.

Parameters:

ids (Array) – Input IDs with shape (batch, h, w).

Returns:

2D position embeddings of shape (batch, h*w, hidden_size).

Return type:

Array

hidden_size#
pe_h#
pe_w#
class gensbi.models.embedding.SinusoidalTimeEmbedding(output_dim=128)[source]#

Bases: flax.nnx.Module

Base class for all neural network modules.

Layers and models should subclass this class.

Module’s can contain submodules, and in this way can be nested in a tree structure. Submodules can be assigned as regular attributes inside the __init__ method.

You can define arbitrary “forward pass” methods on your Module subclass. While no methods are special-cased, __call__ is a popular choice since you can call the Module directly:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     x = nnx.relu(x)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> y = model(x)
Parameters:

output_dim (int)

__call__(t)[source]#

Compute sinusoidal time embedding.

Parameters:

t (Array) – Time values.

Returns:

Sinusoidal time embeddings.

Return type:

Array

output_dim = 128#