gensbi.experimental.models.pixeldit#

PixelDiT: dual-level pixel-space DiT for conditional flow matching on 2D fields.

Submodules#

Classes#

MMDiTBlock

Joint-attention MMDiT block with per-stream rope (ref lines 19-132).

PiTBlock

Pixel-level DiT block: pixel-wise AdaLN + token compaction (ref

PixelDiT

Dual-level pixel-space DiT for conditional flow matching on 2D fields.

PixelDiTParams

Configuration for PixelDiT (mirrors the style of FieldDiTParams).

Package Contents#

class gensbi.experimental.models.pixeldit.MMDiTBlock(hidden_size, num_heads, mlp_ratio=4.0, *, zero_init=True, rngs, param_dtype=jnp.bfloat16)[source]#

Bases: flax.nnx.Module

Joint-attention MMDiT block with per-stream rope (ref lines 19-132).

Parameters:
  • hidden_size (int)

  • num_heads (int)

  • mlp_ratio (float)

  • zero_init (bool)

  • rngs (flax.nnx.Rngs)

  • param_dtype (jax.typing.DTypeLike)

__call__(x, y, c, pe_x, pe_y=None)[source]#
_qkv(qkv_lin, qk_norm, h, pe)[source]#

Project, split heads, q/k-norm, and (optionally) apply per-stream rope.

adaLN_x#
adaLN_y#
head_dim#
hidden_size#
mlp_x#
mlp_y#
norm_x1#
norm_x2#
norm_y1#
norm_y2#
num_heads#
proj_x#
proj_y#
qk_norm_x#
qk_norm_y#
qkv_x#
qkv_y#
class gensbi.experimental.models.pixeldit.PiTBlock(pixel_dim, context_dim, patch_size, attn_dim, num_heads, mlp_ratio=4.0, post_modulation=False, *, zero_init=True, rngs, param_dtype=jnp.bfloat16)[source]#

Bases: flax.nnx.Module

Pixel-level DiT block: pixel-wise AdaLN + token compaction (ref pixeldit_c2i.py:114-187).

Operates on per-patch pixel tokens x of shape (B*L, p^2, D_pix) with per-patch conditioning s_cond of shape (B*L, D). The adaLN projection maps D -> n_mod * D_pix * p^2 and is reshaped to (B*L, p^2, n_mod * D_pix), so every pixel in the patch receives its own modulation parameters (“pixel-wise AdaLN”).

Attention runs on compacted tokens: each patch’s p^2 * D_pix pixels are compressed to a single attn_dim token, attended globally over the (B, L) patch grid with 2D rope, then expanded back. The attention itself is flux1’s SelfAttention, which matches the reference RotaryAttention (no-bias qkv, per-head RMSNorm on q/k, biased proj).

post_modulation selects the gate-free post-modulation variant (n_mod = 4, chunk order scale1, shift1, scale2, shift2 — scale before shift, ref line 168). The default pre-variant uses the standard six-chunk DiT order and is an exact identity at zero_init=True; the post variant is not (no gates close the residuals).

pe is the output of precompute_freqs_cis_2d(attn_dim // num_heads, grid_h, grid_w, ...) over the patch grid, shape (1, 1, L, head_dim/2, 2, 2).

Parameters:
  • pixel_dim (int)

  • context_dim (int)

  • patch_size (int)

  • attn_dim (int)

  • num_heads (int)

  • mlp_ratio (float)

  • post_modulation (bool)

  • zero_init (bool)

  • rngs (flax.nnx.Rngs)

  • param_dtype (jax.typing.DTypeLike)

__call__(x, s_cond, pe, batch)[source]#
adaLN#
attn#
attn_dim#
compress#
expand#
mlp#
n_mod = 4#
norm1#
norm2#
patch_size#
pixel_dim#
post_modulation = False#
class gensbi.experimental.models.pixeldit.PixelDiT(params)[source]#

Bases: flax.nnx.Module

Dual-level pixel-space DiT for conditional flow matching on 2D fields.

Forward: (t, obs=field, cond) -> velocity field of the same shape. Patch-level MMDiT blocks attend over patch tokens jointly with cond tokens, producing per-patch conditioning s_cond; pixel-level PiT blocks then refine every pixel inside each patch under that conditioning.

Parameters:

params (PixelDiTParams)

__call__(t, obs, cond, obs_ids=None, cond_ids=None, conditioned=True, guidance=None)[source]#
cond_dim#
cond_embedder#
cond_in_channels#
field_shape#
final_layer#
in_channels#
param_dtype#
patch_blocks#
patch_size#
pe_patch#
pe_pit#
pixel_blocks#
pixel_embedder#
s_embedder#
t_conditioner#
token_grid#
class gensbi.experimental.models.pixeldit.PixelDiTParams[source]#

Configuration for PixelDiT (mirrors the style of FieldDiTParams).

Note: rngs is a live nnx.Rngs stream (mirrors FieldDiTParams / Flux1Params). Constructing two models from the same params object yields different weights, because the stream advances; build a fresh PixelDiTParams (or a fresh nnx.Rngs(seed)) per model for reproducibility.

__post_init__()[source]#
cond_dim: int#
cond_id_embedding: str = 'absolute'#
cond_in_channels: int = 1#
field_shape: Tuple[int, int]#
hidden_size: int = 384#
in_channels: int#
mlp_ratio: float = 4.0#
num_heads: int = 6#
param_dtype: jax.typing.DTypeLike#
patch_depth: int = 6#
patch_size: int = 4#
pit_post_modulation: bool = False#
pixel_attn_hidden_size: int | None = None#
pixel_depth: int = 2#
pixel_hidden_size: int = 16#
pixel_num_heads: int | None = None#
rngs: flax.nnx.Rngs#
rope_scale: float = 16.0#
theta: float = 10000.0#
use_cond_rope: bool = True#
use_pixel_abs_pos: bool = True#
zero_init_blocks: bool = True#