gensbi.experimental.models.fielddit#
FieldDiT: conditional flow-matching for 2D field-level inference (Phase 1).
Submodules#
Classes#
Conditional flow-matching network for 2D fields (pixel space). |
|
Configuration for |
|
Integer rope/positional id buffers. |
Package Contents#
- class gensbi.experimental.models.fielddit.FieldDiT(params)[source]#
Bases:
flax.nnx.ModuleConditional flow-matching network for 2D fields (pixel space).
Forward:
(t, obs=field, cond) -> velocity fieldof the same shape. The conv encoder is modulated by time only (or by the fullvecwhencond_modulates_encoder=True); the decoder and the MMDiT core are modulated byvec = time (+ cond summary if flagged-C) (+ guidance). The meeting-grid rope2d obs ids and absolute cond ids are built internally.- Parameters:
params (FieldDiTParams)
- cond_dim#
- cond_embedder#
- cond_ids#
- cond_modulates_encoder#
- core#
- decoder#
- encoder#
- field_shape#
- guidance_embed#
- in_channels#
- obs_ids#
- param_dtype#
- time_in#
- token_grid#
- tokenizer#
- untokenizer#
- use_cond_summary_in_vec#
- class gensbi.experimental.models.fielddit.FieldDiTParams[source]#
Configuration for
FieldDiT(mirrors the style ofFlux1Params).The meeting-grid token count is derived from
field_shape,encoder_widths(depth) andpatch_size— it is not prescribed:tokens = (H / (2**D * p)) * (W / (2**D * p))withD = len(encoder_widths) - 1.Note:
rngsis a livennx.Rngsstream (mirrorsFlux1Params). Constructing two models from the same params object yields different weights, because the stream advances; build a freshFieldDiTParams(or a freshnnx.Rngs(seed)) per model for reproducibility.- axes_dim: List[int] | None = None#
- cond_dim: int#
- cond_in_channels: int = 1#
- cond_modulates_encoder: bool = False#
- depth: int = 2#
- depth_single_blocks: int = 2#
- encoder_widths: Tuple[int, Ellipsis]#
- field_shape: Tuple[int, int]#
- guidance_embed: bool = False#
- in_channels: int#
- mlp_ratio: float = 4.0#
- norm_groups: int = 8#
- num_heads: int = 12#
- param_dtype: jax.typing.DTypeLike#
- patch_size: int = 2#
- qkv_bias: bool = False#
- res_blocks_down: int = 2#
- res_blocks_up: int = 2#
- rngs: flax.nnx.Rngs#
- theta: int | None = None#
- use_cond_summary_in_vec: bool = True#
- vec_in_dim: int | None = None#
- class gensbi.experimental.models.fielddit.RopeIds(value, *, hijax=None, ref=None, eager_sharding=None, **metadata)[source]#
Bases:
flax.nnx.VariableInteger rope/positional id buffers.
A dedicated Variable type so the ids are (a) filterable with
nnx.state(model, RopeIds), (b) excluded fromnnx.Paramstate, and (c) safe from blanket float-dtype casts applied to the parameter state.- Parameters:
value (A | VariableMetadata[A])
hijax (bool | None)
ref (bool | None)
eager_sharding (bool | None)
metadata (Any)