gensbi.experimental.models.pixeldit.modules#

Low-level building blocks for the PixelDiT port.

Faithful port of reference/PixelDiT/pixdit_core/modules.py into flax.nnx. Deliberate deviations from the reference are called out inline.

Classes#

Buffer

Non-trainable array buffer (rope tables, sincos tables, etc.).

FinalLayer

RMSNorm + zero-init linear output projection (ref modules.py:241-250).

PixelMLP

Standard GELU MLP used in pixel-level DiT blocks (ref MLP, modules.py:223-238).

SwiGLU

SwiGLU feed-forward (ref FeedForward, modules.py:119-129).

TimestepConditioner

Embed a scalar timestep into hidden_size (ref modules.py:63-91).

Functions#

_get_1d_sincos_pos_embed_from_grid(embed_dim, pos)

1D sincos embedding for a grid of positions (ref modules.py:38-56).

_timestep_embedding(t, dim[, max_period])

Sinusoidal timestep embedding in float32 (ref TimestepConditioner.timestep_embedding).

get_2d_sincos_pos_embed(embed_dim, h, w)

Pure-numpy 2D sincos positional embedding table.

Module Contents#

class gensbi.experimental.models.pixeldit.modules.Buffer(value, *, hijax=None, ref=None, eager_sharding=None, **metadata)[source]#

Bases: flax.nnx.Variable

Non-trainable array buffer (rope tables, sincos tables, etc.).

A dedicated Variable type so buffers are: - filterable with nnx.state(model, Buffer), - excluded from nnx.Param state (safe from dtype casts applied to

parameter state), and

  • importable by Tasks 3 and 6 which need the same distinction.

Mirrors RopeIds in fielddit/model.py:28-34.

Parameters:
  • value (A | VariableMetadata[A])

  • hijax (bool | None)

  • ref (bool | None)

  • eager_sharding (bool | None)

  • metadata (Any)

class gensbi.experimental.models.pixeldit.modules.FinalLayer(hidden_size, out_channels, *, rngs, param_dtype=jnp.bfloat16)[source]#

Bases: flax.nnx.Module

RMSNorm + zero-init linear output projection (ref modules.py:241-250).

The linear’s kernel and bias are both zero-initialized, so the layer is exactly the zero map at initialisation (safe residual scaling).

Parameters:
  • hidden_size (int)

  • out_channels (int)

  • rngs (flax.nnx.Rngs)

  • param_dtype (jax.typing.DTypeLike)

__call__(x)[source]#
linear[source]#
norm[source]#
class gensbi.experimental.models.pixeldit.modules.PixelMLP(dim, mlp_ratio=4.0, *, rngs, param_dtype=jnp.bfloat16)[source]#

Bases: flax.nnx.Module

Standard GELU MLP used in pixel-level DiT blocks (ref MLP, modules.py:223-238).

Linear(dim 4·dim) → GELU → Linear(→ dim). Both linears have biases; no dropout (we never train with dropout).

Parameters:
  • dim (int)

  • mlp_ratio (float)

  • rngs (flax.nnx.Rngs)

  • param_dtype (jax.typing.DTypeLike)

__call__(x)[source]#
fc1[source]#
fc2[source]#
class gensbi.experimental.models.pixeldit.modules.SwiGLU(dim, mlp_ratio=4.0, *, rngs, param_dtype=jnp.bfloat16)[source]#

Bases: flax.nnx.Module

SwiGLU feed-forward (ref FeedForward, modules.py:119-129).

hidden = int(2 * (dim * mlp_ratio) / 3); all three projections are bias-free. Forward: w2(silu(w1(x)) * w3(x)).

Parameters:
  • dim (int)

  • mlp_ratio (float)

  • rngs (flax.nnx.Rngs)

  • param_dtype (jax.typing.DTypeLike)

__call__(x)[source]#
w1[source]#
w2[source]#
w3[source]#
class gensbi.experimental.models.pixeldit.modules.TimestepConditioner(hidden_size, freq_dim=256, *, rngs, param_dtype=jnp.bfloat16)[source]#

Bases: flax.nnx.Module

Embed a scalar timestep into hidden_size (ref modules.py:63-91).

The sinusoid is always computed in float32; the result is cast to param_dtype only before the MLP projection. MLP weights are initialised with normal(std=0.02) (ref initialize_weights).

Parameters:
  • hidden_size (int)

  • freq_dim (int)

  • rngs (flax.nnx.Rngs)

  • param_dtype (jax.typing.DTypeLike)

__call__(t)[source]#
freq_dim = 256[source]#
mlp_in[source]#
mlp_out[source]#
gensbi.experimental.models.pixeldit.modules._get_1d_sincos_pos_embed_from_grid(embed_dim, pos)[source]#

1D sincos embedding for a grid of positions (ref modules.py:38-56).

Returns (M, embed_dim) where the order is [sin, cos] (reference line 55: cat([emb_sin, emb_cos])).

Parameters:
  • embed_dim (int)

  • pos (numpy.ndarray)

Return type:

numpy.ndarray

gensbi.experimental.models.pixeldit.modules._timestep_embedding(t, dim, max_period=10.0)[source]#

Sinusoidal timestep embedding in float32 (ref TimestepConditioner.timestep_embedding).

Deliberate deviations from flux1.timestep_embedding: - max_period = 10 (not 10000). - No ×1000 time factor. - Order is cat([cos, sin]) (reference line 80). - Always computed in float32 (FieldDiT lesson: bf16 t quantizes fine

differences; cast to param_dtype only before the MLP).

Parameters:
  • dim (int)

  • max_period (float)

gensbi.experimental.models.pixeldit.modules.get_2d_sincos_pos_embed(embed_dim, h, w)[source]#

Pure-numpy 2D sincos positional embedding table.

Returns (h*w, embed_dim) float32. Supports h != w.

The axis convention is copied exactly from the reference (modules.py:16-22): grid = meshgrid(grid_w, grid_h) — w goes first — so grid[0] is the w-axis and grid[1] is the h-axis. emb_h is built from grid[0] (the w-grid) following the reference naming; do not “fix” this.

Parameters:
  • embed_dim (int)

  • h (int)

  • w (int)

Return type:

numpy.ndarray