gensbi.experimental.models.pixeldit.blocks#

Transformer blocks for the PixelDiT port: patch-level MMDiT and pixel-level PiT.

MMDiTBlock is a faithful port of MMDiTBlockT2I + MMDiTJointAttention (reference/PixelDiT/pixdit_core/pixeldit_t2i.py:19-132) collapsed into a single flax.nnx module. PiTBlock ports the pixel-level block (pixeldit_c2i.py:114-187); see its docstring.

Two streams x (obs patches, (B, Lx, D)) and y (cond, (B, Ly, D)) each get their own qkv / proj / norms / mlp / adaLN, then share one joint scaled-dot-product attention over the concatenated token axis (cond first, matching the reference torch.cat([y, x]) order).

Rope is applied per stream before the joint concat, because the two streams carry different rope tables (a 2D grid for obs, a 1D table — or none — for cond). The library’s usual attention(pe=...) pattern applies one pe to the whole concatenated sequence, which cannot express per-stream tables; so we call flux1.math.apply_rope() ourselves on each stream and then invoke attention(..., pe=None), which reduces to plain sdpa plus the output rearrange — math identical to the reference.

The adaLN projections are plain linears with no internal silu: c arrives pre-activated, so reusing flux1’s Modulation (which applies silu inside) would double-activate. zero_init zero-initialises the adaLN kernel+bias so the gated residuals start as an exact identity (c2i recipe); zero_init=False uses default init (t2i recipe).

No attention-mask support: the cond length is fixed (YAGNI).

Classes#

MMDiTBlock

Joint-attention MMDiT block with per-stream rope (ref lines 19-132).

PiTBlock

Pixel-level DiT block: pixel-wise AdaLN + token compaction (ref

Module Contents#

class gensbi.experimental.models.pixeldit.blocks.MMDiTBlock(hidden_size, num_heads, mlp_ratio=4.0, *, zero_init=True, rngs, param_dtype=jnp.bfloat16)[source]#

Bases: flax.nnx.Module

Joint-attention MMDiT block with per-stream rope (ref lines 19-132).

Parameters:
  • hidden_size (int)

  • num_heads (int)

  • mlp_ratio (float)

  • zero_init (bool)

  • rngs (flax.nnx.Rngs)

  • param_dtype (jax.typing.DTypeLike)

__call__(x, y, c, pe_x, pe_y=None)[source]#
_qkv(qkv_lin, qk_norm, h, pe)[source]#

Project, split heads, q/k-norm, and (optionally) apply per-stream rope.

adaLN_x[source]#
adaLN_y[source]#
head_dim[source]#
hidden_size[source]#
mlp_x[source]#
mlp_y[source]#
norm_x1[source]#
norm_x2[source]#
norm_y1[source]#
norm_y2[source]#
num_heads[source]#
proj_x[source]#
proj_y[source]#
qk_norm_x[source]#
qk_norm_y[source]#
qkv_x[source]#
qkv_y[source]#
class gensbi.experimental.models.pixeldit.blocks.PiTBlock(pixel_dim, context_dim, patch_size, attn_dim, num_heads, mlp_ratio=4.0, post_modulation=False, *, zero_init=True, rngs, param_dtype=jnp.bfloat16)[source]#

Bases: flax.nnx.Module

Pixel-level DiT block: pixel-wise AdaLN + token compaction (ref pixeldit_c2i.py:114-187).

Operates on per-patch pixel tokens x of shape (B*L, p^2, D_pix) with per-patch conditioning s_cond of shape (B*L, D). The adaLN projection maps D -> n_mod * D_pix * p^2 and is reshaped to (B*L, p^2, n_mod * D_pix), so every pixel in the patch receives its own modulation parameters (“pixel-wise AdaLN”).

Attention runs on compacted tokens: each patch’s p^2 * D_pix pixels are compressed to a single attn_dim token, attended globally over the (B, L) patch grid with 2D rope, then expanded back. The attention itself is flux1’s SelfAttention, which matches the reference RotaryAttention (no-bias qkv, per-head RMSNorm on q/k, biased proj).

post_modulation selects the gate-free post-modulation variant (n_mod = 4, chunk order scale1, shift1, scale2, shift2 — scale before shift, ref line 168). The default pre-variant uses the standard six-chunk DiT order and is an exact identity at zero_init=True; the post variant is not (no gates close the residuals).

pe is the output of precompute_freqs_cis_2d(attn_dim // num_heads, grid_h, grid_w, ...) over the patch grid, shape (1, 1, L, head_dim/2, 2, 2).

Parameters:
  • pixel_dim (int)

  • context_dim (int)

  • patch_size (int)

  • attn_dim (int)

  • num_heads (int)

  • mlp_ratio (float)

  • post_modulation (bool)

  • zero_init (bool)

  • rngs (flax.nnx.Rngs)

  • param_dtype (jax.typing.DTypeLike)

__call__(x, s_cond, pe, batch)[source]#
adaLN[source]#
attn[source]#
attn_dim[source]#
compress[source]#
expand[source]#
mlp[source]#
n_mod = 4[source]#
norm1[source]#
norm2[source]#
patch_size[source]#
pixel_dim[source]#
post_modulation = False[source]#