gensbi.experimental.models.fielddit#

FieldDiT: conditional flow-matching for 2D field-level inference (Phase 1).

Submodules#

Classes#

FieldDiT

Conditional flow-matching network for 2D fields (pixel space).

FieldDiTParams

Configuration for FieldDiT (mirrors the style of Flux1Params).

RopeIds

Integer rope/positional id buffers.

Package Contents#

class gensbi.experimental.models.fielddit.FieldDiT(params)[source]#

Bases: flax.nnx.Module

Conditional flow-matching network for 2D fields (pixel space).

Forward: (t, obs=field, cond) -> velocity field of the same shape. The conv encoder is modulated by time only (or by the full vec when cond_modulates_encoder=True); the decoder and the MMDiT core are modulated by vec = time (+ cond summary if flagged-C) (+ guidance). The meeting-grid rope2d obs ids and absolute cond ids are built internally.

Parameters:

params (FieldDiTParams)

__call__(t, obs, cond, obs_ids=None, cond_ids=None, conditioned=True, guidance=None)[source]#
cond_dim#
cond_embedder#
cond_ids#
cond_modulates_encoder#
core#
decoder#
encoder#
field_shape#
guidance_embed#
in_channels#
obs_ids#
param_dtype#
time_in#
token_grid#
tokenizer#
untokenizer#
use_cond_summary_in_vec#
class gensbi.experimental.models.fielddit.FieldDiTParams[source]#

Configuration for FieldDiT (mirrors the style of Flux1Params).

The meeting-grid token count is derived from field_shape, encoder_widths (depth) and patch_size — it is not prescribed: tokens = (H / (2**D * p)) * (W / (2**D * p)) with D = len(encoder_widths) - 1.

Note: rngs is a live nnx.Rngs stream (mirrors Flux1Params). Constructing two models from the same params object yields different weights, because the stream advances; build a fresh FieldDiTParams (or a fresh nnx.Rngs(seed)) per model for reproducibility.

__post_init__()[source]#
axes_dim: List[int] | None = None#
cond_dim: int#
cond_in_channels: int = 1#
cond_modulates_encoder: bool = False#
depth: int = 2#
depth_single_blocks: int = 2#
encoder_widths: Tuple[int, Ellipsis]#
field_shape: Tuple[int, int]#
guidance_embed: bool = False#
in_channels: int#
mlp_ratio: float = 4.0#
norm_groups: int = 8#
num_heads: int = 12#
param_dtype: jax.typing.DTypeLike#
patch_size: int = 2#
qkv_bias: bool = False#
res_blocks_down: int = 2#
res_blocks_up: int = 2#
rngs: flax.nnx.Rngs#
theta: int | None = None#
use_cond_summary_in_vec: bool = True#
vec_in_dim: int | None = None#
class gensbi.experimental.models.fielddit.RopeIds(value, *, hijax=None, ref=None, eager_sharding=None, **metadata)[source]#

Bases: flax.nnx.Variable

Integer rope/positional id buffers.

A dedicated Variable type so the ids are (a) filterable with nnx.state(model, RopeIds), (b) excluded from nnx.Param state, and (c) safe from blanket float-dtype casts applied to the parameter state.

Parameters:
  • value (A | VariableMetadata[A])

  • hijax (bool | None)

  • ref (bool | None)

  • eager_sharding (bool | None)

  • metadata (Any)