gensbi.experimental.models.fielddit.cond#
Pluggable condition embedders for FieldDiT (Phase 1: scalar / vector).
A condition is embedded into a token stream (consumed by the MMDiT core via joint attention) plus a pooled summary (added to the modulation vector for flagged-C decoder modulation).
Classes#
Embed a few condition tokens (e.g. theta scalars) to |
Module Contents#
- class gensbi.experimental.models.fielddit.cond.ScalarCondEmbedder(in_channels, hidden_size, rngs, param_dtype=jnp.bfloat16)[source]#
Bases:
flax.nnx.ModuleEmbed a few condition tokens (e.g. theta scalars) to
hidden_size.Input
condis(B, k, in_channels)(or(B, k), auto-expanded to(B, k, 1)— the 2D form raises unlessin_channels == 1). Returns(cond_tokens (B, k, hidden), summary (B, hidden))where the summary is a projection of the mean-pooled tokens.- Parameters:
in_channels (int)
hidden_size (int)
rngs (flax.nnx.Rngs)
param_dtype (jax.typing.DTypeLike)