Source code for gensbi.experimental.recipes.vae_pipeline

"""
VAE Pipeline module for GenSBI.

This module provides an abstract pipeline class and concrete implementations for training and evaluating Variational Autoencoders (VAEs)
within the GenSBI framework. It manages model instantiation, training and validation loops, optimizer and EMA setup, checkpointing,
and utility functions for saving and restoring models.

The `AbstractVAEPipeline` class defines the general workflow for VAE-based models, including optimizer configuration, KL annealing,
and early stopping. Subclasses such as `VAE1DPipeline` and `VAE2DPipeline` implement pipelines for 1D and 2D autoencoder architectures, respectively.

Typical usage involves subclassing or instantiating the provided pipelines with appropriate datasets, model parameters, and (optionally) training configurations.

Key features:
- Model and EMA initialization
- Training and validation step functions (JIT-compiled)
- Learning rate scheduling and gradient clipping
- Early stopping and checkpoint management
- Support for custom training configurations

See the `VAE1DPipeline` and `VAE2DPipeline` classes for concrete examples.
"""

from flax import nnx
import jax
from jax import numpy as jnp
from typing import Any, Callable, Optional, Tuple
from jax import Array

from numpyro import distributions as dist

import abc
from functools import partial

import optax

from optax.schedules import linear_schedule, constant_schedule

import yaml

import orbax.checkpoint as ocp

from tqdm import tqdm

import os

from gensbi.experimental.models.autoencoders import (
    AutoEncoderParams,
    vae_loss_fn,
    AutoEncoder1D,
    AutoEncoder2D,
)
from gensbi.recipes.pipeline import ema_step, ModelEMA


