gensbi.recipes.pipeline#
Pipeline module for GenSBI.
This module provides an abstract pipeline class for training and evaluating conditional generative models (such as conditional flow matching or diffusion models) in the GenSBI framework. It handles model creation, training loop, optimizer setup, checkpointing, and evaluation utilities.
For practical implementations, subclasses should implement specific model architectures, loss functions, and sampling methods. See JointPipeline and ConditionalPipeline for concrete examples.
Classes#
Abstract base class for GenSBI training pipelines. |
|
Exponential Moving Average (EMA) optimizer for maintaining a smoothed version of model parameters. |
Functions#
|
Create a batch sampler that processes samples in chunks. |
|
Update EMA model with current model parameters. |
Module Contents#
- class gensbi.recipes.pipeline.AbstractPipeline(model, train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, ch_cond=None, params=None, training_config=None)[source]#
Bases:
abc.ABCAbstract base class for GenSBI training pipelines.
This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.
- Parameters:
train_dataset (iterable) – Training dataset, should yield batches of data.
val_dataset (iterable) – Validation dataset, should yield batches of data.
dim_obs (int) – Dimensionality of the parameter (theta) space.
dim_cond (int) – Dimensionality of the observation (x) space.
model (nnx.Module, optional) – The model to be trained. If None, the model is created using _make_model.
params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.
ch_obs (int, optional) – Number of channels in the observation data. Default is 1.
ch_cond (int, optional) – Number of channels in the conditional data (if applicable). Default is None.
training_config (dict, optional) – Training configuration. If None, uses defaults from get_default_training_config.
- abstractmethod _get_default_params(rngs)[source]#
Return a dictionary of default model parameters.
- Parameters:
rngs (flax.nnx.Rngs)
- _get_ema_optimizer()[source]#
Construct the EMA optimizer for maintaining an exponential moving average of model parameters. :returns: ema_optimizer – The EMA optimizer instance. :rtype: ModelEMA
- _get_optimizer()[source]#
Construct the optimizer for training, including learning rate scheduling and gradient clipping.
- Returns:
optimizer – The optimizer instance for the model.
- Return type:
nnx.Optimizer
- abstractmethod _wrap_model()[source]#
Wrap the model for evaluation (either using JointWrapper or ConditionalWrapper).
- classmethod get_default_training_config()[source]#
Return a dictionary of default training configuration parameters.
- Returns:
training_config – Default training configuration.
- Return type:
dict
- abstractmethod get_sampler(key, x_o, step_size=0.01, use_ema=True, time_grid=None, **model_extras)[source]#
Get a sampler function for generating samples from the trained model.
- Parameters:
key (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable.
step_size (float, optional) – Step size for the sampler.
use_ema (bool, optional) – Whether to use the EMA model for sampling.
time_grid (array-like, optional) – Time grid for the sampler (if applicable).
model_extras (dict, optional) – Additional model-specific parameters.
- Returns:
sampler – A function that generates samples when called with a random key and number of samples.
- Return type:
Callable: key, nsamples -> samples
- get_train_step_fn(loss_fn)[source]#
Return the training step function, which performs a single optimization step.
- Returns:
train_step – JIT-compiled training step function.
- Return type:
Callable
- get_val_step_fn(loss_fn)[source]#
Return the validation step function, which performs a single optimization step.
- Returns:
val_step – JIT-compiled validation step function.
- Return type:
Callable
- abstractmethod init_pipeline_from_config(train_dataset, val_dataset, dim_obs, dim_cond, config_path, checkpoint_dir)[source]#
Initialize the pipeline from a configuration file.
- Parameters:
train_dataset (iterable) – Training dataset.
val_dataset (iterable) – Validation dataset.
dim_obs (int) – Dimensionality of the parameter (theta) space.
dim_cond (int) – Dimensionality of the observation (x) space.
config_path (str) – Path to the configuration file.
checkpoint_dir (str) – Directory for saving checkpoints.
- Returns:
pipeline – An instance of the pipeline initialized from the configuration.
- Return type:
- restore_model(experiment_id=None)[source]#
Restore model and EMA model from checkpoints.
- Parameters:
experiment_id (str, optional) – Experiment identifier. If None, uses training_config value.
- abstractmethod sample(key, x_o, nsamples=10000)[source]#
Generate samples from the trained model.
- Parameters:
key (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable (e.g., observed data).
nsamples (int, optional) – Number of samples to generate.
- Returns:
samples – Generated samples of size (nsamples, dim_obs, ch_obs).
- Return type:
array-like
- sample_batched(key, x_o, nsamples, *args, chunk_size=50, show_progress_bars=True, **kwargs)[source]#
Generate samples from the trained model in batches.
- Parameters:
key (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable (e.g., observed data).
nsamples (int) – Number of samples to generate.
chunk_size (int, optional) – Size of each batch for sampling. Default is 50.
show_progress_bars (bool, optional) – Whether to display progress bars during sampling. Default is True.
args (tuple) – Additional positional arguments for the sampler.
kwargs (dict) – Additional keyword arguments for the sampler.
- Returns:
samples – Generated samples of shape (nsamples, batch_size_cond, dim_obs, ch_obs).
- Return type:
array-like
- save_model(experiment_id=None)[source]#
Save model and EMA model checkpoints.
- Parameters:
experiment_id (str, optional) – Experiment identifier. If None, uses training_config value.
- train(rngs, nsteps=None, save_model=True)[source]#
Run the training loop for the model.
- Parameters:
rngs (nnx.Rngs) – Random number generators for training/validation steps.
nsteps (Optional[int])
- Returns:
loss_array (list) – List of training losses.
val_loss_array (list) – List of validation losses.
- Return type:
Tuple[list, list]
- class gensbi.recipes.pipeline.ModelEMA(model, tx)[source]#
Bases:
flax.nnx.OptimizerExponential Moving Average (EMA) optimizer for maintaining a smoothed version of model parameters.
This optimizer keeps an exponential moving average of the model parameters, which can help stabilize training and improve evaluation performance. The EMA parameters are updated at each training step.
- Parameters:
model (nnx.Module) – The model whose parameters will be tracked.
tx (optax.GradientTransformation) – The Optax transformation defining the EMA update rule.
- update(model, model_orginal)[source]#
Update the EMA parameters using the current model parameters. :param model: The model with EMA parameters to be updated. :type model: nnx.Module :param model_orginal: The original model with current parameters. :type model_orginal: nnx.Module
- Parameters:
model_orginal (flax.nnx.Module)
- gensbi.recipes.pipeline._get_batch_sampler(sampler_fn, ncond, chunk_size, show_progress_bars=True)[source]#
Create a batch sampler that processes samples in chunks.
- Parameters:
sampler_fn (Callable) – Sampling function.
ncond (int) – Number of conditions.
chunk_size (int) – Size of each chunk.
show_progress_bars (bool, optional) – Whether to show progress bars.
- Returns:
Batch sampler function.
- Return type:
Callable