gensbi.experimental.models.fielddit.cond#

Pluggable condition embedders for FieldDiT (Phase 1: scalar / vector).

A condition is embedded into a token stream (consumed by the MMDiT core via joint attention) plus a pooled summary (added to the modulation vector for flagged-C decoder modulation).

Classes#

ScalarCondEmbedder

Embed a few condition tokens (e.g. theta scalars) to hidden_size.

Module Contents#

class gensbi.experimental.models.fielddit.cond.ScalarCondEmbedder(in_channels, hidden_size, rngs, param_dtype=jnp.bfloat16)[source]#

Bases: flax.nnx.Module

Embed a few condition tokens (e.g. theta scalars) to hidden_size.

Input cond is (B, k, in_channels) (or (B, k), auto-expanded to (B, k, 1) — the 2D form raises unless in_channels == 1). Returns (cond_tokens (B, k, hidden), summary (B, hidden)) where the summary is a projection of the mean-pooled tokens.

Parameters:
  • in_channels (int)

  • hidden_size (int)

  • rngs (flax.nnx.Rngs)

  • param_dtype (jax.typing.DTypeLike)

__call__(cond)[source]#
in_channels[source]#
summary_proj[source]#
token_proj[source]#