Source code for gensbi.experimental.models.fielddit.cond
"""Pluggable condition embedders for FieldDiT (Phase 1: scalar / vector).
A condition is embedded into a token stream (consumed by the MMDiT core via
joint attention) plus a pooled summary (added to the modulation vector for
flagged-C decoder modulation).
"""
import jax.numpy as jnp
from flax import nnx
from jax.typing import DTypeLike
[docs]
class ScalarCondEmbedder(nnx.Module):
"""Embed a few condition tokens (e.g. theta scalars) to ``hidden_size``.
Input ``cond`` is ``(B, k, in_channels)`` (or ``(B, k)``, auto-expanded to
``(B, k, 1)`` — the 2D form raises unless ``in_channels == 1``).
Returns ``(cond_tokens (B, k, hidden), summary (B, hidden))`` where the
summary is a projection of the mean-pooled tokens.
"""
def __init__(self, in_channels: int, hidden_size: int, rngs: nnx.Rngs, param_dtype: DTypeLike = jnp.bfloat16):
[docs]
self.in_channels = in_channels
[docs]
self.token_proj = nnx.Linear(
in_features=in_channels, out_features=hidden_size, use_bias=True,
rngs=rngs, param_dtype=param_dtype,
)
[docs]
self.summary_proj = nnx.Linear(
in_features=hidden_size, out_features=hidden_size, use_bias=True,
rngs=rngs, param_dtype=param_dtype,
)
[docs]
def __call__(self, cond):
if cond.ndim == 2:
if self.in_channels != 1:
raise ValueError(
f"2D cond (B, k) is only valid when cond_in_channels == 1; "
f"this embedder has cond_in_channels={self.in_channels} — "
f"pass cond as (B, k, {self.in_channels})"
)
cond = cond[..., None]
tokens = self.token_proj(cond)
summary = self.summary_proj(jnp.mean(tokens, axis=1))
return tokens, summary