gensbi.experimental.recipes.vae_pipeline#
VAE Pipeline module for GenSBI.
This module provides an abstract pipeline class and concrete implementations for training and evaluating Variational Autoencoders (VAEs) within the GenSBI framework. It manages model instantiation, training and validation loops, optimizer and EMA setup, checkpointing, and utility functions for saving and restoring models.
The AbstractVAEPipeline class defines the general workflow for VAE-based models, including optimizer configuration, KL annealing, and early stopping. Subclasses such as VAE1DPipeline and VAE2DPipeline implement pipelines for 1D and 2D autoencoder architectures, respectively.
Typical usage involves subclassing or instantiating the provided pipelines with appropriate datasets, model parameters, and (optionally) training configurations.
Key features: - Model and EMA initialization - Training and validation step functions (JIT-compiled) - Learning rate scheduling and gradient clipping - Early stopping and checkpoint management - Support for custom training configurations
See the VAE1DPipeline and VAE2DPipeline classes for concrete examples.
Classes#
Abstract pipeline for training and evaluating Variational Autoencoders (VAEs) in GenSBI. |
|
Pipeline for training and evaluating 1D Variational Autoencoders (VAE1D) in GenSBI. |
|
Pipeline for training and evaluating 2D Variational Autoencoders (VAE2D) in GenSBI. |
Functions#
|
Parse a VAE configuration file. |
|
Parse a training configuration file. |
Module Contents#
- class gensbi.experimental.recipes.vae_pipeline.AbstractVAEPipeline(model_cls, train_dataset, val_dataset, params, training_config=None)[source]#
Abstract pipeline for training and evaluating Variational Autoencoders (VAEs) in GenSBI.
This class manages model creation, optimizer and EMA setup, training and validation loops, checkpointing, and utility functions. It is designed to be subclassed for specific VAE architectures.
- Parameters:
params (gensbi.experimental.models.autoencoders.AutoEncoderParams)
- _get_ema_optimizer()[source]#
Construct the EMA optimizer for maintaining an exponential moving average of model parameters.
- Returns:
ema_optimizer – The EMA optimizer instance.
- Return type:
- _get_kl_schedule(nsteps)[source]#
Construct a KL annealing schedule for training.
- Parameters:
nsteps (int) – Number of training steps.
- Returns:
schedule – KL weight schedule function.
- Return type:
Callable
- _get_optimizer()[source]#
Construct the optimizer for training, including learning rate scheduling and gradient clipping.
- Returns:
optimizer – The optimizer instance for the model.
- Return type:
nnx.Optimizer
- classmethod get_default_training_config()[source]#
Return a dictionary of default training configuration parameters for VAE training.
- Returns:
training_config – Default training configuration.
- Return type:
dict
- get_train_step_fn()[source]#
Return the training step function, which performs a single optimization step.
- Returns:
train_step – JIT-compiled training step function.
- Return type:
Callable
- get_val_step_fn()[source]#
Return the validation step function, which computes the loss on a validation batch.
- Returns:
val_step – JIT-compiled validation step function.
- Return type:
Callable
- init_pipeline_from_config(train_dataset, val_dataset, dim_obs, dim_cond, config_path, checkpoint_dir)[source]#
Initialize the pipeline from a configuration file.
- Parameters:
config_path (str) – Path to the configuration file.
dim_obs (int)
dim_cond (int)
checkpoint_dir (str)
- restore_model(experiment_id=None)[source]#
Restore the model and EMA model from checkpoint directories.
- Parameters:
experiment_id (int, optional) – Identifier for the experiment/checkpoint. If None, uses the current training config value.
- save_model(experiment_id=None)[source]#
Save the current model and EMA model to checkpoint directories.
- Parameters:
experiment_id (int, optional) – Identifier for the experiment/checkpoint. If None, uses the current training config value.
- train(rngs, nsteps=None, save_model=True)[source]#
Run the training loop for the VAE model.
- Parameters:
rngs (nnx.Rngs) – Random number generators for training/validation steps.
nsteps (int, optional) – Number of training steps. If None, uses the value from training config.
save_model (bool, optional) – Whether to save the model after training.
- Returns:
loss_array (list) – List of training losses.
val_loss_array (list) – List of validation losses.
- Return type:
Tuple[list, list]
- class gensbi.experimental.recipes.vae_pipeline.VAE1DPipeline(train_dataset, val_dataset, params, training_config=None)[source]#
Bases:
AbstractVAEPipelinePipeline for training and evaluating 1D Variational Autoencoders (VAE1D) in GenSBI.
Inherits from AbstractVAEPipeline and uses the AutoEncoder1D model class.
- Parameters:
params (gensbi.experimental.models.autoencoders.AutoEncoderParams)
- class gensbi.experimental.recipes.vae_pipeline.VAE2DPipeline(train_dataset, val_dataset, params, training_config=None)[source]#
Bases:
AbstractVAEPipelinePipeline for training and evaluating 2D Variational Autoencoders (VAE2D) in GenSBI.
Inherits from AbstractVAEPipeline and uses the AutoEncoder2D model class.
- Parameters:
params (gensbi.experimental.models.autoencoders.AutoEncoderParams)