Training Guide#
This guide details how the training pipeline works in GenSBI, best practices for training flow and diffusion models, and how to customize the training loop.
Training 101#
Training a flow matching model is extremely easy using the default pipeline. Here is a quick example:
from flax import nnx
from gensbi.recipes import Flux1FlowPipeline
from gensbi.models import Flux1Params
train_dataset = ... # define a training dataset (infinite iterator)
val_dataset = ... # define a validation dataset (infinite iterator)
obs_dim = ... # dimension of the parameters (theta)
cond_dim = ... # dimension of the simulator observations (x)
params = Flux1Params(...) # the parameters for your model
# Instantiate the pipeline
pipeline = Flux1FlowPipeline(
train_dataset,
val_dataset,
obs_dim,
cond_dim,
params=params,
)
# Train the model
# Note: GenSBI uses Flax NNX, so we pass a random key generator
pipeline.train(rngs=nnx.Rngs(0))
For a full example, see the 5-minute guide or the full notebook example.
Pipeline Overview#
The default training pipeline (e.g., Flux1FlowPipeline, SimformerFlowPipeline) is built on the AbstractPipeline class. It manages the entire training lifecycle:
State Management: Uses Flax NNX for managing model parameters and optimizer states.
Steps vs. Epochs: Training runs for a fixed number of steps (
num_steps), not epochs. This is common in generative modeling where datasets (like simulation outputs) might be effectively infinite.EMA (Exponential Moving Average): The pipeline maintains a shadow copy of the model weights using EMA. This version is smoother and often yields better generation results. It is saved in a separate
checkpoints/emafolder.Early Stopping: Validation runs every
val_everysteps. If the validation loss stops improving or diverges significantly from training loss (controlled byval_error_ratio), training stops early.Checkpointing: Models are automatically saved to the
checkpointsdirectory. Both the latest training state and the EMA state are preserved.
Data Preprocessing#
To ensure efficient training, especially when your simulator is computationally expensive or when training on GPUs, it is highly advisable to use an optimized data loader. We recommend using grain, a high-performance data loader for JAX, with multiprocessing pre-fetching.
However, using multiprocessing in a Python script alongside JAX requires careful handling to prevent subprocesses from attempting to initialize GPU resources, which can lead to errors or hangs.
Here is a step-by-step example of how to guard your script and set up a performant data loader:
1. Guarding GPU Initialization#
The very first lines of your script must control the JAX_PLATFORMS environment variable. This ensures that only the main process attempts to use the GPU (CUDA), while worker subprocesses (spawned by the data loader) are forced to use the CPU.
import os
if __name__ != "__main__":
# Worker processes must not touch the GPU
os.environ["JAX_PLATFORMS"] = "cpu"
else:
# Main process uses the GPU (change to 'cpu' if you don't have one)
os.environ["JAX_PLATFORMS"] = "cuda"
import grain.python as grain # Import after setting env vars
import numpy as np
import jax
from jax import numpy as jnp
from numpyro import distributions as dist
2. Defining the Simulator#
Next, define your simulator logic. In this example, we create a simple simulator that takes parameters (\(\theta\)) and generates observations (\(x\)).
def _simulator(key, thetas):
# Simulate data: x depends on theta with some noise
xs = thetas + 1 + jax.random.normal(key, thetas.shape) * 0.1
# Ensure correct shapes for concatenation
thetas = thetas[..., None]
xs = xs[..., None]
# For Joint Pipelines, concatenating theta and x is common
# Typically: data = [parameters, observations]
data = jnp.concatenate([thetas, xs], axis=1)
return data
# Define a prior distribution for parameters
theta_prior = dist.Uniform(
low=jnp.array([-2.0, -2.0, -2.0]), high=jnp.array([2.0, 2.0, 2.0])
)
def simulator(key, nsamples):
theta_key, sample_key = jax.random.split(key, 2)
# Sample parameters from the prior
thetas = theta_prior.sample(theta_key, (nsamples,))
# Run simulation
return _simulator(sample_key, thetas)
3. Setting up the Data Loader#
In your main execution block, generate your dataset and configure the grain loader. We use grain.MapDataset for simple array-based data, shuffle it, and then use mp_prefetch to load batches in parallel.
def main():
# --- 1. Define Dimensions ---
obs_dim = 3 # Dimension of simulator input (theta)
cond_dim = 3 # Dimension of simulator output (x)
# --- 2. Generate Data ---
# In a real scenario, you might load this from disk
train_data = jnp.asarray(simulator(jax.random.PRNGKey(0), 10_000), dtype=jnp.bfloat16)
val_data = jnp.asarray(simulator(jax.random.PRNGKey(1), 2_000), dtype=jnp.bfloat16)
# Helper to split the joint data back into (obs, cond) for the pipeline
def split_obs_cond(data):
# Assuming first 'obs_dim' columns are observations (theta)
# and remaining columns are conditions (x)
return data[:, :obs_dim], data[:, obs_dim:]
batch_size = 1024
# --- 3. Create Training Pipeline with Grain ---
# Create basic source dataset
train_ds = grain.MapDataset.source(np.array(train_data))
train_ds = train_ds.shuffle(seed=42).repeat() # Infinite iterator
# Convert to iter dataset for batching and prefetching
train_iter_ds = train_ds.to_iter_dataset()
# Automatically determine optimal multiprocessing settings
performance_config = grain.experimental.pick_performance_config(
ds=train_iter_ds,
ram_budget_mb=1024 * 4, # Adjust based on available RAM
max_workers=None, # Defaults to CPU count
max_buffer_size=None,
)
# Apply batching, splitting, and multiprocessing
train_dataset_grain = (
train_iter_ds
.batch(batch_size)
.map(split_obs_cond)
.mp_prefetch(performance_config.multiprocessing_options)
)
# --- 4. Create Validation Pipeline ---
# Validation usually doesn't need complex prefetching
val_dataset_grain = (
grain.MapDataset.source(np.array(val_data))
.to_iter_dataset()
.batch(batch_size)
.map(split_obs_cond)
)
# ... Initialize and run your GenSBI pipeline here ...
print("Data loaders ready. Starting training...")
# ....
# Run main only if this is the entry point
if __name__ == "__main__":
main()
To see a complete runnable script that puts all of this together, check out the first_model.py example.
Configuration & Hyperparameters#
You can control most training behaviors via the training_config dictionary without subclassing the pipeline.
Default Configuration Keys#
These are the standard hyperparameters available in AbstractPipeline:
Key |
Default |
Description |
|---|---|---|
|
|
Total number of training steps. |
|
|
Maximum learning rate. |
|
|
Steps to wait for improvement before reducing LR (via |
|
|
Gradient Accumulation. Accumulates gradients over \(N\) steps before updating weights. |
|
|
Decay rate for the Exponential Moving Average. |
|
|
How frequently (in steps) to run validation. |
|
|
Whether to enable early stopping. |
|
|
Directory to save model states. |
Modifying the Configuration#
To change parameters from the default training configuration, you can pass a customized dictionary to the pipeline constructor. A common use case is changing the checkpoint directory or the number of training steps.
from gensbi.recipes.flux1 import Flux1FlowPipeline
# 1. Retrieve the default configuration
training_config = Flux1FlowPipeline._get_default_training_config()
# 2. Modify specific settings
training_config["checkpoint_dir"] = "/path/to/custom/checkpoints"
training_config["num_steps"] = 50_000 # Train for longer
# 3. Instantiate the pipeline with the custom config
pipeline = Flux1FlowPipeline(
..., # Pass standard arguments (model, datasets, etc.) here
training_config=training_config, # Pass the modified config here
)
Best Practices#
1. Prefer Flow Matching over Diffusion#
Flow Matching models are the recommended default in GenSBI. While the library supports standard Diffusion, Flow Matching models are generally easier to train (straighter paths in latent space) and faster to sample from. Use Flux1FlowPipeline or similar classes unless you have a specific research need for Diffusion.
2. Use Large Effective Batch Sizes#
Generative models benefit significantly from seeing many different time steps (\(t \in [0, 1]\)) in a single update.
Target: 1,000 to 10,000 samples per batch (start at 1024).
Why: Small batches lead to high variance gradients because they cover only a tiny slice of the time interval.
3. Gradient Accumulation (multistep)#
If your GPU cannot fit a batch size of 1024+, use the multistep configuration to achieve a large effective batch size.
For example, if your GPU fits a batch of 128:
Set physical
batch_size=128in your DataLoader.Set
training_config["multistep"] = 8.Effective Batch Size = \(128 \times 8 = 1024\).
Advanced Customization#
If the options provided by training_config aren’t enough, you can subclass the pipeline to modify internal logic.
Customizing the Optimizer#
Override _get_optimizer to change the optimization algorithm or schedule.
import optax
from flax import nnx
from gensbi.recipes.flux1 import Flux1FlowPipeline
class CustomFluxPipeline(Flux1FlowPipeline):
def _get_optimizer(self):
"""
Overrides the default optimizer to use a linear schedule with AdamW.
"""
# Retrieve config values
lr = self.training_config["max_lr"]
steps = self.training_config["num_steps"]
# Example: Linear decay schedule
schedule = optax.linear_schedule(
init_value=lr,
end_value=0.0,
transition_steps=steps
)
# Use AdamW with the schedule
opt = optax.adamw(learning_rate=schedule, weight_decay=1e-4)
# Return nnx.Optimizer wrapping the model
return nnx.Optimizer(self.model, opt, wrt=nnx.Param)
Writing a New Training Loop#
If you need to change how batches are processed or add custom logging, override the train method.
def train(
self, rngs: nnx.Rngs, nsteps: Optional[int] = None, save_model=True
) -> Tuple[list, list]:
"""
Run the training loop for the model.
"""
# Initialize optimizers
optimizer = self._get_optimizer()
ema_optimizer = self._get_ema_optimizer()
# Save initial state for best model tracking
best_state = nnx.state(self.model)
best_state_ema = nnx.state(self.ema_model)
loss_fn = self.get_loss_fn()
# JIT compile steps
train_step = self.get_train_step_fn(loss_fn)
val_step = self.get_val_step_fn(loss_fn)
# Initial validation check
batch_val = next(self.val_dataset_iter)
min_val = val_step(self.model, batch_val, rngs.val_step())
val_error_ratio = 1.1
counter = 0
cmax = 10
loss_array = []
val_loss_array = []
self.model.train()
if nsteps is None:
nsteps = self.training_config["num_steps"]
early_stopping = self.training_config["early_stopping"]
val_every = self.training_config["val_every"]
experiment_id = self.training_config["experiment_id"]
pbar = tqdm(range(nsteps))
l_train = None
for j in pbar:
# Check early stopping conditions
if counter > cmax and early_stopping:
print("Early stopping")
# Restore best state
graphdef = nnx.graphdef(self.model)
self.model = nnx.merge(graphdef, best_state)
self.ema_model = nnx.merge(graphdef, best_state_ema)
break
batch = next(self.train_dataset_iter)
# Optimization step
loss = train_step(
self.model, optimizer, batch, rngs.train_step()
)
# EMA Update
if j % self.training_config["multistep"] == 0:
ema_step(self.ema_model, self.model, ema_optimizer)
# Smoothing loss for display
if j == 0:
l_train = loss
else:
l_train = 0.9 * l_train + 0.1 * loss
# Validation Loop
if j > 0 and j % val_every == 0:
batch_val = next(self.val_dataset_iter)
l_val = val_step(self.model, batch_val, rngs.val_step())
# Divergence check
ratio = l_val / l_train
if ratio > val_error_ratio:
counter += 1
else:
counter = 0
pbar.set_postfix(
loss=f"{l_train:.4f}",
ratio=f"{ratio:.4f}",
counter=counter,
val_loss=f"{l_val:.4f}",
)
loss_array.append(l_train)
val_loss_array.append(l_val)
# Keep track of best model
if l_val < min_val:
min_val = l_val
best_state = nnx.state(self.model)
best_state_ema = nnx.state(self.ema_model)
l_val = 0
l_train = 0
self.model.eval()
if save_model:
self.save_model(experiment_id)
self._wrap_model()
return loss_array, val_loss_array
Further Help#
If you have any questions not covered in this guide or encounter bugs while training your models, please feel free to open an issue on the GitHub page.