gensbi.experimental.models.pixeldit.embedders#
Token embedders for the PixelDiT port.
Faithful port of reference/PixelDiT/pixdit_core/pixeldit_c2i.py embedders
plus the cond-token embedder from pixeldit_t2i.py. Operates channel-last
(B, H, W, C) throughout.
Classes#
Condition token embedder: linear → RMSNorm → add id embedding. |
|
Linear patch-token embedder (ref |
|
Per-pixel projection + optional sincos abs-pos + patch grouping. |
Functions#
|
Fold a channel-last image into patch tokens, row-major. |
|
Exact inverse of |
Module Contents#
- class gensbi.experimental.models.pixeldit.embedders.CondTokenEmbedder(cond_in_channels, hidden_size, n_tokens, *, id_embedding='absolute', rngs, param_dtype=jnp.bfloat16)[source]#
Bases:
flax.nnx.ModuleCondition token embedder: linear → RMSNorm → add id embedding.
Faithful port of the t2i cond pipeline (
y_embedder+y_pos_embedding, pixeldit_t2i.py:179-180, 267-268).- Parameters:
cond_in_channels (int) – Dimension of each condition token
D_c.hidden_size (int) – Output embedding dimension
D.n_tokens (int) – Number of condition tokens
K(used to build the id embedding table).id_embedding (str) – How to embed token positions:
"absolute"(learned),"pos1d"(sinusoidal 1D), or"none"(no positional information added).rngs (flax.nnx.Rngs)
param_dtype (jax.typing.DTypeLike)
- class gensbi.experimental.models.pixeldit.embedders.PatchTokenEmbedder(in_features, hidden_size, *, rngs, param_dtype=jnp.bfloat16)[source]#
Bases:
flax.nnx.ModuleLinear patch-token embedder (ref
PatchTokenEmbedder, pixeldit_c2i.py:21-38).Linear(in_features → hidden_size, bias=True); kernel xavier_uniform, bias zeros. No norm layer (norm_layer=Nonecase).- Parameters:
in_features (int)
hidden_size (int)
rngs (flax.nnx.Rngs)
param_dtype (jax.typing.DTypeLike)
- class gensbi.experimental.models.pixeldit.embedders.PixelTokenEmbedder(in_channels, pixel_hidden_size, field_shape, patch_size, *, use_abs_pos=True, rngs, param_dtype=jnp.bfloat16)[source]#
Bases:
flax.nnx.ModulePer-pixel projection + optional sincos abs-pos + patch grouping.
Faithful port of
PixelTokenEmbedder.forward(pixeldit_c2i.py:93-111). Operates channel-last: input(B, H, W, C), output(B·L, p², D_pix).- Parameters:
in_channels (int) – Number of input channels
C.pixel_hidden_size (int) – Per-pixel hidden dimension
D_pix.field_shape (tuple[int, int]) –
(H, W)— fixed spatial resolution; the sincos table is precomputed and stored as a non-trainableBuffer.patch_size (int) – Patch size
p.use_abs_pos (bool) – If
True(default), add the sincos 2D positional embedding.rngs (flax.nnx.Rngs)
param_dtype (jax.typing.DTypeLike)
- __call__(x)[source]#
- Parameters:
x (jax.Array) –
(B, H, W, C)channel-last image.- Returns:
(B·L, p², D_pix)grouped pixel tokens.- Return type:
jax.Array
- gensbi.experimental.models.pixeldit.embedders.patchify(x, p)[source]#
Fold a channel-last image into patch tokens, row-major.
- Parameters:
x (jax.Array) –
(B, H, W, C)channel-last image.p (int) – Patch size (both spatial dims).
- Returns:
(B, L, p²·C)whereL = Hs*Ws, patches in row-major order(Hs, Ws)and pixels within each patch in row-major order(p, p).- Return type:
jax.Array
- gensbi.experimental.models.pixeldit.embedders.unpatchify(tokens, grid, p, C)[source]#
Exact inverse of
patchify().- Parameters:
tokens (jax.Array) –
(B, L, p²·C)patch tokens.grid (tuple[int, int]) –
(Hs, Ws)— number of patches along each spatial axis.p (int) – Patch size.
C (int) – Number of output channels.
- Returns:
(B, H, W, C)channel-last image.- Return type:
jax.Array