gensbi.experimental.models.pixeldit.embedders#

Token embedders for the PixelDiT port.

Faithful port of reference/PixelDiT/pixdit_core/pixeldit_c2i.py embedders plus the cond-token embedder from pixeldit_t2i.py. Operates channel-last (B, H, W, C) throughout.

Classes#

CondTokenEmbedder

Condition token embedder: linear → RMSNorm → add id embedding.

PatchTokenEmbedder

Linear patch-token embedder (ref PatchTokenEmbedder, pixeldit_c2i.py:21-38).

PixelTokenEmbedder

Per-pixel projection + optional sincos abs-pos + patch grouping.

Functions#

patchify(x, p)

Fold a channel-last image into patch tokens, row-major.

unpatchify(tokens, grid, p, C)

Exact inverse of patchify().

Module Contents#

class gensbi.experimental.models.pixeldit.embedders.CondTokenEmbedder(cond_in_channels, hidden_size, n_tokens, *, id_embedding='absolute', rngs, param_dtype=jnp.bfloat16)[source]#

Bases: flax.nnx.Module

Condition token embedder: linear → RMSNorm → add id embedding.

Faithful port of the t2i cond pipeline (y_embedder + y_pos_embedding, pixeldit_t2i.py:179-180, 267-268).

Parameters:
  • cond_in_channels (int) – Dimension of each condition token D_c.

  • hidden_size (int) – Output embedding dimension D.

  • n_tokens (int) – Number of condition tokens K (used to build the id embedding table).

  • id_embedding (str) – How to embed token positions: "absolute" (learned), "pos1d" (sinusoidal 1D), or "none" (no positional information added).

  • rngs (flax.nnx.Rngs)

  • param_dtype (jax.typing.DTypeLike)

__call__(cond)[source]#
Parameters:

cond (jax.Array) – (B, K, D_c) condition tokens; or (B, K) when cond_in_channels == 1 (auto-expanded to (B, K, 1)).

Returns:

(B, K, D) embedded condition tokens.

Return type:

jax.Array

cond_in_channels[source]#
id_embedding_kind = 'absolute'[source]#
norm[source]#
proj[source]#
class gensbi.experimental.models.pixeldit.embedders.PatchTokenEmbedder(in_features, hidden_size, *, rngs, param_dtype=jnp.bfloat16)[source]#

Bases: flax.nnx.Module

Linear patch-token embedder (ref PatchTokenEmbedder, pixeldit_c2i.py:21-38).

Linear(in_features hidden_size, bias=True); kernel xavier_uniform, bias zeros. No norm layer (norm_layer=None case).

Parameters:
  • in_features (int)

  • hidden_size (int)

  • rngs (flax.nnx.Rngs)

  • param_dtype (jax.typing.DTypeLike)

__call__(x)[source]#
proj[source]#
class gensbi.experimental.models.pixeldit.embedders.PixelTokenEmbedder(in_channels, pixel_hidden_size, field_shape, patch_size, *, use_abs_pos=True, rngs, param_dtype=jnp.bfloat16)[source]#

Bases: flax.nnx.Module

Per-pixel projection + optional sincos abs-pos + patch grouping.

Faithful port of PixelTokenEmbedder.forward (pixeldit_c2i.py:93-111). Operates channel-last: input (B, H, W, C), output (B·L, p², D_pix).

Parameters:
  • in_channels (int) – Number of input channels C.

  • pixel_hidden_size (int) – Per-pixel hidden dimension D_pix.

  • field_shape (tuple[int, int]) – (H, W) — fixed spatial resolution; the sincos table is precomputed and stored as a non-trainable Buffer.

  • patch_size (int) – Patch size p.

  • use_abs_pos (bool) – If True (default), add the sincos 2D positional embedding.

  • rngs (flax.nnx.Rngs)

  • param_dtype (jax.typing.DTypeLike)

__call__(x)[source]#
Parameters:

x (jax.Array) – (B, H, W, C) channel-last image.

Returns:

(B·L, p², D_pix) grouped pixel tokens.

Return type:

jax.Array

field_shape[source]#
patch_size[source]#
pixel_hidden_size[source]#
pos_embed[source]#
proj[source]#
use_abs_pos = True[source]#
gensbi.experimental.models.pixeldit.embedders.patchify(x, p)[source]#

Fold a channel-last image into patch tokens, row-major.

Parameters:
  • x (jax.Array) – (B, H, W, C) channel-last image.

  • p (int) – Patch size (both spatial dims).

Returns:

(B, L, p²·C) where L = Hs*Ws, patches in row-major order (Hs, Ws) and pixels within each patch in row-major order (p, p).

Return type:

jax.Array

gensbi.experimental.models.pixeldit.embedders.unpatchify(tokens, grid, p, C)[source]#

Exact inverse of patchify().

Parameters:
  • tokens (jax.Array) – (B, L, p²·C) patch tokens.

  • grid (tuple[int, int]) – (Hs, Ws) — number of patches along each spatial axis.

  • p (int) – Patch size.

  • C (int) – Number of output channels.

Returns:

(B, H, W, C) channel-last image.

Return type:

jax.Array