"""Rotary positional embedding tables for the PixelDiT port.
Ports ``precompute_freqs_cis_2d`` (reference ``modules.py:132-145``) and
``fetch_pos_text`` (reference ``pixeldit_t2i.py:232-241``) from the PyTorch
PixelDiT code, but emits the real 2x2-rotation tensor consumed by
``gensbi.models.flux1.math.apply_rope`` instead of complex ``cis`` values.
The two formats are numerically equivalent: the reference rotates consecutive
value pairs ``(x[2j], x[2j+1])`` by the complex phase ``freqs_cis[n, j]``, and
``apply_rope`` applies the matching 2x2 rotation to the same consecutive pairs.
"""
import jax.numpy as jnp
[docs]
def _rotation_from_angles(angles):
"""Build the (1, 1, N, head_dim/2, 2, 2) rotation table from per-pair angles.
``angles`` has shape (N, head_dim/2); each angle becomes a 2x2 rotation
matrix matching ``flux1.math.rope``.
"""
rot = jnp.stack(
[jnp.cos(angles), -jnp.sin(angles), jnp.sin(angles), jnp.cos(angles)],
axis=-1,
).reshape(*angles.shape, 2, 2)
return rot[None, None, ...].astype(jnp.float32)
[docs]
def precompute_freqs_cis_2d(head_dim, height, width, theta=10000.0, scale=16.0):
"""2D axial rope on fractional positions in ``[0, scale]``.
Returns a ``(1, 1, H*W, head_dim//2, 2, 2)`` float32 rotation table. The x
and y angle pairs are interleaved exactly as the reference interleaves the
complex ``x_cis``/``y_cis`` phases.
"""
x_pos = jnp.linspace(0, scale, width)
y_pos = jnp.linspace(0, scale, height)
yy, xx = jnp.meshgrid(y_pos, x_pos, indexing="ij")
freqs = 1.0 / (theta ** (jnp.arange(0, head_dim, 4)[: head_dim // 4] / head_dim))
x_ang = jnp.outer(xx.reshape(-1), freqs)
y_ang = jnp.outer(yy.reshape(-1), freqs)
angles = jnp.stack([x_ang, y_ang], axis=-1).reshape(x_ang.shape[0], -1)
return _rotation_from_angles(angles)
[docs]
def precompute_freqs_cis_1d(head_dim, length, theta=10000.0):
"""1D integer-position rope for cond tokens (reference ``fetch_pos_text``).
Same output layout as :func:`precompute_freqs_cis_2d` with ``N = length``.
The full head dim is rotated by the single axis, so there are
``head_dim/2`` frequencies (step 2).
"""
freqs = 1.0 / (theta ** (jnp.arange(0, head_dim, 2) / head_dim))
angles = jnp.outer(jnp.arange(length), freqs)
return _rotation_from_angles(angles)