gensbi.models.embedding#
Submodules#
Classes#
Wrapper around nnx.Embed that handles 3D input by removing the last dimension. |
|
General Feature Embedder supporting learned, 1D sinusoidal, and 2D sinusoidal embeddings. |
|
Base class for all neural network modules. |
|
MLP-based embedder with skip connections. |
|
Simple time embedding module using cosine and sine transformations. |
|
Base class for all neural network modules. |
|
Base class for all neural network modules. |
|
Base class for all neural network modules. |
Package Contents#
- class gensbi.models.embedding.Embed(*args, **kwargs)[source]#
Bases:
flax.nnx.ModuleWrapper around nnx.Embed that handles 3D input by removing the last dimension.
- Parameters:
*args – Positional arguments passed to nnx.Embed.
**kwargs – Keyword arguments passed to nnx.Embed.
- __call__(ids)[source]#
Apply embedding to input IDs.
- Parameters:
ids (Array) – Input IDs with shape (batch, seq_len, 1).
- Returns:
Embedded output.
- Return type:
Array
- embed#
- class gensbi.models.embedding.FeatureEmbedder(num_embeddings, hidden_size, *, kind='absolute', param_dtype=jnp.float32, rngs=None, **kwargs)[source]#
Bases:
flax.nnx.ModuleGeneral Feature Embedder supporting learned, 1D sinusoidal, and 2D sinusoidal embeddings. 1D sinusoidal embeddings are suitable for sequences, while 2D sinusoidal embeddings are ideal for grid-like data (e.g., images).
- Parameters:
num_embeddings (int) – Number of embeddings.
hidden_size (int) – Hidden size/embedding dimension.
kind (str, optional) – Type of embedding: ‘absolute’, ‘pos1d’, or ‘pos2d’. Defaults to ‘absolute’.
param_dtype (DTypeLike, optional) – Data type for parameters. Defaults to jnp.float32.
rngs (nnx.Rngs, optional) – Random number generators for initialization.
**kwargs – Additional keyword arguments specific to the embedding type.
- class gensbi.models.embedding.GaussianFourierEmbedding(output_dim=128, learnable=False, *, rngs, param_dtype=jnp.float32)[source]#
Bases:
flax.nnx.ModuleBase class for all neural network modules.
Layers and models should subclass this class.
Module’s can contain submodules, and in this way can be nested in a tree structure. Submodules can be assigned as regular attributes inside the__init__method.You can define arbitrary “forward pass” methods on your
Modulesubclass. While no methods are special-cased,__call__is a popular choice since you can call theModuledirectly:>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... x = nnx.relu(x) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> y = model(x)
- Parameters:
output_dim (int)
learnable (bool)
rngs (flax.nnx.Rngs)
param_dtype (jax.typing.DTypeLike)
- __call__(t)[source]#
Compute Gaussian Fourier time embedding.
- Parameters:
t (Array) – Time values.
- Returns:
Gaussian Fourier time embeddings.
- Return type:
Array
- B#
- output_dim = 128#
- class gensbi.models.embedding.MLPEmbedder(in_dim, hidden_dim, rngs, param_dtype=jnp.float32)[source]#
Bases:
flax.nnx.ModuleMLP-based embedder with skip connections.
- Parameters:
in_dim (int) – Input dimension.
hidden_dim (int) – Hidden dimension, must be a multiple of in_dim.
rngs (nnx.Rngs) – Random number generators for initialization.
param_dtype (DTypeLike, optional) – Data type for parameters. Defaults to jnp.float32.
- __call__(x)[source]#
Forward pass of the MLP embedder.
- Parameters:
x (Array) – Input array.
- Returns:
Embedded output with skip connections.
- Return type:
Array
- in_layer#
- out_layer#
- p_skip#
- repeats#
- silu#
- class gensbi.models.embedding.SimpleTimeEmbedding[source]#
Bases:
flax.nnx.ModuleSimple time embedding module using cosine and sine transformations.
- class gensbi.models.embedding.SinusoidalPosEmbed1D(hidden_size, max_len=5000, param_dtype=jnp.float32)[source]#
Bases:
flax.nnx.ModuleBase class for all neural network modules.
Layers and models should subclass this class.
Module’s can contain submodules, and in this way can be nested in a tree structure. Submodules can be assigned as regular attributes inside the__init__method.You can define arbitrary “forward pass” methods on your
Modulesubclass. While no methods are special-cased,__call__is a popular choice since you can call theModuledirectly:>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... x = nnx.relu(x) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> y = model(x)
- Parameters:
hidden_size (int)
max_len (int)
param_dtype (jax.typing.DTypeLike)
- __call__(ids)[source]#
Forward pass of the 1D sinusoidal position embedder.
- Parameters:
ids (Array) – Input IDs with shape (batch, seq_len).
- Returns:
Position embeddings of shape (1, seq_len, hidden_size).
- Return type:
Array
- pe#
- class gensbi.models.embedding.SinusoidalPosEmbed2D(hidden_size, max_h=128, max_w=128, param_dtype=jnp.float32)[source]#
Bases:
flax.nnx.ModuleBase class for all neural network modules.
Layers and models should subclass this class.
Module’s can contain submodules, and in this way can be nested in a tree structure. Submodules can be assigned as regular attributes inside the__init__method.You can define arbitrary “forward pass” methods on your
Modulesubclass. While no methods are special-cased,__call__is a popular choice since you can call theModuledirectly:>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... x = nnx.relu(x) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> y = model(x)
- Parameters:
hidden_size (int)
max_h (int)
max_w (int)
param_dtype (jax.typing.DTypeLike)
- __call__(ids)[source]#
Compute 2D sinusoidal position embeddings.
- Parameters:
ids (Array) – Input IDs with shape (batch, h, w).
- Returns:
2D position embeddings of shape (batch, h*w, hidden_size).
- Return type:
Array
- pe_h#
- pe_w#
- class gensbi.models.embedding.SinusoidalTimeEmbedding(output_dim=128)[source]#
Bases:
flax.nnx.ModuleBase class for all neural network modules.
Layers and models should subclass this class.
Module’s can contain submodules, and in this way can be nested in a tree structure. Submodules can be assigned as regular attributes inside the__init__method.You can define arbitrary “forward pass” methods on your
Modulesubclass. While no methods are special-cased,__call__is a popular choice since you can call theModuledirectly:>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... x = nnx.relu(x) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> y = model(x)
- Parameters:
output_dim (int)
- __call__(t)[source]#
Compute sinusoidal time embedding.
- Parameters:
t (Array) – Time values.
- Returns:
Sinusoidal time embeddings.
- Return type:
Array
- output_dim = 128#