gensbi.recipes.unconditional_pipeline#
Pipeline for training and using a Unconditional model for simulation-based inference.
Classes#
Diffusion pipeline for training and using an Unconditional model for simulation-based inference. |
|
Flow pipeline for training and using an Unconditional model for simulation-based inference. |
Module Contents#
- class gensbi.recipes.unconditional_pipeline.UnconditionalDiffusionPipeline(model, train_dataset, val_dataset, dim_obs, ch_obs=1, params=None, training_config=None)[source]#
Bases:
gensbi.recipes.pipeline.AbstractPipelineDiffusion pipeline for training and using an Unconditional model for simulation-based inference.
- Parameters:
model (nnx.Module) – The model to be trained.
train_dataset (grain dataset or iterator over batches) – Training dataset.
val_dataset (grain dataset or iterator over batches) – Validation dataset.
dim_obs (int) – Dimension of the parameter space.
ch_obs (int) – Number of channels in the observation space.
params (optional) – Parameters for the model. Serves no use if a custom model is provided.
training_config (dict, optional) – Configuration for training. If None, default configuration is used.
Examples
Minimal example on how to instantiate and use the UnconditionalDiffusionPipeline:
1# %% Imports 2import os 3 4# Set JAX backend (use 'cuda' for GPU, 'cpu' otherwise) 5# os.environ["JAX_PLATFORMS"] = "cuda" 6 7import grain 8import numpy as np 9import jax 10from jax import numpy as jnp 11from gensbi.recipes import UnconditionalDiffusionPipeline 12from gensbi.utils.model_wrapping import _expand_dims, _expand_time 13from gensbi.utils.plotting import plot_marginals 14import matplotlib.pyplot as plt 15from gensbi.models import Simformer, SimformerParams 16 17 18from flax import nnx 19 20 21# %% define a simulator 22def simulator(key, nsamples): 23 return 3 + jax.random.normal(key, (nsamples, 2)) * jnp.array([0.5, 1]).reshape( 24 1, 2 25 ) # a simple 2D gaussian 26 27 28# %% 29 30 31# %% Define your training and validation datasets. 32train_data = simulator(jax.random.PRNGKey(0), 100_000).reshape(-1, 2, 1) 33val_data = simulator(jax.random.PRNGKey(1), 2000).reshape(-1, 2, 1) 34# %% 35# %% Normalize the dataset 36# It is important to normalize the data to have zero mean and unit variance. 37# This helps the model training process. 38means = jnp.mean(train_data, axis=0) 39stds = jnp.std(train_data, axis=0) 40 41 42def normalize(data, means, stds): 43 return (data - means) / stds 44 45 46def unnormalize(data, means, stds): 47 return data * stds + means 48 49 return normalize(data, means, stds) 50 51 52def process_data(data): 53 return normalize(data, means, stds) 54 55 56# %% Create the input pipeline using Grain 57# We use Grain to create an efficient input pipeline. 58# This involves shuffling, repeating for multiple epochs, and batching the data. 59# We also map the process_data function to prepare (normalize) the data. 60batch_size = 256 61 62train_dataset_grain = ( 63 grain.MapDataset.source(np.array(train_data)) 64 .shuffle(42) 65 .repeat() 66 .to_iter_dataset() 67 .batch(batch_size) 68 .map(process_data) 69 # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching 70) 71 72val_dataset_grain = ( 73 grain.MapDataset.source(np.array(val_data)) 74 .shuffle( 75 42 76 ) # Use a different seed/strategy for validation if needed, but shuffling is fine 77 .repeat() 78 .to_iter_dataset() 79 .batch(batch_size) 80 .map(process_data) 81 # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching 82) 83# %% Define your model 84# Here we define a MLP velocity field model, 85# this model only works for inputs of shape (batch, dim, 1). 86# For more complex models, please refer to the transformer-based models in gensbi.models. 87 88 89class MLP(nnx.Module): 90 def __init__(self, input_dim: int = 2, hidden_dim: int = 512, *, rngs: nnx.Rngs): 91 92 self.input_dim = input_dim 93 self.hidden_dim = hidden_dim 94 95 din = input_dim + 1 96 97 self.linear1 = nnx.Linear(din, self.hidden_dim, rngs=rngs) 98 self.linear2 = nnx.Linear(self.hidden_dim, self.hidden_dim, rngs=rngs) 99 self.linear3 = nnx.Linear(self.hidden_dim, self.hidden_dim, rngs=rngs) 100 self.linear4 = nnx.Linear(self.hidden_dim, self.hidden_dim, rngs=rngs) 101 self.linear5 = nnx.Linear(self.hidden_dim, self.input_dim, rngs=rngs) 102 103 def __call__(self, t: jax.Array, obs: jax.Array, node_ids, *args, **kwargs): 104 obs = _expand_dims(obs)[ 105 ..., 0 106 ] # for this specific model, we use samples of shape (batch, dim), while for transformer models we use (batch, dim, c) 107 t = _expand_time(t) 108 if t.ndim == 3: 109 t = t.reshape(t.shape[0], t.shape[1]) 110 t = jnp.broadcast_to(t, (obs.shape[0], 1)) 111 112 h = jnp.concatenate([obs, t], axis=-1) 113 114 x = self.linear1(h) 115 x = jax.nn.gelu(x) 116 117 x = self.linear2(x) 118 x = jax.nn.gelu(x) 119 120 x = self.linear3(x) 121 x = jax.nn.gelu(x) 122 123 x = self.linear4(x) 124 x = jax.nn.gelu(x) 125 126 x = self.linear5(x) 127 128 return x[..., None] # return shape (batch, dim, 1) 129 130 131model = MLP( 132 rngs=nnx.Rngs(42) 133) # your nnx.Module model here, e.g., a simple MLP, or the Simformer model 134# if you define a custom model, it should take as input the following arguments: 135# t: Array, 136# obs: Array, 137# node_ids: Array (optional, if your model is a transformer-based model) 138# *args 139# **kwargs 140 141# the obs input should have shape (batch_size, dim_joint, c), and the output will be of the same shape 142# %% Instantiate the pipeline 143dim_obs = 2 # Dimension of the parameter space 144ch_obs = 1 # Number of channels of the parameter space 145 146# The UnconditionalDiffusionPipeline handles the training loop and sampling. 147# We configure it with the model, datasets, dimensions using a default training configuration. 148training_config = UnconditionalDiffusionPipeline.get_default_training_config() 149training_config["nsteps"] = 10000 150 151pipeline = UnconditionalDiffusionPipeline( 152 model, 153 train_dataset_grain, 154 val_dataset_grain, 155 dim_obs, 156 ch_obs, 157 training_config=training_config, 158) 159 160# %% Train the model 161# We create a random key for training and start the training process. 162rngs = nnx.Rngs(42) 163pipeline.train( 164 rngs, save_model=False 165) # if you want to save the model, set save_model=True 166 167# %% Sample from the posterior 168# We generate new samples using the trained model. 169samples = pipeline.sample(rngs.sample(), nsamples=100_000) 170# Finally, we unnormalize the samples to get them back to the original scale. 171samples = unnormalize(samples, means, stds) 172 173# %% Plot the samples 174# We verify the model's performance by plotting the marginal distributions of the generated samples. 175samples.mean(axis=0), samples.std(axis=0) 176# %% 177 178plot_marginals( 179 np.array(samples[..., 0]), true_param=[3, 3], gridsize=20, range=[(-2, 8), (-2, 8)] 180) 181plt.savefig("unconditional_diffusion_samples.png", dpi=300, bbox_inches="tight") 182plt.show() 183 184# %%
Note
If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a
if __name__ == "__main__":guard. See https://docs.python.org/3/library/multiprocessing.html- _wrap_model()[source]#
Wrap the model for evaluation (either using JointWrapper or ConditionalWrapper).
- classmethod get_default_training_config(sde='EDM')[source]#
Return a dictionary of default training configuration parameters.
- Returns:
training_config – Default training configuration.
- Return type:
dict
- get_sampler(nsteps=18, use_ema=True, return_intermediates=False, **model_extras)[source]#
Get a sampler function for generating samples from the trained model.
- Parameters:
key (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable.
step_size (float, optional) – Step size for the sampler.
use_ema (bool, optional) – Whether to use the EMA model for sampling.
time_grid (array-like, optional) – Time grid for the sampler (if applicable).
model_extras (dict, optional) – Additional model-specific parameters.
- Returns:
sampler – A function that generates samples when called with a random key and number of samples.
- Return type:
Callable: key, nsamples -> samples
- classmethod init_pipeline_from_config()[source]#
- Abstractmethod:
Initialize the pipeline from a configuration file.
- Parameters:
train_dataset (iterable) – Training dataset.
val_dataset (iterable) – Validation dataset.
dim_obs (int) – Dimensionality of the parameter (theta) space.
dim_cond (int) – Dimensionality of the observation (x) space.
config_path (str) – Path to the configuration file.
checkpoint_dir (str) – Directory for saving checkpoints.
- Returns:
pipeline – An instance of the pipeline initialized from the configuration.
- Return type:
- sample(key, nsamples=10000, nsteps=18, use_ema=True, return_intermediates=False, **model_extras)[source]#
Generate samples from the trained model.
- Parameters:
key (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable (e.g., observed data).
nsamples (int, optional) – Number of samples to generate.
- Returns:
samples – Generated samples of size (nsamples, dim_obs, ch_obs).
- Return type:
array-like
- abstractmethod sample_batched(*args, **kwargs)[source]#
Generate samples from the trained model in batches.
- Parameters:
key (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable (e.g., observed data).
nsamples (int) – Number of samples to generate.
chunk_size (int, optional) – Size of each batch for sampling. Default is 50.
show_progress_bars (bool, optional) – Whether to display progress bars during sampling. Default is True.
args (tuple) – Additional positional arguments for the sampler.
kwargs (dict) – Additional keyword arguments for the sampler.
- Returns:
samples – Generated samples of shape (nsamples, batch_size_cond, dim_obs, ch_obs).
- Return type:
array-like
- class gensbi.recipes.unconditional_pipeline.UnconditionalFlowPipeline(model, train_dataset, val_dataset, dim_obs, ch_obs=1, params=None, training_config=None)[source]#
Bases:
gensbi.recipes.pipeline.AbstractPipelineFlow pipeline for training and using an Unconditional model for simulation-based inference.
- Parameters:
model (nnx.Module) – The model to be trained.
train_dataset (grain dataset or iterator over batches) – Training dataset.
val_dataset (grain dataset or iterator over batches) – Validation dataset.
dim_obs (int) – Dimension of the parameter space.
ch_obs (int) – Number of channels in the observation space.
params (optional) – Parameters for the model. Serves no use if a custom model is provided.
training_config (dict, optional) – Configuration for training. If None, default configuration is used.
Examples
Minimal example on how to instantiate and use the UnconditionalFlowPipeline:
1# %% Imports 2import os 3 4# Set JAX backend (use 'cuda' for GPU, 'cpu' otherwise) 5# os.environ["JAX_PLATFORMS"] = "cuda" 6 7import grain 8import numpy as np 9import jax 10from jax import numpy as jnp 11from gensbi.recipes import UnconditionalFlowPipeline 12from gensbi.utils.model_wrapping import _expand_dims, _expand_time 13from gensbi.utils.plotting import plot_marginals 14import matplotlib.pyplot as plt 15 16 17from flax import nnx 18 19 20# %% define a simulator 21def simulator(key, nsamples): 22 return 3 + jax.random.normal(key, (nsamples, 2)) * jnp.array([0.5, 1]).reshape( 23 1, 2 24 ) # a simple 2D gaussian 25 26 27# %% Define your training and validation datasets. 28# We generate a training dataset and a validation dataset using the simulator. 29# The simulator generates samples from a 2D Gaussian distribution. 30train_data = simulator(jax.random.PRNGKey(0), 100_000).reshape(-1, 2, 1) 31val_data = simulator(jax.random.PRNGKey(1), 2000).reshape(-1, 2, 1) 32 33# %% Normalize the dataset 34# It is important to normalize the data to have zero mean and unit variance. 35# This helps the model training process. 36means = jnp.mean(train_data, axis=0) 37stds = jnp.std(train_data, axis=0) 38 39 40def normalize(data, means, stds): 41 return (data - means) / stds 42 43 44def unnormalize(data, means, stds): 45 return data * stds + means 46 47 48def process_data(data): 49 return normalize(data, means, stds) 50 51 52# %% Create the input pipeline using Grain 53# We use Grain to create an efficient input pipeline. 54# This involves shuffling, repeating for multiple epochs, and batching the data. 55# We also map the process_data function to prepare (normalize) the data. 56batch_size = 256 57 58train_dataset_grain = ( 59 grain.MapDataset.source(np.array(train_data)) 60 .shuffle(42) 61 .repeat() 62 .to_iter_dataset() 63 .batch(batch_size) 64 .map(process_data) 65 # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching 66) 67 68val_dataset_grain = ( 69 grain.MapDataset.source(np.array(val_data)) 70 .shuffle( 71 42 72 ) # Use a different seed/strategy for validation if needed, but shuffling is fine 73 .repeat() 74 .to_iter_dataset() 75 .batch(batch_size) 76 .map(process_data) 77 # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching 78) 79 80 81# %% Define your model 82# Here we define a MLP velocity field model, 83# this model only works for inputs of shape (batch, dim, 1). 84# For more complex models, please refer to the transformer-based models in gensbi.models. 85class MLP(nnx.Module): 86 def __init__(self, input_dim: int = 2, hidden_dim: int = 128, *, rngs: nnx.Rngs): 87 88 self.input_dim = input_dim 89 self.hidden_dim = hidden_dim 90 91 din = input_dim + 1 92 93 self.linear1 = nnx.Linear(din, self.hidden_dim, rngs=rngs) 94 self.linear2 = nnx.Linear(self.hidden_dim, self.hidden_dim, rngs=rngs) 95 self.linear3 = nnx.Linear(self.hidden_dim, self.hidden_dim, rngs=rngs) 96 self.linear4 = nnx.Linear(self.hidden_dim, self.hidden_dim, rngs=rngs) 97 self.linear5 = nnx.Linear(self.hidden_dim, self.input_dim, rngs=rngs) 98 99 def __call__(self, t: jax.Array, obs: jax.Array, node_ids, *args, **kwargs): 100 obs = _expand_dims(obs)[ 101 ..., 0 102 ] # for this specific model, we use samples of shape (batch, dim), while for transformer models we use (batch, dim, c) 103 t = _expand_time(t) 104 t = jnp.broadcast_to(t, (obs.shape[0], 1)) 105 106 h = jnp.concatenate([obs, t], axis=-1) 107 108 x = self.linear1(h) 109 x = jax.nn.gelu(x) 110 111 x = self.linear2(x) 112 x = jax.nn.gelu(x) 113 114 x = self.linear3(x) 115 x = jax.nn.gelu(x) 116 117 x = self.linear4(x) 118 x = jax.nn.gelu(x) 119 120 x = self.linear5(x) 121 122 return x[..., None] # return shape (batch, dim, 1) 123 124 125model = MLP( 126 rngs=nnx.Rngs(42) 127) # your nnx.Module model here, e.g., a simple MLP, or the Simformer model 128# if you define a custom model, it should take as input the following arguments: 129# t: Array, 130# obs: Array, 131# node_ids: Array (optional, if your model is a transformer-based model) 132# *args 133# **kwargs 134 135# the obs input should have shape (batch_size, dim_joint, c), and the output will be of the same shape 136 137# %% Instantiate the pipeline 138# The UnconditionalFlowPipeline handles the training loop and sampling. 139# We configure it with the model, datasets, dimensions using a default training configuration. 140training_config = UnconditionalFlowPipeline.get_default_training_config() 141training_config["nsteps"] = 10000 142 143dim_obs = 2 # Dimension of the parameter space 144ch_obs = 1 # Number of channels of the parameter space 145 146pipeline = UnconditionalFlowPipeline( 147 model, 148 train_dataset_grain, 149 val_dataset_grain, 150 dim_obs, 151 ch_obs, 152 training_config=training_config, 153) 154 155# %% Train the model 156# We create a random key for training and start the training process. 157rngs = nnx.Rngs(42) 158pipeline.train( 159 rngs, save_model=False 160) # if you want to save the model, set save_model=True 161 162# %% Sample from the posterior 163# We generate new samples using the trained model. 164samples = pipeline.sample(rngs.sample(), nsamples=100_000) 165# Finally, we unnormalize the samples to get them back to the original scale. 166samples = unnormalize(samples, means, stds) 167 168# %% Plot the samples 169# We verify the model's performance by plotting the marginal distributions of the generated samples. 170plot_marginals( 171 np.array(samples[..., 0]), true_param=[3, 3], gridsize=30, range=[(-2, 8), (-2, 8)] 172) 173plt.savefig("unconditional_flow_samples.png", dpi=300, bbox_inches="tight") 174plt.show() 175# %%
Note
If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a
if __name__ == "__main__":guard. See https://docs.python.org/3/library/multiprocessing.html- _wrap_model()[source]#
Wrap the model for evaluation (either using JointWrapper or ConditionalWrapper).
- get_sampler(step_size=0.01, use_ema=True, time_grid=None, **model_extras)[source]#
Get a sampler function for generating samples from the trained model.
- Parameters:
key (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable.
step_size (float, optional) – Step size for the sampler.
use_ema (bool, optional) – Whether to use the EMA model for sampling.
time_grid (array-like, optional) – Time grid for the sampler (if applicable).
model_extras (dict, optional) – Additional model-specific parameters.
- Returns:
sampler – A function that generates samples when called with a random key and number of samples.
- Return type:
Callable: key, nsamples -> samples
- classmethod init_pipeline_from_config()[source]#
- Abstractmethod:
Initialize the pipeline from a configuration file.
- Parameters:
train_dataset (iterable) – Training dataset.
val_dataset (iterable) – Validation dataset.
dim_obs (int) – Dimensionality of the parameter (theta) space.
dim_cond (int) – Dimensionality of the observation (x) space.
config_path (str) – Path to the configuration file.
checkpoint_dir (str) – Directory for saving checkpoints.
- Returns:
pipeline – An instance of the pipeline initialized from the configuration.
- Return type:
- sample(key, nsamples=10000, step_size=0.01, use_ema=True, time_grid=None, **model_extras)[source]#
Generate samples from the trained model.
- Parameters:
key (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable (e.g., observed data).
nsamples (int, optional) – Number of samples to generate.
- Returns:
samples – Generated samples of size (nsamples, dim_obs, ch_obs).
- Return type:
array-like
- abstractmethod sample_batched(*args, **kwargs)[source]#
Generate samples from the trained model in batches.
- Parameters:
key (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable (e.g., observed data).
nsamples (int) – Number of samples to generate.
chunk_size (int, optional) – Size of each batch for sampling. Default is 50.
show_progress_bars (bool, optional) – Whether to display progress bars during sampling. Default is True.
args (tuple) – Additional positional arguments for the sampler.
kwargs (dict) – Additional keyword arguments for the sampler.
- Returns:
samples – Generated samples of shape (nsamples, batch_size_cond, dim_obs, ch_obs).
- Return type:
array-like