gensbi.experimental.models.pixeldit.modules#
Low-level building blocks for the PixelDiT port.
Faithful port of reference/PixelDiT/pixdit_core/modules.py into flax.nnx.
Deliberate deviations from the reference are called out inline.
Classes#
Non-trainable array buffer (rope tables, sincos tables, etc.). |
|
RMSNorm + zero-init linear output projection (ref modules.py:241-250). |
|
Standard GELU MLP used in pixel-level DiT blocks (ref |
|
SwiGLU feed-forward (ref |
|
Embed a scalar timestep into |
Functions#
|
1D sincos embedding for a grid of positions (ref modules.py:38-56). |
|
Sinusoidal timestep embedding in float32 (ref |
|
Pure-numpy 2D sincos positional embedding table. |
Module Contents#
- class gensbi.experimental.models.pixeldit.modules.Buffer(value, *, hijax=None, ref=None, eager_sharding=None, **metadata)[source]#
Bases:
flax.nnx.VariableNon-trainable array buffer (rope tables, sincos tables, etc.).
A dedicated Variable type so buffers are: - filterable with
nnx.state(model, Buffer), - excluded fromnnx.Paramstate (safe from dtype casts applied toparameter state), and
importable by Tasks 3 and 6 which need the same distinction.
Mirrors
RopeIdsinfielddit/model.py:28-34.- Parameters:
value (A | VariableMetadata[A])
hijax (bool | None)
ref (bool | None)
eager_sharding (bool | None)
metadata (Any)
- class gensbi.experimental.models.pixeldit.modules.FinalLayer(hidden_size, out_channels, *, rngs, param_dtype=jnp.bfloat16)[source]#
Bases:
flax.nnx.ModuleRMSNorm + zero-init linear output projection (ref modules.py:241-250).
The linear’s kernel and bias are both zero-initialized, so the layer is exactly the zero map at initialisation (safe residual scaling).
- Parameters:
hidden_size (int)
out_channels (int)
rngs (flax.nnx.Rngs)
param_dtype (jax.typing.DTypeLike)
- class gensbi.experimental.models.pixeldit.modules.PixelMLP(dim, mlp_ratio=4.0, *, rngs, param_dtype=jnp.bfloat16)[source]#
Bases:
flax.nnx.ModuleStandard GELU MLP used in pixel-level DiT blocks (ref
MLP, modules.py:223-238).Linear(dim → 4·dim)→ GELU →Linear(→ dim). Both linears have biases; no dropout (we never train with dropout).- Parameters:
dim (int)
mlp_ratio (float)
rngs (flax.nnx.Rngs)
param_dtype (jax.typing.DTypeLike)
- class gensbi.experimental.models.pixeldit.modules.SwiGLU(dim, mlp_ratio=4.0, *, rngs, param_dtype=jnp.bfloat16)[source]#
Bases:
flax.nnx.ModuleSwiGLU feed-forward (ref
FeedForward, modules.py:119-129).hidden = int(2 * (dim * mlp_ratio) / 3); all three projections are bias-free. Forward:w2(silu(w1(x)) * w3(x)).- Parameters:
dim (int)
mlp_ratio (float)
rngs (flax.nnx.Rngs)
param_dtype (jax.typing.DTypeLike)
- class gensbi.experimental.models.pixeldit.modules.TimestepConditioner(hidden_size, freq_dim=256, *, rngs, param_dtype=jnp.bfloat16)[source]#
Bases:
flax.nnx.ModuleEmbed a scalar timestep into
hidden_size(ref modules.py:63-91).The sinusoid is always computed in float32; the result is cast to
param_dtypeonly before the MLP projection. MLP weights are initialised withnormal(std=0.02)(refinitialize_weights).- Parameters:
hidden_size (int)
freq_dim (int)
rngs (flax.nnx.Rngs)
param_dtype (jax.typing.DTypeLike)
- gensbi.experimental.models.pixeldit.modules._get_1d_sincos_pos_embed_from_grid(embed_dim, pos)[source]#
1D sincos embedding for a grid of positions (ref modules.py:38-56).
Returns
(M, embed_dim)where the order is[sin, cos](reference line 55:cat([emb_sin, emb_cos])).- Parameters:
embed_dim (int)
pos (numpy.ndarray)
- Return type:
numpy.ndarray
- gensbi.experimental.models.pixeldit.modules._timestep_embedding(t, dim, max_period=10.0)[source]#
Sinusoidal timestep embedding in float32 (ref
TimestepConditioner.timestep_embedding).Deliberate deviations from
flux1.timestep_embedding: -max_period = 10(not 10000). - No ×1000 time factor. - Order iscat([cos, sin])(reference line 80). - Always computed in float32 (FieldDiT lesson: bf16tquantizes finedifferences; cast to param_dtype only before the MLP).
- Parameters:
dim (int)
max_period (float)
- gensbi.experimental.models.pixeldit.modules.get_2d_sincos_pos_embed(embed_dim, h, w)[source]#
Pure-numpy 2D sincos positional embedding table.
Returns
(h*w, embed_dim)float32. Supportsh != w.The axis convention is copied exactly from the reference (modules.py:16-22):
grid = meshgrid(grid_w, grid_h)— w goes first — sogrid[0]is the w-axis andgrid[1]is the h-axis.emb_his built fromgrid[0](the w-grid) following the reference naming; do not “fix” this.- Parameters:
embed_dim (int)
h (int)
w (int)
- Return type:
numpy.ndarray