gensbi.experimental.models.pixeldit.model#

PixelDiT config and assembly.

PixelDiT = dual-level pixel-space DiT: a patch-level MMDiT transformer over patch tokens (with joint cond attention) feeds per-patch conditioning into a pixel-level PiT stack that refines every pixel inside each patch, for conditional pixel-space flow matching on 2D fields.

Faithful port of reference/PixelDiT/pixdit_core/pixeldit_t2i.py (cond enters via tokens only; c = silu(t_emb)), minus repa / attention-mask / s caching (YAGNI).

Classes#

PixelDiT

Dual-level pixel-space DiT for conditional flow matching on 2D fields.

PixelDiTParams

Configuration for PixelDiT (mirrors the style of FieldDiTParams).

Module Contents#

class gensbi.experimental.models.pixeldit.model.PixelDiT(params)[source]#

Bases: flax.nnx.Module

Dual-level pixel-space DiT for conditional flow matching on 2D fields.

Forward: (t, obs=field, cond) -> velocity field of the same shape. Patch-level MMDiT blocks attend over patch tokens jointly with cond tokens, producing per-patch conditioning s_cond; pixel-level PiT blocks then refine every pixel inside each patch under that conditioning.

Parameters:

params (PixelDiTParams)

__call__(t, obs, cond, obs_ids=None, cond_ids=None, conditioned=True, guidance=None)[source]#
cond_dim[source]#
cond_embedder[source]#
cond_in_channels[source]#
field_shape[source]#
final_layer[source]#
in_channels[source]#
param_dtype[source]#
patch_blocks[source]#
patch_size[source]#
pe_patch[source]#
pe_pit[source]#
pixel_blocks[source]#
pixel_embedder[source]#
s_embedder[source]#
t_conditioner[source]#
token_grid[source]#
class gensbi.experimental.models.pixeldit.model.PixelDiTParams[source]#

Configuration for PixelDiT (mirrors the style of FieldDiTParams).

Note: rngs is a live nnx.Rngs stream (mirrors FieldDiTParams / Flux1Params). Constructing two models from the same params object yields different weights, because the stream advances; build a fresh PixelDiTParams (or a fresh nnx.Rngs(seed)) per model for reproducibility.

__post_init__()[source]#
cond_dim: int[source]#
cond_id_embedding: str = 'absolute'[source]#
cond_in_channels: int = 1[source]#
field_shape: Tuple[int, int][source]#
hidden_size: int = 384[source]#
in_channels: int[source]#
mlp_ratio: float = 4.0[source]#
num_heads: int = 6[source]#
param_dtype: jax.typing.DTypeLike[source]#
patch_depth: int = 6[source]#
patch_size: int = 4[source]#
pit_post_modulation: bool = False[source]#
pixel_attn_hidden_size: int | None = None[source]#
pixel_depth: int = 2[source]#
pixel_hidden_size: int = 16[source]#
pixel_num_heads: int | None = None[source]#
rngs: flax.nnx.Rngs[source]#
rope_scale: float = 16.0[source]#
theta: float = 10000.0[source]#
use_cond_rope: bool = True[source]#
use_pixel_abs_pos: bool = True[source]#
zero_init_blocks: bool = True[source]#