Source code for gensbi.experimental.models.pixeldit.rope

"""Rotary positional embedding tables for the PixelDiT port.

Ports ``precompute_freqs_cis_2d`` (reference ``modules.py:132-145``) and
``fetch_pos_text`` (reference ``pixeldit_t2i.py:232-241``) from the PyTorch
PixelDiT code, but emits the real 2x2-rotation tensor consumed by
``gensbi.models.flux1.math.apply_rope`` instead of complex ``cis`` values.

The two formats are numerically equivalent: the reference rotates consecutive
value pairs ``(x[2j], x[2j+1])`` by the complex phase ``freqs_cis[n, j]``, and
``apply_rope`` applies the matching 2x2 rotation to the same consecutive pairs.
"""

import jax.numpy as jnp


[docs] def _rotation_from_angles(angles): """Build the (1, 1, N, head_dim/2, 2, 2) rotation table from per-pair angles. ``angles`` has shape (N, head_dim/2); each angle becomes a 2x2 rotation matrix matching ``flux1.math.rope``. """ rot = jnp.stack( [jnp.cos(angles), -jnp.sin(angles), jnp.sin(angles), jnp.cos(angles)], axis=-1, ).reshape(*angles.shape, 2, 2) return rot[None, None, ...].astype(jnp.float32)
[docs] def precompute_freqs_cis_2d(head_dim, height, width, theta=10000.0, scale=16.0): """2D axial rope on fractional positions in ``[0, scale]``. Returns a ``(1, 1, H*W, head_dim//2, 2, 2)`` float32 rotation table. The x and y angle pairs are interleaved exactly as the reference interleaves the complex ``x_cis``/``y_cis`` phases. """ x_pos = jnp.linspace(0, scale, width) y_pos = jnp.linspace(0, scale, height) yy, xx = jnp.meshgrid(y_pos, x_pos, indexing="ij") freqs = 1.0 / (theta ** (jnp.arange(0, head_dim, 4)[: head_dim // 4] / head_dim)) x_ang = jnp.outer(xx.reshape(-1), freqs) y_ang = jnp.outer(yy.reshape(-1), freqs) angles = jnp.stack([x_ang, y_ang], axis=-1).reshape(x_ang.shape[0], -1) return _rotation_from_angles(angles)
[docs] def precompute_freqs_cis_1d(head_dim, length, theta=10000.0): """1D integer-position rope for cond tokens (reference ``fetch_pos_text``). Same output layout as :func:`precompute_freqs_cis_2d` with ``N = length``. The full head dim is rotated by the single axis, so there are ``head_dim/2`` frequencies (step 2). """ freqs = 1.0 / (theta ** (jnp.arange(0, head_dim, 2) / head_dim)) angles = jnp.outer(jnp.arange(length), freqs) return _rotation_from_angles(angles)