[docs] def parse_training_config(config_path: str): """ Parse a training configuration file. Parameters ---------- config_path : str Path to the configuration file. Returns ------- config : dict Parsed configuration dictionary. """ with open(config_path, "r") as f: config = yaml.safe_load(f) # Training parameters train_params = config.get("training", {}) multistep = train_params.get("multistep", 1) experiment_id = train_params.get("experiment_id", 1) early_stopping = train_params.get("early_stopping", True) nsteps = train_params.get("nsteps", 30000) * multistep val_every = train_params.get("val_every", 100) * multistep # Optimizer parameters opt_params = config.get("optimizer", {}) MAX_LR = opt_params.get("max_lr", 1e-3) MIN_LR = opt_params.get("min_lr", 0.0) MIN_SCALE = MIN_LR / MAX_LR if MAX_LR > 0 else 0.0 warmup_steps = opt_params.get("warmup_steps", 500) ema_decay = opt_params.get("ema_decay", 0.999) decay_transition = opt_params.get("decay_transition", 0.85) training_config = {} # overwrite the defaults with the config file values training_config["nsteps"] = nsteps training_config["ema_decay"] = ema_decay training_config["decay_transition"] = decay_transition training_config["max_lr"] = MAX_LR training_config["min_lr"] = MIN_LR training_config["min_scale"] = MIN_SCALE training_config["val_every"] = val_every training_config["early_stopping"] = early_stopping training_config["experiment_id"] = experiment_id training_config["multistep"] = multistep training_config["warmup_steps"] = warmup_steps return training_config
# AutoEncoderParams: # resolution: int # in_channels: int # ch: int # out_ch: int # ch_mult: list[int] # num_res_blocks: int # z_channels: int # scale_factor: float # shift_factor: float # rngs: nnx.Rngs # param_dtype: DTypeLike
[docs] def parse_autoencoder_params(config_path: str): """ Parse a VAE configuration file. Parameters ---------- config_path : str Path to the configuration file. Returns ------- config : dict Parsed configuration dictionary. """ with open(config_path, "r") as f: config = yaml.safe_load(f) model_params = config.get("vae_model", {}) params_dict = dict( resolution=model_params.get("resolution"), in_channels=model_params.get("in_channels"), ch=model_params.get("ch"), out_ch=model_params.get("out_ch"), ch_mult=model_params.get("ch_mult"), num_res_blocks=model_params.get("num_res_blocks"), z_channels=model_params.get("z_channels"), scale_factor=model_params.get("scale_factor", 1.0), shift_factor=model_params.get("shift_factor", 0.0), param_dtype=getattr(jnp, model_params.get("param_dtype", "float32")), ) return params_dict
[docs] class AbstractVAEPipeline: """ Abstract pipeline for training and evaluating Variational Autoencoders (VAEs) in GenSBI. This class manages model creation, optimizer and EMA setup, training and validation loops, checkpointing, and utility functions. It is designed to be subclassed for specific VAE architectures. """ def __init__( self, model_cls, train_dataset, val_dataset, params: AutoEncoderParams, training_config=None, ): """ Initialize the VAE pipeline. Parameters ---------- model_cls : type The class of the VAE model to instantiate (e.g., AutoEncoder1D or AutoEncoder2D). train_dataset : iterable Training dataset. val_dataset : iterable Validation dataset. params : AutoEncoderParams Model hyperparameters and configuration. training_config : dict, optional Training configuration dictionary. If None, defaults are used. """
[docs] self.train_dataset = train_dataset
[docs] self.val_dataset = val_dataset
[docs] self.train_dataset_iter = iter(self.train_dataset)
[docs] self.val_dataset_iter = iter(self.val_dataset)
[docs] self.params = params
[docs] self.training_config = training_config
if training_config is None: self.training_config = self.get_default_training_config() self.training_config["min_scale"] = ( self.training_config["min_lr"] / self.training_config["max_lr"] if self.training_config["max_lr"] > 0 else 0.0 ) os.makedirs(self.training_config["checkpoint_dir"], exist_ok=True)
[docs] self.model = model_cls(self.params)
[docs] self.ema_model = nnx.clone(self.model)
[docs] self.loss_fn = vae_loss_fn
[docs] def init_pipeline_from_config( cls, train_dataset, val_dataset, dim_obs: int, dim_cond: int, config_path: str, checkpoint_dir: str, ): """ Initialize the pipeline from a configuration file. Parameters ---------- config_path : str Path to the configuration file. """ params_dict = parse_autoencoder_params(config_path) params = AutoEncoderParams( rngs=nnx.Rngs(0), **params_dict, ) # Training parameters training_config = cls.get_default_training_config() training_config["checkpoint_dir"] = checkpoint_dir training_config_ = parse_training_config(config_path) for key, value in training_config_.items(): training_config[key] = value # update with config file values pipeline = cls( train_dataset, val_dataset, params=params, training_config=training_config, ) return pipeline
[docs] def _get_ema_optimizer(self): """ Construct the EMA optimizer for maintaining an exponential moving average of model parameters. Returns ------- ema_optimizer : ModelEMA The EMA optimizer instance. """ ema_tx = optax.ema(self.training_config["ema_decay"]) ema_optimizer = ModelEMA(self.ema_model, ema_tx) return ema_optimizer
[docs] def _get_kl_schedule(self, nsteps): """ Construct a KL annealing schedule for training. Parameters ---------- nsteps : int Number of training steps. Returns ------- schedule : Callable KL weight schedule function. """ # schedule = linear_schedule(0.1, 1, nsteps) schedule = constant_schedule(0.1) return schedule
# def _get_optimizer(self): # """ # Construct the optimizer for training, including learning rate scheduling and gradient clipping. # Returns # ------- # optimizer : nnx.Optimizer # The optimizer instance for the model. # """ # opt = optax.chain( # optax.adaptive_grad_clip(10.0), # optax.adamw(self.training_config["max_lr"]), # reduce_on_plateau( # patience=self.training_config["patience"], # cooldown=self.training_config["cooldown"], # factor=self.training_config["factor"], # rtol=self.training_config["rtol"], # accumulation_size=self.training_config["accumulation_size"], # min_scale=self.training_config["min_scale"], # ), # ) # if self.training_config["multistep"] > 1: # opt = optax.MultiSteps(opt, self.training_config["multistep"]) # optimizer = nnx.Optimizer(self.model, opt, wrt=nnx.Param) # return optimizer
[docs] def _get_optimizer(self): """ Construct the optimizer for training, including learning rate scheduling and gradient clipping. Returns ------- optimizer : nnx.Optimizer The optimizer instance for the model. """ warmup_steps = ( self.training_config["warmup_steps"] * self.training_config["multistep"] ) nsteps = self.training_config["nsteps"] max_lr = self.training_config["max_lr"] min_lr = self.training_config["min_lr"] # we define the following schedule using join schedules: warmup for warmup_steps, then constant LR until 90% of the training steps, then cosine decay to min_lr decay_transition = self.training_config["decay_transition"] warmup_schedule = optax.linear_schedule( init_value=1e-7, end_value=max_lr, transition_steps=warmup_steps ) constant_schedule = optax.constant_schedule(value=max_lr) decay_schedule = optax.cosine_decay_schedule( init_value=max_lr, decay_steps=int((1 - decay_transition) * nsteps), alpha=min_lr / max_lr, ) schedule = optax.join_schedules( schedules=[ warmup_schedule, constant_schedule, decay_schedule, ], boundaries=[warmup_steps, int(decay_transition * nsteps)], ) # define the weight decay mask to avoid applying weight decay to bias and norm parameters def decay_mask_fn(params): return jax.tree_util.tree_map(lambda x: x.ndim > 1, params) opt = optax.chain( optax.adaptive_grad_clip(10.0), optax.adamw(schedule, mask=decay_mask_fn), ) if self.training_config["multistep"] > 1: opt = optax.MultiSteps(opt, self.training_config["multistep"]) optimizer = nnx.Optimizer(self.model, opt, wrt=nnx.Param) return optimizer
@classmethod
[docs] def get_default_training_config(cls): """ Return a dictionary of default training configuration parameters for VAE training. Returns ------- training_config : dict Default training configuration. """ training_config = {} training_config["nsteps"] = 50_000 training_config["ema_decay"] = 0.999 training_config["warmup_steps"] = 500 training_config["decay_transition"] = 0.70 training_config["max_lr"] = 1e-4 training_config["min_lr"] = 1e-6 training_config["val_every"] = 100 training_config["early_stopping"] = True training_config["experiment_id"] = 1 training_config["multistep"] = 1 training_config["checkpoint_dir"] = os.path.join(os.getcwd(), "checkpoints") return training_config
[docs] def update_training_config(self, new_config): """ Update the training configuration with new parameters. Parameters ---------- new_config : dict New training configuration parameters. """ self.training_config.update(new_config) self.training_config["min_scale"] = ( self.training_config["min_lr"] / self.training_config["max_lr"] if self.training_config["max_lr"] > 0 else 0.0 ) return
[docs] def get_train_step_fn(self): """ Return the training step function, which performs a single optimization step. Returns ------- train_step : Callable JIT-compiled training step function. """ @nnx.jit def train_step(model, optimizer, batch, key, kl_weight): loss, grads = nnx.value_and_grad(self.loss_fn)(model, batch, key, kl_weight) optimizer.update(model, grads, value=loss) return loss return train_step
[docs] def get_val_step_fn(self): """ Return the validation step function, which computes the loss on a validation batch. Returns ------- val_step : Callable JIT-compiled validation step function. """ @nnx.jit def val_step(model, batch, key, kl_weight): loss = self.loss_fn(model, batch, key, kl_weight) return loss return val_step
[docs] def save_model(self, experiment_id=None): """ Save the current model and EMA model to checkpoint directories. Parameters ---------- experiment_id : int, optional Identifier for the experiment/checkpoint. If None, uses the current training config value. """ if experiment_id is None: experiment_id = self.training_config["experiment_id"] checkpoint_dir = self.training_config["checkpoint_dir"] checkpoint_dir_ema = os.path.join(self.training_config["checkpoint_dir"], "ema") os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(checkpoint_dir_ema, exist_ok=True) # Save the model checkpoint_manager = ocp.CheckpointManager( checkpoint_dir, options=ocp.CheckpointManagerOptions( max_to_keep=None, keep_checkpoints_without_metrics=True, create=True, ), ) _, state = nnx.split(self.model) checkpoint_manager.save( experiment_id, args=ocp.args.Composite(state=ocp.args.StandardSave(state)), ) checkpoint_manager.close() # Save the EMA model _, ema_state = nnx.split(self.ema_model) checkpoint_manager_ema = ocp.CheckpointManager( checkpoint_dir_ema, options=ocp.CheckpointManagerOptions( max_to_keep=None, keep_checkpoints_without_metrics=True, create=True, ), ) checkpoint_manager_ema.save( experiment_id, args=ocp.args.Composite(state=ocp.args.StandardSave(ema_state)), ) checkpoint_manager_ema.close() print("Saved model to checkpoint") return
[docs] def restore_model(self, experiment_id=None): """ Restore the model and EMA model from checkpoint directories. Parameters ---------- experiment_id : int, optional Identifier for the experiment/checkpoint. If None, uses the current training config value. """ if experiment_id is None: experiment_id = self.training_config["experiment_id"] graphdef, model_state = nnx.split(self.model) with ocp.CheckpointManager( self.training_config["checkpoint_dir"], options=ocp.CheckpointManagerOptions(read_only=True), ) as read_mgr: restored = read_mgr.restore( experiment_id, args=ocp.args.Composite( state=ocp.args.StandardRestore(item=model_state) ), ) self.model = nnx.merge(graphdef, restored["state"]) self.model.training = False # Restore the EMA model graphdef, model_state_ema = nnx.split(self.ema_model) with ocp.CheckpointManager( os.path.join(self.training_config["checkpoint_dir"], "ema"), options=ocp.CheckpointManagerOptions(read_only=True), ) as read_mgr_ema: restored_ema = read_mgr_ema.restore( experiment_id, args=ocp.args.Composite( state=ocp.args.StandardRestore(item=model_state_ema) ), ) self.ema_model = nnx.merge(graphdef, restored_ema["state"]) self.ema_model.training = False print("Restored model from checkpoint") return
[docs] def train( self, rngs: nnx.Rngs, nsteps: Optional[int] = None, save_model=True ) -> Tuple[list, list]: """ Run the training loop for the VAE model. Parameters ---------- rngs : nnx.Rngs Random number generators for training/validation steps. nsteps : int, optional Number of training steps. If None, uses the value from training config. save_model : bool, optional Whether to save the model after training. Returns ------- loss_array : list List of training losses. val_loss_array : list List of validation losses. """ self.model.train(update_KL=True) optimizer = self._get_optimizer() ema_optimizer = self._get_ema_optimizer() best_state = nnx.state(self.model) best_state_ema = nnx.state(self.ema_model) train_step = self.get_train_step_fn() val_step = self.get_val_step_fn() batch_val = next(self.val_dataset_iter) min_val = val_step(self.model, batch_val, rngs.val_step(), 1.0) val_error_ratio = 0.1 counter = 0 cmax = 10 loss_array = [] val_loss_array = [] if nsteps is None: nsteps = self.training_config["nsteps"] early_stopping = self.training_config["early_stopping"] val_every = self.training_config["val_every"] kl_schedule = self._get_kl_schedule(nsteps) experiment_id = self.training_config["experiment_id"] pbar = tqdm(range(nsteps)) l_train = None for j in pbar: if counter > cmax and early_stopping: print("Early stopping") graphdef = nnx.graphdef(self.model) self.model = nnx.merge(graphdef, best_state) self.ema_model = nnx.merge(graphdef, best_state_ema) break kl_weight = kl_schedule(j) batch = next(self.train_dataset_iter) loss = train_step( self.model, optimizer, batch, rngs.train_step(), kl_weight ) # update the parameters ema if j % self.training_config["multistep"] == 0: ema_step(self.ema_model, self.model, ema_optimizer) if j == 0: l_train = loss else: l_train = 0.9 * l_train + 0.1 * loss if j > 0 and j % val_every == 0: batch_val = next(self.val_dataset_iter) l_val = val_step(self.model, batch_val, rngs.val_step(), kl_weight) if ( l_val < l_train ): # TODO figure out something more clever to do, since the loss may be negative here ratio = 0.0 else: ratio = jnp.abs((l_train - l_val) / (l_train + 1e-8)) if ratio > val_error_ratio: counter += 1 else: counter = 0 pbar.set_postfix( loss=f"{l_train:.4f}", diff_ratio=f"{ratio:.4f}", counter=counter, val_loss=f"{l_val:.4f}", ) loss_array.append(l_train) val_loss_array.append(l_val) if l_val < min_val: min_val = l_val best_state = nnx.state(self.model) best_state_ema = nnx.state(self.ema_model) l_val = 0 l_train = 0 self.model.eval(update_KL=False) if save_model: self.save_model(experiment_id) return loss_array, val_loss_array
[docs] class VAE1DPipeline(AbstractVAEPipeline): """ Pipeline for training and evaluating 1D Variational Autoencoders (VAE1D) in GenSBI. Inherits from AbstractVAEPipeline and uses the AutoEncoder1D model class. """ def __init__( self, train_dataset, val_dataset, params: AutoEncoderParams, training_config=None, ): """ Initialize the 1D VAE pipeline. Parameters ---------- train_dataset : iterable Training dataset. val_dataset : iterable Validation dataset. params : AutoEncoderParams Model hyperparameters and configuration. training_config : dict, optional Training configuration dictionary. If None, defaults are used. """ super().__init__( AutoEncoder1D, train_dataset, val_dataset, params, training_config, ) return
[docs] class VAE2DPipeline(AbstractVAEPipeline): """ Pipeline for training and evaluating 2D Variational Autoencoders (VAE2D) in GenSBI. Inherits from AbstractVAEPipeline and uses the AutoEncoder2D model class. """ def __init__( self, train_dataset, val_dataset, params: AutoEncoderParams, training_config=None, ): """ Initialize the 2D VAE pipeline. Parameters ---------- train_dataset : iterable Training dataset. val_dataset : iterable Validation dataset. params : AutoEncoderParams Model hyperparameters and configuration. training_config : dict, optional Training configuration dictionary. If None, defaults are used. """ super().__init__( AutoEncoder2D, train_dataset, val_dataset, params, training_config, )