gensbi.experimental.recipes.field_pipeline#
Field-shaped conditional pipeline (experimental).
ConditionalPipeline assumes token-shaped observations (B, dim, ch):
it flattens dim_obs, resolves embedding ids, and builds a rank-2 prior.
Pixel-space field models (FieldDiT) need (B, H, W, C) observations, a
rank-3 event_shape, no external ids, and rank-aware input expansion.
Classes#
Conditional pipeline for pixel-space fields |
|
Wrapper for field-shaped conditional models. |
Module Contents#
- class gensbi.experimental.recipes.field_pipeline.FieldConditionalPipeline(model, train_dataset, val_dataset, field_shape, dim_cond, method, ch_obs=1, ch_cond=1, params=None, training_config=None)[source]#
Bases:
gensbi.recipes.conditional_pipeline.ConditionalPipelineConditional pipeline for pixel-space fields
(B, H, W, C).Differences from
ConditionalPipeline:event_shape = (*field_shape, ch_obs)— prior, path, and sampling are field-shaped (samplereturns(nsamples, H, W, C)).no obs/cond id resolution:
obs_ids/cond_idsareNonein all extras; the model builds its rope ids internally and ignores them.FieldConditionalWrapperfor event-rank-aware expansion.
Datasets must yield
(obs, cond)batches withobsof shape(B, H, W, C)andcondof shape(B, k)or(B, k, c).- Parameters:
- class gensbi.experimental.recipes.field_pipeline.FieldConditionalWrapper(model)[source]#
Bases:
gensbi.utils.model_wrapping.ModelWrapperWrapper for field-shaped conditional models.
Expansion is event-rank-aware (a field event is rank 3,
(H, W, C)):obswithndim == 3is treated as unbatched ->(1, H, W, C).condwithndim == 1((k,)) ->(1, k); 2D+ cond is assumed batched ((B, k)/(B, k, c)).batch-1
condis broadcast to the obs batch (sampling N draws for a singlex_o; the model itself deliberately does not broadcast).ids are passed through untouched (FieldDiT builds rope ids internally).
- __call__(t, obs, cond, obs_ids=None, cond_ids=None, conditioned=True, guidance=None, **kwargs)[source]#
Call the wrapped model with
obsandt.Uses keyword arguments when calling the underlying model for safety (avoids positional-argument order bugs).
- Parameters:
t (Array) – time (batch_size).
obs (Array) – input data to the model (batch_size, …).
**kwargs (additional information forwarded to the model,) – e.g., text condition.
- Returns:
model output.
- Return type:
Array