gensbi.experimental.recipes#
Submodules#
Classes#
Conditional pipeline for pixel-space fields |
|
Wrapper for field-shaped conditional models. |
|
Pipeline for training and evaluating 1D Variational Autoencoders (VAE1D) in GenSBI. |
|
Pipeline for training and evaluating 2D Variational Autoencoders (VAE2D) in GenSBI. |
Package Contents#
- class gensbi.experimental.recipes.FieldConditionalPipeline(model, train_dataset, val_dataset, field_shape, dim_cond, method, ch_obs=1, ch_cond=1, params=None, training_config=None)[source]#
Bases:
gensbi.recipes.conditional_pipeline.ConditionalPipelineConditional pipeline for pixel-space fields
(B, H, W, C).Differences from
ConditionalPipeline:event_shape = (*field_shape, ch_obs)— prior, path, and sampling are field-shaped (samplereturns(nsamples, H, W, C)).no obs/cond id resolution:
obs_ids/cond_idsareNonein all extras; the model builds its rope ids internally and ignores them.FieldConditionalWrapperfor event-rank-aware expansion.
Datasets must yield
(obs, cond)batches withobsof shape(B, H, W, C)andcondof shape(B, k)or(B, k, c).- Parameters:
- _wrap_model()[source]#
Wrap the model for evaluation (either using JointWrapper or ConditionalWrapper).
- cond_ids = None#
- event_shape#
- field_shape#
- loss_obj#
- method#
- obs_ids = None#
- path#
- class gensbi.experimental.recipes.FieldConditionalWrapper(model)[source]#
Bases:
gensbi.utils.model_wrapping.ModelWrapperWrapper for field-shaped conditional models.
Expansion is event-rank-aware (a field event is rank 3,
(H, W, C)):obswithndim == 3is treated as unbatched ->(1, H, W, C).condwithndim == 1((k,)) ->(1, k); 2D+ cond is assumed batched ((B, k)/(B, k, c)).batch-1
condis broadcast to the obs batch (sampling N draws for a singlex_o; the model itself deliberately does not broadcast).ids are passed through untouched (FieldDiT builds rope ids internally).
- __call__(t, obs, cond, obs_ids=None, cond_ids=None, conditioned=True, guidance=None, **kwargs)[source]#
Call the wrapped model with
obsandt.Uses keyword arguments when calling the underlying model for safety (avoids positional-argument order bugs).
- Parameters:
t (Array) – time (batch_size).
obs (Array) – input data to the model (batch_size, …).
**kwargs (additional information forwarded to the model,) – e.g., text condition.
- Returns:
model output.
- Return type:
Array
- class gensbi.experimental.recipes.VAE1DPipeline(train_dataset, val_dataset, params, training_config=None)[source]#
Bases:
AbstractVAEPipelinePipeline for training and evaluating 1D Variational Autoencoders (VAE1D) in GenSBI.
Inherits from AbstractVAEPipeline and uses the AutoEncoder1D model class.
- Parameters:
params (gensbi.experimental.models.autoencoders.AutoEncoderParams)
- class gensbi.experimental.recipes.VAE2DPipeline(train_dataset, val_dataset, params, training_config=None)[source]#
Bases:
AbstractVAEPipelinePipeline for training and evaluating 2D Variational Autoencoders (VAE2D) in GenSBI.
Inherits from AbstractVAEPipeline and uses the AutoEncoder2D model class.
- Parameters:
params (gensbi.experimental.models.autoencoders.AutoEncoderParams)