gensbi.recipes#
Cookie cutter modules for creating and training SBI models.
Submodules#
Classes#
Diffusion pipeline for training and using a Conditional model for simulation-based inference. |
|
Flow pipeline for training and using a Conditional model for simulation-based inference. |
|
Diffusion pipeline for training and using a Conditional model for simulation-based inference. |
|
Flow pipeline for training and using a Conditional model for simulation-based inference. |
|
Diffusion pipeline for training and using a Joint model for simulation-based inference. |
|
Flow pipeline for training and using a Joint model for simulation-based inference. |
|
Diffusion pipeline for training and using a Joint model for simulation-based inference. |
|
Flow pipeline for training and using a Joint model for simulation-based inference. |
|
Diffusion pipeline for training and using a Joint model for simulation-based inference. |
|
Flow pipeline for training and using a Joint model for simulation-based inference. |
|
Diffusion pipeline for training and using an Unconditional model for simulation-based inference. |
|
Flow pipeline for training and using an Unconditional model for simulation-based inference. |
Package Contents#
- class gensbi.recipes.ConditionalDiffusionPipeline(model, train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, ch_cond=1, id_embedding_strategy=('absolute', 'absolute'), params=None, training_config=None)[source]#
Bases:
gensbi.recipes.pipeline.AbstractPipelineDiffusion pipeline for training and using a Conditional model for simulation-based inference.
- Parameters:
train_dataset (grain dataset or iterator over batches) – Training dataset.
val_dataset (grain dataset or iterator over batches) – Validation dataset.
dim_obs (int or tuple of int) – Dimension of the parameter space (number of tokens). Can represent unstructured data, time-series, or patchified 2D images. For images, provide a tuple (height, width).
dim_cond (int or tuple of int) – Dimension of the observation space (number of tokens). Can represent unstructured data, time-series, or patchified 2D images. For images, provide a tuple (height, width).
params (ConditionalParams, optional) – Parameters for the Conditional model. If None, default parameters are used.
training_config (dict, optional) – Configuration for training. If None, default configuration is used.
Examples
Minimal example on how to instantiate and use the ConditionalDiffusionPipeline:
1# %% Imports 2import os 3 4# Set JAX backend (use 'cuda' for GPU, 'cpu' otherwise) 5# os.environ["JAX_PLATFORMS"] = "cuda" 6 7import grain 8import numpy as np 9import jax 10from jax import numpy as jnp 11from numpyro import distributions as dist 12from flax import nnx 13 14from gensbi.recipes import ConditionalDiffusionPipeline 15from gensbi.models import Flux1, Flux1Params 16 17from gensbi.utils.plotting import plot_marginals 18import matplotlib.pyplot as plt 19 20 21# %% 22 23theta_prior = dist.Uniform( 24 low=jnp.array([-2.0, -2.0, -2.0]), high=jnp.array([2.0, 2.0, 2.0]) 25) 26 27dim_obs = 3 28dim_cond = 3 29dim_joint = dim_obs + dim_cond 30 31 32# %% 33def simulator(key, nsamples): 34 theta_key, sample_key = jax.random.split(key, 2) 35 thetas = theta_prior.sample(theta_key, (nsamples,)) 36 37 xs = thetas + 1 + jax.random.normal(sample_key, thetas.shape) * 0.1 38 39 thetas = thetas[..., None] 40 xs = xs[..., None] 41 42 # when making a dataset for the joint pipeline, thetas need to come first 43 data = jnp.concatenate([thetas, xs], axis=1) 44 45 return data 46 47 48# %% Define your training and validation datasets. 49# We generate a training dataset and a validation dataset using the simulator. 50# The simulator is a simple function that generates parameters (theta) and data (x). 51# In this example, we use a simple Gaussian simulator. 52train_data = simulator(jax.random.PRNGKey(0), 100_000) 53val_data = simulator(jax.random.PRNGKey(1), 2000) 54# %% Normalize the dataset 55# It is important to normalize the data to have zero mean and unit variance. 56# This helps the model training process. 57means = jnp.mean(train_data, axis=0) 58stds = jnp.std(train_data, axis=0) 59 60 61def normalize(data, means, stds): 62 return (data - means) / stds 63 64 65def unnormalize(data, means, stds): 66 return data * stds + means 67 68 69# %% Prepare the data for the pipeline 70# The pipeline expects the data to be split into observations and conditions. 71# We also apply normalization at this stage. 72def split_obs_cond(data): 73 data = normalize(data, means, stds) 74 return ( 75 data[:, :dim_obs], 76 data[:, dim_obs:], 77 ) # assuming first dim_obs are obs, last dim_cond are cond 78 79 80# %% 81 82# %% Create the input pipeline using Grain 83# We use Grain to create an efficient input pipeline. 84# This involves shuffling, repeating for multiple epochs, and batching the data. 85# We also map the split_obs_cond function to prepare the data for the model. 86batch_size = 256 87 88train_dataset_grain = ( 89 grain.MapDataset.source(np.array(train_data)) 90 .shuffle(42) 91 .repeat() 92 .to_iter_dataset() 93 .batch(batch_size) 94 .map(split_obs_cond) 95 # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching 96) 97 98val_dataset_grain = ( 99 grain.MapDataset.source(np.array(val_data)) 100 .shuffle( 101 42 102 ) # Use a different seed/strategy for validation if needed, but shuffling is fine 103 .repeat() 104 .to_iter_dataset() 105 .batch(batch_size) 106 .map(split_obs_cond) 107 # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching 108) 109 110# %% Define your model 111# specific model parameters are defined here. 112# For Flux1, we need to specify dimensions, embedding strategies, and other architecture details. 113params = Flux1Params( 114 in_channels=1, 115 vec_in_dim=None, 116 context_in_dim=1, 117 mlp_ratio=3, 118 num_heads=2, 119 depth=4, 120 depth_single_blocks=8, 121 axes_dim=[ 122 10, 123 ], 124 qkv_bias=True, 125 dim_obs=dim_obs, 126 dim_cond=dim_cond, 127 id_embedding_strategy=("absolute", "absolute"), 128 theta=10 * dim_joint, 129 rngs=nnx.Rngs(default=42), 130 param_dtype=jnp.float32, 131) 132 133model = Flux1(params) 134 135# %% Instantiate the pipeline 136# The ConditionalDiffusionPipeline handles the training loop and sampling. 137# We configure it with the model, datasets, dimensions using a default training configuration. 138training_config = ConditionalDiffusionPipeline.get_default_training_config() 139training_config["nsteps"] = 10000 140 141pipeline = ConditionalDiffusionPipeline( 142 model, 143 train_dataset_grain, 144 val_dataset_grain, 145 dim_obs, 146 dim_cond, 147 training_config=training_config, 148) 149 150# %% Train the model 151# We create a random key for training and start the training process. 152rngs = nnx.Rngs(42) 153pipeline.train( 154 rngs, save_model=False 155) # if you want to save the model, set save_model=True 156 157# %% Sample from the posterior 158# To generate samples, we first need an observation (and its corresponding condition). 159# We generate a new sample from the simulator, normalize it, and extract the condition x_o. 160 161new_sample = simulator(jax.random.PRNGKey(20), 1) 162true_theta = new_sample[:, :dim_obs, :] # extract observation from the joint sample 163 164new_sample = normalize(new_sample, means, stds) 165x_o = new_sample[:, dim_obs:, :] # extract condition from the joint sample 166 167# Then we invoke the pipeline's sample method. 168samples = pipeline.sample(rngs.sample(), x_o, nsamples=100_000) 169# Finally, we unnormalize the samples to get them back to the original scale. 170samples = unnormalize(samples, means[:dim_obs], stds[:dim_obs]) 171 172# %% Plot the samples 173# We verify the model's performance by plotting the marginal distributions of the generated samples 174# against the true parameters. 175plot_marginals( 176 np.array(samples[..., 0]), 177 gridsize=30, 178 true_param=np.array(true_theta[0, :, 0]), 179 range=[(1, 3), (1, 3), (-0.6, 0.5)], 180) 181 182plt.savefig( 183 "conditional_diffusion_pipeline_marginals.png", dpi=100, bbox_inches="tight" 184) 185plt.show() 186 187# %%
Note
If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a
if __name__ == "__main__":guard. See https://docs.python.org/3/library/multiprocessing.htmlNote
Sampling in the latent space (latent diffusion/flow) is not currently supported.
- _wrap_model()[source]#
Wrap the model for evaluation (either using JointWrapper or ConditionalWrapper).
- classmethod get_default_training_config()[source]#
Return a dictionary of default training configuration parameters.
- Returns:
training_config – Default training configuration.
- Return type:
dict
- get_sampler(x_o, nsteps=18, use_ema=True, return_intermediates=False, **model_extras)[source]#
Get a sampler function for generating samples from the trained model.
- Parameters:
key (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable.
step_size (float, optional) – Step size for the sampler.
use_ema (bool, optional) – Whether to use the EMA model for sampling.
time_grid (array-like, optional) – Time grid for the sampler (if applicable).
model_extras (dict, optional) – Additional model-specific parameters.
- Returns:
sampler – A function that generates samples when called with a random key and number of samples.
- Return type:
Callable: key, nsamples -> samples
- get_train_step_fn(loss_fn)[source]#
Return the training step function, which performs a single optimization step.
- Returns:
train_step – JIT-compiled training step function.
- Return type:
Callable
- classmethod init_pipeline_from_config()[source]#
- Abstractmethod:
Initialize the pipeline from a configuration file.
- Parameters:
train_dataset (iterable) – Training dataset.
val_dataset (iterable) – Validation dataset.
dim_obs (int) – Dimensionality of the parameter (theta) space.
dim_cond (int) – Dimensionality of the observation (x) space.
config_path (str) – Path to the configuration file.
checkpoint_dir (str) – Directory for saving checkpoints.
- Returns:
pipeline – An instance of the pipeline initialized from the configuration.
- Return type:
- sample(key, x_o, nsamples=10000, nsteps=18, use_ema=True, return_intermediates=False, **model_extras)[source]#
Generate samples from the trained model.
- Parameters:
key (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable (e.g., observed data).
nsamples (int, optional) – Number of samples to generate.
- Returns:
samples – Generated samples of size (nsamples, dim_obs, ch_obs).
- Return type:
array-like
- cond_ids#
- loss_fn#
- obs_ids#
- path#
- class gensbi.recipes.ConditionalFlowPipeline(model, train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, ch_cond=1, id_embedding_strategy=('absolute', 'absolute'), params=None, training_config=None)[source]#
Bases:
gensbi.recipes.pipeline.AbstractPipelineFlow pipeline for training and using a Conditional model for simulation-based inference.
- Parameters:
model (nnx.Module) – The model to be trained.
train_dataset (grain dataset or iterator over batches) – Training dataset.
val_dataset (grain dataset or iterator over batches) – Validation dataset.
dim_obs (int or tuple of int) – Dimension of the parameter space (number of tokens). Can represent unstructured data, time-series, or patchified 2D images. For images, provide a tuple (height, width).
dim_cond (int or tuple of int) – Dimension of the observation space (number of tokens). Can represent unstructured data, time-series, or patchified 2D images. For images, provide a tuple (height, width).
ch_obs (int, optional) – Number of channels per token in the observation data. Default is 1.
ch_cond (int, optional) – Number of channels per token in the conditional data. Default is 1.
params (ConditionalParams, optional) – Parameters for the Conditional model. If None, default parameters are used.
training_config (dict, optional) – Configuration for training. If None, default configuration is used.
Examples
Minimal example on how to instantiate and use the ConditionalFlowPipeline:
1# %% Imports 2import os 3 4# Set JAX backend (use 'cuda' for GPU, 'cpu' otherwise) 5# os.environ["JAX_PLATFORMS"] = "cuda" 6 7import grain 8import numpy as np 9import jax 10from jax import numpy as jnp 11from numpyro import distributions as dist 12from flax import nnx 13 14from gensbi.recipes import ConditionalFlowPipeline 15from gensbi.models import Flux1, Flux1Params 16 17from gensbi.utils.plotting import plot_marginals 18import matplotlib.pyplot as plt 19 20 21# %% 22 23theta_prior = dist.Uniform( 24 low=jnp.array([-2.0, -2.0, -2.0]), high=jnp.array([2.0, 2.0, 2.0]) 25) 26 27dim_obs = 3 28dim_cond = 3 29dim_joint = dim_obs + dim_cond 30 31 32# %% 33def simulator(key, nsamples): 34 theta_key, sample_key = jax.random.split(key, 2) 35 thetas = theta_prior.sample(theta_key, (nsamples,)) 36 37 xs = thetas + 1 + jax.random.normal(sample_key, thetas.shape) * 0.1 38 39 thetas = thetas[..., None] 40 xs = xs[..., None] 41 42 # when making a dataset for the joint pipeline, thetas need to come first 43 data = jnp.concatenate([thetas, xs], axis=1) 44 45 return data 46 47 48# %% Define your training and validation datasets. 49# We generate a training dataset and a validation dataset using the simulator. 50# The simulator is a simple function that generates parameters (theta) and data (x). 51# In this example, we use a simple Gaussian simulator. 52train_data = simulator(jax.random.PRNGKey(0), 100_000) 53val_data = simulator(jax.random.PRNGKey(1), 2000) 54 55 56# %% Normalize the dataset 57# It is important to normalize the data to have zero mean and unit variance. 58# This helps the model training process. 59means = jnp.mean(train_data, axis=0) 60stds = jnp.std(train_data, axis=0) 61 62 63def normalize(data, means, stds): 64 return (data - means) / stds 65 66 67def unnormalize(data, means, stds): 68 return data * stds + means 69 70 71# %% Prepare the data for the pipeline 72# The pipeline expects the data to be split into observations and conditions. 73# We also apply normalization at this stage. 74def split_obs_cond(data): 75 data = normalize(data, means, stds) 76 return ( 77 data[:, :dim_obs], 78 data[:, dim_obs:], 79 ) # assuming first dim_obs are obs, last dim_cond are cond 80 81 82# %% 83 84# %% Create the input pipeline using Grain 85# We use Grain to create an efficient input pipeline. 86# This involves shuffling, repeating for multiple epochs, and batching the data. 87# We also map the split_obs_cond function to prepare the data for the model. 88batch_size = 256 89 90train_dataset_grain = ( 91 grain.MapDataset.source(np.array(train_data)) 92 .shuffle(42) 93 .repeat() 94 .to_iter_dataset() 95 .batch(batch_size) 96 .map(split_obs_cond) 97 # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching 98) 99 100val_dataset_grain = ( 101 grain.MapDataset.source(np.array(val_data)) 102 .shuffle(42) 103 .repeat() 104 .to_iter_dataset() 105 .batch(batch_size) 106 .map(split_obs_cond) 107 # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching 108) 109 110# %% Define your model 111# specific model parameters are defined here. 112# For Flux1, we need to specify dimensions, embedding strategies, and other architecture details. 113params = Flux1Params( 114 in_channels=1, 115 vec_in_dim=None, 116 context_in_dim=1, 117 mlp_ratio=3, 118 num_heads=2, 119 depth=4, 120 depth_single_blocks=8, 121 axes_dim=[ 122 10, 123 ], 124 qkv_bias=True, 125 dim_obs=dim_obs, 126 dim_cond=dim_cond, 127 theta=10 * dim_joint, 128 id_embedding_strategy=("absolute", "absolute"), 129 rngs=nnx.Rngs(default=42), 130 param_dtype=jnp.float32, 131) 132 133model = Flux1(params) 134 135# %% Instantiate the pipeline 136# The ConditionalFlowPipeline handles the training loop and sampling. 137# We configure it with the model, datasets, dimensions, and training configuration. 138training_config = ConditionalFlowPipeline.get_default_training_config() 139training_config["nsteps"] = 10000 140 141pipeline = ConditionalFlowPipeline( 142 model, 143 train_dataset_grain, 144 val_dataset_grain, 145 dim_obs, 146 dim_cond, 147 training_config=training_config, 148) 149 150# %% Train the model 151# We create a random key for training and start the training process. 152rngs = nnx.Rngs(42) 153pipeline.train( 154 rngs, save_model=False 155) # if you want to save the model, set save_model=True 156 157# %% Sample from the posterior 158# To generate samples, we first need an observation (and its corresponding condition). 159# We generate a new sample from the simulator, normalize it, and extract the condition x_o. 160 161new_sample = simulator(jax.random.PRNGKey(20), 1) 162true_theta = new_sample[:, :dim_obs, :] # extract observation from the joint sample 163 164new_sample = normalize(new_sample, means, stds) 165x_o = new_sample[:, dim_obs:, :] # extract condition from the joint sample 166 167# Then we invoke the pipeline's sample method. 168samples = pipeline.sample(rngs.sample(), x_o, nsamples=100_000) 169# Finally, we unnormalize the samples to get them back to the original scale. 170samples = unnormalize(samples, means[:dim_obs], stds[:dim_obs]) 171# %% Plot the samples 172plot_marginals( 173 np.array(samples[..., 0]), 174 gridsize=30, 175 true_param=np.array(true_theta[0, :, 0]), 176 range=[(1, 3), (1, 3), (-0.6, 0.5)], 177) 178plt.savefig("conditional_flow_pipeline_marginals.png", dpi=100, bbox_inches="tight") 179plt.show() 180 181# %%
Note
If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a
if __name__ == "__main__":guard. See https://docs.python.org/3/library/multiprocessing.htmlNote
Sampling in the latent space (latent diffusion/flow) is not currently supported.
- _wrap_model()[source]#
Wrap the model for evaluation (either using JointWrapper or ConditionalWrapper).
- get_sampler(x_o, step_size=0.01, use_ema=True, time_grid=None, **model_extras)[source]#
Get a sampler function for generating samples from the trained model.
- Parameters:
key (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable.
step_size (float, optional) – Step size for the sampler.
use_ema (bool, optional) – Whether to use the EMA model for sampling.
time_grid (array-like, optional) – Time grid for the sampler (if applicable).
model_extras (dict, optional) – Additional model-specific parameters.
- Returns:
sampler – A function that generates samples when called with a random key and number of samples.
- Return type:
Callable: key, nsamples -> samples
- get_train_step_fn(loss_fn)[source]#
Return the training step function, which performs a single optimization step.
- Returns:
train_step – JIT-compiled training step function.
- Return type:
Callable
- classmethod init_pipeline_from_config()[source]#
- Abstractmethod:
Initialize the pipeline from a configuration file.
- Parameters:
train_dataset (iterable) – Training dataset.
val_dataset (iterable) – Validation dataset.
dim_obs (int) – Dimensionality of the parameter (theta) space.
dim_cond (int) – Dimensionality of the observation (x) space.
config_path (str) – Path to the configuration file.
checkpoint_dir (str) – Directory for saving checkpoints.
- Returns:
pipeline – An instance of the pipeline initialized from the configuration.
- Return type:
- sample(key, x_o, nsamples=10000, step_size=0.01, use_ema=True, time_grid=None, **model_extras)[source]#
Generate samples from the trained model.
- Parameters:
key (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable (e.g., observed data).
nsamples (int, optional) – Number of samples to generate.
- Returns:
samples – Generated samples of size (nsamples, dim_obs, ch_obs).
- Return type:
array-like
- cond_ids#
- loss_fn#
- obs_ids#
- p0_obs#
- path#
- class gensbi.recipes.Flux1DiffusionPipeline(train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, ch_cond=1, params=None, training_config=None)[source]#
Bases:
gensbi.recipes.conditional_pipeline.ConditionalDiffusionPipelineDiffusion pipeline for training and using a Conditional model for simulation-based inference.
- Parameters:
train_dataset (grain dataset or iterator over batches) – Training dataset.
val_dataset (grain dataset or iterator over batches) – Validation dataset.
dim_obs (int or tuple of int) – Dimension of the parameter space (number of tokens). Can represent unstructured data, time-series, or patchified 2D images. For images, provide a tuple (height, width).
dim_cond (int or tuple of int) – Dimension of the observation space (number of tokens). Can represent unstructured data, time-series, or patchified 2D images. For images, provide a tuple (height, width).
params (ConditionalParams, optional) – Parameters for the Conditional model. If None, default parameters are used.
training_config (dict, optional) – Configuration for training. If None, default configuration is used.
Examples
Minimal example on how to instantiate and use the ConditionalDiffusionPipeline:
1# %% Imports 2import os 3 4# Set JAX backend (use 'cuda' for GPU, 'cpu' otherwise) 5# os.environ["JAX_PLATFORMS"] = "cuda" 6 7import grain 8import numpy as np 9import jax 10from jax import numpy as jnp 11from numpyro import distributions as dist 12from flax import nnx 13 14from gensbi.recipes import ConditionalDiffusionPipeline 15from gensbi.models import Flux1, Flux1Params 16 17from gensbi.utils.plotting import plot_marginals 18import matplotlib.pyplot as plt 19 20 21# %% 22 23theta_prior = dist.Uniform( 24 low=jnp.array([-2.0, -2.0, -2.0]), high=jnp.array([2.0, 2.0, 2.0]) 25) 26 27dim_obs = 3 28dim_cond = 3 29dim_joint = dim_obs + dim_cond 30 31 32# %% 33def simulator(key, nsamples): 34 theta_key, sample_key = jax.random.split(key, 2) 35 thetas = theta_prior.sample(theta_key, (nsamples,)) 36 37 xs = thetas + 1 + jax.random.normal(sample_key, thetas.shape) * 0.1 38 39 thetas = thetas[..., None] 40 xs = xs[..., None] 41 42 # when making a dataset for the joint pipeline, thetas need to come first 43 data = jnp.concatenate([thetas, xs], axis=1) 44 45 return data 46 47 48# %% Define your training and validation datasets. 49# We generate a training dataset and a validation dataset using the simulator. 50# The simulator is a simple function that generates parameters (theta) and data (x). 51# In this example, we use a simple Gaussian simulator. 52train_data = simulator(jax.random.PRNGKey(0), 100_000) 53val_data = simulator(jax.random.PRNGKey(1), 2000) 54# %% Normalize the dataset 55# It is important to normalize the data to have zero mean and unit variance. 56# This helps the model training process. 57means = jnp.mean(train_data, axis=0) 58stds = jnp.std(train_data, axis=0) 59 60 61def normalize(data, means, stds): 62 return (data - means) / stds 63 64 65def unnormalize(data, means, stds): 66 return data * stds + means 67 68 69# %% Prepare the data for the pipeline 70# The pipeline expects the data to be split into observations and conditions. 71# We also apply normalization at this stage. 72def split_obs_cond(data): 73 data = normalize(data, means, stds) 74 return ( 75 data[:, :dim_obs], 76 data[:, dim_obs:], 77 ) # assuming first dim_obs are obs, last dim_cond are cond 78 79 80# %% 81 82# %% Create the input pipeline using Grain 83# We use Grain to create an efficient input pipeline. 84# This involves shuffling, repeating for multiple epochs, and batching the data. 85# We also map the split_obs_cond function to prepare the data for the model. 86batch_size = 256 87 88train_dataset_grain = ( 89 grain.MapDataset.source(np.array(train_data)) 90 .shuffle(42) 91 .repeat() 92 .to_iter_dataset() 93 .batch(batch_size) 94 .map(split_obs_cond) 95 # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching 96) 97 98val_dataset_grain = ( 99 grain.MapDataset.source(np.array(val_data)) 100 .shuffle( 101 42 102 ) # Use a different seed/strategy for validation if needed, but shuffling is fine 103 .repeat() 104 .to_iter_dataset() 105 .batch(batch_size) 106 .map(split_obs_cond) 107 # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching 108) 109 110# %% Define your model 111# specific model parameters are defined here. 112# For Flux1, we need to specify dimensions, embedding strategies, and other architecture details. 113params = Flux1Params( 114 in_channels=1, 115 vec_in_dim=None, 116 context_in_dim=1, 117 mlp_ratio=3, 118 num_heads=2, 119 depth=4, 120 depth_single_blocks=8, 121 axes_dim=[ 122 10, 123 ], 124 qkv_bias=True, 125 dim_obs=dim_obs, 126 dim_cond=dim_cond, 127 id_embedding_strategy=("absolute", "absolute"), 128 theta=10 * dim_joint, 129 rngs=nnx.Rngs(default=42), 130 param_dtype=jnp.float32, 131) 132 133model = Flux1(params) 134 135# %% Instantiate the pipeline 136# The ConditionalDiffusionPipeline handles the training loop and sampling. 137# We configure it with the model, datasets, dimensions using a default training configuration. 138training_config = ConditionalDiffusionPipeline.get_default_training_config() 139training_config["nsteps"] = 10000 140 141pipeline = ConditionalDiffusionPipeline( 142 model, 143 train_dataset_grain, 144 val_dataset_grain, 145 dim_obs, 146 dim_cond, 147 training_config=training_config, 148) 149 150# %% Train the model 151# We create a random key for training and start the training process. 152rngs = nnx.Rngs(42) 153pipeline.train( 154 rngs, save_model=False 155) # if you want to save the model, set save_model=True 156 157# %% Sample from the posterior 158# To generate samples, we first need an observation (and its corresponding condition). 159# We generate a new sample from the simulator, normalize it, and extract the condition x_o. 160 161new_sample = simulator(jax.random.PRNGKey(20), 1) 162true_theta = new_sample[:, :dim_obs, :] # extract observation from the joint sample 163 164new_sample = normalize(new_sample, means, stds) 165x_o = new_sample[:, dim_obs:, :] # extract condition from the joint sample 166 167# Then we invoke the pipeline's sample method. 168samples = pipeline.sample(rngs.sample(), x_o, nsamples=100_000) 169# Finally, we unnormalize the samples to get them back to the original scale. 170samples = unnormalize(samples, means[:dim_obs], stds[:dim_obs]) 171 172# %% Plot the samples 173# We verify the model's performance by plotting the marginal distributions of the generated samples 174# against the true parameters. 175plot_marginals( 176 np.array(samples[..., 0]), 177 gridsize=30, 178 true_param=np.array(true_theta[0, :, 0]), 179 range=[(1, 3), (1, 3), (-0.6, 0.5)], 180) 181 182plt.savefig( 183 "conditional_diffusion_pipeline_marginals.png", dpi=100, bbox_inches="tight" 184) 185plt.show() 186 187# %%
Note
If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a
if __name__ == "__main__":guard. See https://docs.python.org/3/library/multiprocessing.htmlNote
Sampling in the latent space (latent diffusion/flow) is not currently supported.
- classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_obs, dim_cond, config_path, checkpoint_dir)[source]#
Initialize the pipeline from a configuration file.
- Parameters:
config_path (str) – Path to the configuration file.
dim_obs (int)
dim_cond (int)
checkpoint_dir (str)
- ch_cond = 1#
- ch_obs = 1#
- dim_cond#
- dim_obs#
- ema_model#
- class gensbi.recipes.Flux1FlowPipeline(train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, ch_cond=1, params=None, training_config=None)[source]#
Bases:
gensbi.recipes.conditional_pipeline.ConditionalFlowPipelineFlow pipeline for training and using a Conditional model for simulation-based inference.
- Parameters:
model (nnx.Module) – The model to be trained.
train_dataset (grain dataset or iterator over batches) – Training dataset.
val_dataset (grain dataset or iterator over batches) – Validation dataset.
dim_obs (int or tuple of int) – Dimension of the parameter space (number of tokens). Can represent unstructured data, time-series, or patchified 2D images. For images, provide a tuple (height, width).
dim_cond (int or tuple of int) – Dimension of the observation space (number of tokens). Can represent unstructured data, time-series, or patchified 2D images. For images, provide a tuple (height, width).
ch_obs (int, optional) – Number of channels per token in the observation data. Default is 1.
ch_cond (int, optional) – Number of channels per token in the conditional data. Default is 1.
params (ConditionalParams, optional) – Parameters for the Conditional model. If None, default parameters are used.
training_config (dict, optional) – Configuration for training. If None, default configuration is used.
Examples
Minimal example on how to instantiate and use the ConditionalFlowPipeline:
1# %% Imports 2import os 3 4# Set JAX backend (use 'cuda' for GPU, 'cpu' otherwise) 5# os.environ["JAX_PLATFORMS"] = "cuda" 6 7import grain 8import numpy as np 9import jax 10from jax import numpy as jnp 11from numpyro import distributions as dist 12from flax import nnx 13 14from gensbi.recipes import ConditionalFlowPipeline 15from gensbi.models import Flux1, Flux1Params 16 17from gensbi.utils.plotting import plot_marginals 18import matplotlib.pyplot as plt 19 20 21# %% 22 23theta_prior = dist.Uniform( 24 low=jnp.array([-2.0, -2.0, -2.0]), high=jnp.array([2.0, 2.0, 2.0]) 25) 26 27dim_obs = 3 28dim_cond = 3 29dim_joint = dim_obs + dim_cond 30 31 32# %% 33def simulator(key, nsamples): 34 theta_key, sample_key = jax.random.split(key, 2) 35 thetas = theta_prior.sample(theta_key, (nsamples,)) 36 37 xs = thetas + 1 + jax.random.normal(sample_key, thetas.shape) * 0.1 38 39 thetas = thetas[..., None] 40 xs = xs[..., None] 41 42 # when making a dataset for the joint pipeline, thetas need to come first 43 data = jnp.concatenate([thetas, xs], axis=1) 44 45 return data 46 47 48# %% Define your training and validation datasets. 49# We generate a training dataset and a validation dataset using the simulator. 50# The simulator is a simple function that generates parameters (theta) and data (x). 51# In this example, we use a simple Gaussian simulator. 52train_data = simulator(jax.random.PRNGKey(0), 100_000) 53val_data = simulator(jax.random.PRNGKey(1), 2000) 54 55 56# %% Normalize the dataset 57# It is important to normalize the data to have zero mean and unit variance. 58# This helps the model training process. 59means = jnp.mean(train_data, axis=0) 60stds = jnp.std(train_data, axis=0) 61 62 63def normalize(data, means, stds): 64 return (data - means) / stds 65 66 67def unnormalize(data, means, stds): 68 return data * stds + means 69 70 71# %% Prepare the data for the pipeline 72# The pipeline expects the data to be split into observations and conditions. 73# We also apply normalization at this stage. 74def split_obs_cond(data): 75 data = normalize(data, means, stds) 76 return ( 77 data[:, :dim_obs], 78 data[:, dim_obs:], 79 ) # assuming first dim_obs are obs, last dim_cond are cond 80 81 82# %% 83 84# %% Create the input pipeline using Grain 85# We use Grain to create an efficient input pipeline. 86# This involves shuffling, repeating for multiple epochs, and batching the data. 87# We also map the split_obs_cond function to prepare the data for the model. 88batch_size = 256 89 90train_dataset_grain = ( 91 grain.MapDataset.source(np.array(train_data)) 92 .shuffle(42) 93 .repeat() 94 .to_iter_dataset() 95 .batch(batch_size) 96 .map(split_obs_cond) 97 # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching 98) 99 100val_dataset_grain = ( 101 grain.MapDataset.source(np.array(val_data)) 102 .shuffle(42) 103 .repeat() 104 .to_iter_dataset() 105 .batch(batch_size) 106 .map(split_obs_cond) 107 # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching 108) 109 110# %% Define your model 111# specific model parameters are defined here. 112# For Flux1, we need to specify dimensions, embedding strategies, and other architecture details. 113params = Flux1Params( 114 in_channels=1, 115 vec_in_dim=None, 116 context_in_dim=1, 117 mlp_ratio=3, 118 num_heads=2, 119 depth=4, 120 depth_single_blocks=8, 121 axes_dim=[ 122 10, 123 ], 124 qkv_bias=True, 125 dim_obs=dim_obs, 126 dim_cond=dim_cond, 127 theta=10 * dim_joint, 128 id_embedding_strategy=("absolute", "absolute"), 129 rngs=nnx.Rngs(default=42), 130 param_dtype=jnp.float32, 131) 132 133model = Flux1(params) 134 135# %% Instantiate the pipeline 136# The ConditionalFlowPipeline handles the training loop and sampling. 137# We configure it with the model, datasets, dimensions, and training configuration. 138training_config = ConditionalFlowPipeline.get_default_training_config() 139training_config["nsteps"] = 10000 140 141pipeline = ConditionalFlowPipeline( 142 model, 143 train_dataset_grain, 144 val_dataset_grain, 145 dim_obs, 146 dim_cond, 147 training_config=training_config, 148) 149 150# %% Train the model 151# We create a random key for training and start the training process. 152rngs = nnx.Rngs(42) 153pipeline.train( 154 rngs, save_model=False 155) # if you want to save the model, set save_model=True 156 157# %% Sample from the posterior 158# To generate samples, we first need an observation (and its corresponding condition). 159# We generate a new sample from the simulator, normalize it, and extract the condition x_o. 160 161new_sample = simulator(jax.random.PRNGKey(20), 1) 162true_theta = new_sample[:, :dim_obs, :] # extract observation from the joint sample 163 164new_sample = normalize(new_sample, means, stds) 165x_o = new_sample[:, dim_obs:, :] # extract condition from the joint sample 166 167# Then we invoke the pipeline's sample method. 168samples = pipeline.sample(rngs.sample(), x_o, nsamples=100_000) 169# Finally, we unnormalize the samples to get them back to the original scale. 170samples = unnormalize(samples, means[:dim_obs], stds[:dim_obs]) 171# %% Plot the samples 172plot_marginals( 173 np.array(samples[..., 0]), 174 gridsize=30, 175 true_param=np.array(true_theta[0, :, 0]), 176 range=[(1, 3), (1, 3), (-0.6, 0.5)], 177) 178plt.savefig("conditional_flow_pipeline_marginals.png", dpi=100, bbox_inches="tight") 179plt.show() 180 181# %%
Note
If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a
if __name__ == "__main__":guard. See https://docs.python.org/3/library/multiprocessing.htmlNote
Sampling in the latent space (latent diffusion/flow) is not currently supported.
- classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_obs, dim_cond, config_path, checkpoint_dir)[source]#
Initialize the pipeline from a configuration file.
- Parameters:
config_path (str) – Path to the configuration file.
dim_obs (int)
dim_cond (int)
checkpoint_dir (str)
- ch_cond = 1#
- ch_obs = 1#
- dim_cond#
- dim_obs#
- ema_model#
- class gensbi.recipes.Flux1JointDiffusionPipeline(train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, params=None, training_config=None, condition_mask_kind='structured')[source]#
Bases:
gensbi.recipes.joint_pipeline.JointDiffusionPipelineDiffusion pipeline for training and using a Joint model for simulation-based inference.
- Parameters:
train_dataset (grain dataset or iterator over batches) – Training dataset.
val_dataset (grain dataset or iterator over batches) – Validation dataset.
dim_obs (int) – Dimension of the parameter space.
dim_cond (int) – Dimension of the observation space.
ch_obs (int, optional) – Number of channels for the observation space. Default is 1.
params (optional) – Parameters for the Joint model. If None, default parameters are used.
training_config (dict, optional) – Configuration for training. If None, default configuration is used.
condition_mask_kind (str, optional) – Kind of condition mask to use. One of [“structured”, “posterior”].
Examples
Minimal example on how to instantiate and use the JointDiffusionPipeline:
1# %% Imports 2import os 3 4# Set JAX backend (use 'cuda' for GPU, 'cpu' otherwise) 5# os.environ["JAX_PLATFORMS"] = "cuda" 6 7import grain 8import numpy as np 9import jax 10from jax import numpy as jnp 11from gensbi.recipes import JointDiffusionPipeline 12from gensbi.utils.plotting import plot_marginals 13 14from gensbi.models import Simformer, SimformerParams 15import matplotlib.pyplot as plt 16 17from numpyro import distributions as dist 18 19 20from flax import nnx 21 22# %% 23 24theta_prior = dist.Uniform( 25 low=jnp.array([-2.0, -2.0, -2.0]), high=jnp.array([2.0, 2.0, 2.0]) 26) 27 28dim_obs = 3 29dim_cond = 3 30dim_joint = dim_obs + dim_cond 31 32 33# %% 34def simulator(key, nsamples): 35 theta_key, sample_key = jax.random.split(key, 2) 36 thetas = theta_prior.sample(theta_key, (nsamples,)) 37 38 xs = thetas + 1 + jax.random.normal(sample_key, thetas.shape) * 0.1 39 40 thetas = thetas[..., None] 41 xs = xs[..., None] 42 43 # when making a dataset for the joint pipeline, thetas need to come first 44 data = jnp.concatenate([thetas, xs], axis=1) 45 46 return data 47 48 49# %% Define your training and validation datasets. 50# We generate a training dataset and a validation dataset using the simulator. 51# The simulator is a simple function that generates parameters (theta) and data (x). 52# In this example, we use a simple Gaussian simulator. 53train_data = simulator(jax.random.PRNGKey(0), 100_000) 54val_data = simulator(jax.random.PRNGKey(1), 2000) 55# %% Normalize the dataset 56# It is important to normalize the data to have zero mean and unit variance. 57# This helps the model training process. 58means = jnp.mean(train_data, axis=0) 59stds = jnp.std(train_data, axis=0) 60 61 62def normalize(data, means, stds): 63 return (data - means) / stds 64 65 66def unnormalize(data, means, stds): 67 return data * stds + means 68 69 70# %% Prepare the data for the pipeline 71# The pipeline expects the data to be normalized but not split (for joint pipelines). 72 73 74def process_data(data): 75 return normalize(data, means, stds) 76 77 78# %% 79train_data.shape 80 81# %% 82 83# %% Create the input pipeline using Grain 84# We use Grain to create an efficient input pipeline. 85# This involves shuffling, repeating for multiple epochs, and batching the data. 86# We also map the process_data function to prepare (normalize) the data for the model. 87batch_size = 256 88 89train_dataset_grain = ( 90 grain.MapDataset.source(np.array(train_data)) 91 .shuffle(42) 92 .repeat() 93 .to_iter_dataset() 94 .batch(batch_size) 95 .map(process_data) 96 # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching 97) 98 99val_dataset_grain = ( 100 grain.MapDataset.source(np.array(val_data)) 101 .shuffle( 102 42 103 ) # Use a different seed/strategy for validation if needed, but shuffling is fine 104 .repeat() 105 .to_iter_dataset() 106 .batch(batch_size) 107 .map(process_data) 108 # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching 109) 110 111# %% Define your model 112# specific model parameters are defined here. 113# For Simformer, we need to specify dimensions, embedding strategies, and other architecture details. 114params = SimformerParams( 115 rngs=nnx.Rngs(0), 116 in_channels=1, 117 dim_value=20, 118 dim_id=10, 119 dim_condition=10, 120 dim_joint=dim_joint, 121 fourier_features=128, 122 num_heads=4, 123 num_layers=6, 124 widening_factor=3, 125 qkv_features=40, 126 num_hidden_layers=1, 127) 128 129model = Simformer(params) 130 131# %% Instantiate the pipeline 132# The JointDiffusionPipeline handles the training loop and sampling. 133# We configure it with the model, datasets, dimensions using a default training configuration. 134# We also specify the condition_mask_kind, which determines how conditioning is handled during training. 135training_config = JointDiffusionPipeline.get_default_training_config() 136training_config["nsteps"] = 10000 137 138pipeline = JointDiffusionPipeline( 139 model, 140 train_dataset_grain, 141 val_dataset_grain, 142 dim_obs, 143 dim_cond, 144 condition_mask_kind="posterior", 145 training_config=training_config, 146) 147 148# %% Train the model 149# We create a random key for training and start the training process. 150rngs = nnx.Rngs(42) 151pipeline.train( 152 rngs, save_model=False 153) # if you want to save the model, set save_model=True 154 155# %% Sample from the posterior 156# To generate samples, we first need an observation (and its corresponding condition). 157# We generate a new sample from the simulator, normalize it, and extract the condition x_o. 158 159new_sample = simulator(jax.random.PRNGKey(20), 1) 160true_theta = new_sample[:, :dim_obs, :] # extract observation from the joint sample 161 162new_sample = normalize(new_sample, means, stds) 163x_o = new_sample[:, dim_obs:, :] # extract condition from the joint sample 164 165# Then we invoke the pipeline's sample method. 166samples = pipeline.sample(rngs.sample(), x_o, nsamples=100_000) 167# Finally, we unnormalize the samples to get them back to the original scale. 168samples = unnormalize(samples, means[:dim_obs], stds[:dim_obs]) 169 170# %% Plot the samples 171# We verify the model's performance by plotting the marginal distributions of the generated samples 172# against the true parameters. 173plot_marginals( 174 np.array(samples[..., 0]), 175 gridsize=30, 176 true_param=np.array(true_theta[0, :, 0]), 177 range=[(1, 3), (1, 3), (-0.6, 0.5)], 178) 179plt.savefig("joint_diffusion_pipeline_marginals.png", dpi=100, bbox_inches="tight") 180plt.show()
Note
If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a
if __name__ == "__main__":guard. See https://docs.python.org/3/library/multiprocessing.html- classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_obs, dim_cond, config_path, checkpoint_dir)[source]#
Initialize the pipeline from a configuration file.
- Parameters:
config_path (str) – Path to the configuration file.
dim_obs (int)
dim_cond (int)
checkpoint_dir (str)
- ch_obs = 1#
- dim_joint#
- ema_model#
- class gensbi.recipes.Flux1JointFlowPipeline(train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, params=None, training_config=None, condition_mask_kind='structured')[source]#
Bases:
gensbi.recipes.joint_pipeline.JointFlowPipelineFlow pipeline for training and using a Joint model for simulation-based inference.
- Parameters:
train_dataset (grain dataset or iterator over batches) – Training dataset.
val_dataset (grain dataset or iterator over batches) – Validation dataset.
dim_obs (int) – Dimension of the parameter space.
dim_cond (int) – Dimension of the observation space.
ch_obs (int, optional) – Number of channels for the observation space. Default is 1.
params (JointParams, optional) – Parameters for the Joint model. If None, default parameters are used.
training_config (dict, optional) – Configuration for training. If None, default configuration is used.
condition_mask_kind (str, optional) – Kind of condition mask to use. One of [“structured”, “posterior”].
Examples
Minimal example on how to instantiate and use the JointFlowPipeline:
1# %% Imports 2import os 3 4# Set JAX backend (use 'cuda' for GPU, 'cpu' otherwise) 5# os.environ["JAX_PLATFORMS"] = "cuda" 6 7import grain 8import numpy as np 9import jax 10from jax import numpy as jnp 11from numpyro import distributions as dist 12from flax import nnx 13 14from gensbi.recipes import JointFlowPipeline 15from gensbi.models import Simformer, SimformerParams 16 17from gensbi.utils.plotting import plot_marginals 18import matplotlib.pyplot as plt 19 20 21# %% 22 23theta_prior = dist.Uniform( 24 low=jnp.array([-2.0, -2.0, -2.0]), high=jnp.array([2.0, 2.0, 2.0]) 25) 26 27dim_obs = 3 28dim_cond = 3 29dim_joint = dim_obs + dim_cond 30 31 32# %% 33def simulator(key, nsamples): 34 theta_key, sample_key = jax.random.split(key, 2) 35 thetas = theta_prior.sample(theta_key, (nsamples,)) 36 37 xs = thetas + 1 + jax.random.normal(sample_key, thetas.shape) * 0.1 38 39 thetas = thetas[..., None] 40 xs = xs[..., None] 41 42 # when making a dataset for the joint pipeline, thetas need to come first 43 data = jnp.concatenate([thetas, xs], axis=1) 44 45 return data 46 47 48# %% Define your training and validation datasets. 49# We generate a training dataset and a validation dataset using the simulator. 50# The simulator is a simple function that generates parameters (theta) and data (x). 51# In this example, we use a simple Gaussian simulator. 52train_data = simulator(jax.random.PRNGKey(0), 100_000) 53val_data = simulator(jax.random.PRNGKey(1), 2000) 54# %% Normalize the dataset 55# It is important to normalize the data to have zero mean and unit variance. 56# This helps the model training process. 57means = jnp.mean(train_data, axis=0) 58stds = jnp.std(train_data, axis=0) 59 60 61def normalize(data, means, stds): 62 return (data - means) / stds 63 64 65def unnormalize(data, means, stds): 66 return data * stds + means 67 68 69# %% Prepare the data for the pipeline 70# The pipeline expects the data to be normalized but not split (for joint pipelines). 71 72 73# %% Prepare the data for the pipeline 74# The pipeline expects the data to be normalized but not split (for joint pipelines). 75def process_data(data): 76 return normalize(data, means, stds) 77 78 79# %% 80train_data.shape 81 82# %% 83 84# %% Create the input pipeline using Grain 85# We use Grain to create an efficient input pipeline. 86# This involves shuffling, repeating for multiple epochs, and batching the data. 87# We also map the process_data function to prepare (normalize) the data for the model. 88# We also map the process_data function to prepare (normalize) the data for the model. 89# %% Create the input pipeline using Grain 90# We use Grain to create an efficient input pipeline. 91# This involves shuffling, repeating for multiple epochs, and batching the data. 92# We also map the process_data function to prepare (normalize) the data for the model. 93batch_size = 256 94 95train_dataset_grain = ( 96 grain.MapDataset.source(np.array(train_data)) 97 .shuffle(42) 98 .repeat() 99 .to_iter_dataset() 100 .batch(batch_size) 101 .map(process_data) 102 # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching 103) 104 105val_dataset_grain = ( 106 grain.MapDataset.source(np.array(val_data)) 107 .shuffle(42) 108 .repeat() 109 .to_iter_dataset() 110 .batch(batch_size) 111 .map(process_data) 112 # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching 113) 114 115# %% Define your model 116# specific model parameters are defined here. 117# For Simformer, we need to specify dimensions, embedding strategies, and other architecture details. 118params = SimformerParams( 119 rngs=nnx.Rngs(0), 120 in_channels=1, 121 dim_value=20, 122 dim_id=10, 123 dim_condition=10, 124 dim_joint=dim_joint, 125 fourier_features=128, 126 num_heads=4, 127 num_layers=6, 128 widening_factor=3, 129 qkv_features=40, 130 num_hidden_layers=1, 131) 132 133model = Simformer(params) 134 135# %% Instantiate the pipeline 136# The JointFlowPipeline handles the training loop and sampling. 137# We configure it with the model, datasets, dimensions using a default training configuration. 138# We also specify the condition_mask_kind, which determines how conditioning is handled during training. 139training_config = JointFlowPipeline.get_default_training_config() 140training_config["nsteps"] = 10000 141 142pipeline = JointFlowPipeline( 143 model, 144 train_dataset_grain, 145 val_dataset_grain, 146 dim_obs, 147 dim_cond, 148 condition_mask_kind="posterior", 149 training_config=training_config, 150) 151 152# %% Train the model 153# We create a random key for training and start the training process. 154rngs = nnx.Rngs(42) 155pipeline.train( 156 rngs, save_model=False 157) # if you want to save the model, set save_model=True 158 159# %% Sample from the posterior 160# To generate samples, we first need an observation (and its corresponding condition). 161# We generate a new sample from the simulator, normalize it, and extract the condition x_o. 162 163new_sample = simulator(jax.random.PRNGKey(20), 1) 164true_theta = new_sample[:, :dim_obs, :] # extract observation from the joint sample 165 166new_sample = normalize(new_sample, means, stds) 167x_o = new_sample[:, dim_obs:, :] # extract condition from the joint sample 168 169# Then we invoke the pipeline's sample method. 170samples = pipeline.sample(rngs.sample(), x_o, nsamples=100_000) 171# Finally, we unnormalize the samples to get them back to the original scale. 172samples = unnormalize(samples, means[:dim_obs], stds[:dim_obs]) 173 174# %% Plot the samples 175# We verify the model's performance by plotting the marginal distributions of the generated samples 176# against the true parameters. 177plot_marginals( 178 np.array(samples[..., 0]), 179 gridsize=30, 180 true_param=np.array(true_theta[0, :, 0]), 181 range=[(1, 3), (1, 3), (-0.6, 0.5)], 182) 183plt.savefig("joint_flow_pipeline_marginals.png", dpi=100, bbox_inches="tight") 184plt.show()
Note
If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a
if __name__ == "__main__":guard. See https://docs.python.org/3/library/multiprocessing.html- classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_obs, dim_cond, config_path, checkpoint_dir)[source]#
Initialize the pipeline from a configuration file.
- Parameters:
config_path (str) – Path to the configuration file.
dim_obs (int)
dim_cond (int)
checkpoint_dir (str)
- ch_obs = 1#
- dim_joint#
- ema_model#
- class gensbi.recipes.JointDiffusionPipeline(model, train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, params=None, training_config=None, condition_mask_kind='structured')[source]#
Bases:
gensbi.recipes.pipeline.AbstractPipelineDiffusion pipeline for training and using a Joint model for simulation-based inference.
- Parameters:
train_dataset (grain dataset or iterator over batches) – Training dataset.
val_dataset (grain dataset or iterator over batches) – Validation dataset.
dim_obs (int) – Dimension of the parameter space.
dim_cond (int) – Dimension of the observation space.
ch_obs (int, optional) – Number of channels for the observation space. Default is 1.
params (optional) – Parameters for the Joint model. If None, default parameters are used.
training_config (dict, optional) – Configuration for training. If None, default configuration is used.
condition_mask_kind (str, optional) – Kind of condition mask to use. One of [“structured”, “posterior”].
Examples
Minimal example on how to instantiate and use the JointDiffusionPipeline:
1# %% Imports 2import os 3 4# Set JAX backend (use 'cuda' for GPU, 'cpu' otherwise) 5# os.environ["JAX_PLATFORMS"] = "cuda" 6 7import grain 8import numpy as np 9import jax 10from jax import numpy as jnp 11from gensbi.recipes import JointDiffusionPipeline 12from gensbi.utils.plotting import plot_marginals 13 14from gensbi.models import Simformer, SimformerParams 15import matplotlib.pyplot as plt 16 17from numpyro import distributions as dist 18 19 20from flax import nnx 21 22# %% 23 24theta_prior = dist.Uniform( 25 low=jnp.array([-2.0, -2.0, -2.0]), high=jnp.array([2.0, 2.0, 2.0]) 26) 27 28dim_obs = 3 29dim_cond = 3 30dim_joint = dim_obs + dim_cond 31 32 33# %% 34def simulator(key, nsamples): 35 theta_key, sample_key = jax.random.split(key, 2) 36 thetas = theta_prior.sample(theta_key, (nsamples,)) 37 38 xs = thetas + 1 + jax.random.normal(sample_key, thetas.shape) * 0.1 39 40 thetas = thetas[..., None] 41 xs = xs[..., None] 42 43 # when making a dataset for the joint pipeline, thetas need to come first 44 data = jnp.concatenate([thetas, xs], axis=1) 45 46 return data 47 48 49# %% Define your training and validation datasets. 50# We generate a training dataset and a validation dataset using the simulator. 51# The simulator is a simple function that generates parameters (theta) and data (x). 52# In this example, we use a simple Gaussian simulator. 53train_data = simulator(jax.random.PRNGKey(0), 100_000) 54val_data = simulator(jax.random.PRNGKey(1), 2000) 55# %% Normalize the dataset 56# It is important to normalize the data to have zero mean and unit variance. 57# This helps the model training process. 58means = jnp.mean(train_data, axis=0) 59stds = jnp.std(train_data, axis=0) 60 61 62def normalize(data, means, stds): 63 return (data - means) / stds 64 65 66def unnormalize(data, means, stds): 67 return data * stds + means 68 69 70# %% Prepare the data for the pipeline 71# The pipeline expects the data to be normalized but not split (for joint pipelines). 72 73 74def process_data(data): 75 return normalize(data, means, stds) 76 77 78# %% 79train_data.shape 80 81# %% 82 83# %% Create the input pipeline using Grain 84# We use Grain to create an efficient input pipeline. 85# This involves shuffling, repeating for multiple epochs, and batching the data. 86# We also map the process_data function to prepare (normalize) the data for the model. 87batch_size = 256 88 89train_dataset_grain = ( 90 grain.MapDataset.source(np.array(train_data)) 91 .shuffle(42) 92 .repeat() 93 .to_iter_dataset() 94 .batch(batch_size) 95 .map(process_data) 96 # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching 97) 98 99val_dataset_grain = ( 100 grain.MapDataset.source(np.array(val_data)) 101 .shuffle( 102 42 103 ) # Use a different seed/strategy for validation if needed, but shuffling is fine 104 .repeat() 105 .to_iter_dataset() 106 .batch(batch_size) 107 .map(process_data) 108 # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching 109) 110 111# %% Define your model 112# specific model parameters are defined here. 113# For Simformer, we need to specify dimensions, embedding strategies, and other architecture details. 114params = SimformerParams( 115 rngs=nnx.Rngs(0), 116 in_channels=1, 117 dim_value=20, 118 dim_id=10, 119 dim_condition=10, 120 dim_joint=dim_joint, 121 fourier_features=128, 122 num_heads=4, 123 num_layers=6, 124 widening_factor=3, 125 qkv_features=40, 126 num_hidden_layers=1, 127) 128 129model = Simformer(params) 130 131# %% Instantiate the pipeline 132# The JointDiffusionPipeline handles the training loop and sampling. 133# We configure it with the model, datasets, dimensions using a default training configuration. 134# We also specify the condition_mask_kind, which determines how conditioning is handled during training. 135training_config = JointDiffusionPipeline.get_default_training_config() 136training_config["nsteps"] = 10000 137 138pipeline = JointDiffusionPipeline( 139 model, 140 train_dataset_grain, 141 val_dataset_grain, 142 dim_obs, 143 dim_cond, 144 condition_mask_kind="posterior", 145 training_config=training_config, 146) 147 148# %% Train the model 149# We create a random key for training and start the training process. 150rngs = nnx.Rngs(42) 151pipeline.train( 152 rngs, save_model=False 153) # if you want to save the model, set save_model=True 154 155# %% Sample from the posterior 156# To generate samples, we first need an observation (and its corresponding condition). 157# We generate a new sample from the simulator, normalize it, and extract the condition x_o. 158 159new_sample = simulator(jax.random.PRNGKey(20), 1) 160true_theta = new_sample[:, :dim_obs, :] # extract observation from the joint sample 161 162new_sample = normalize(new_sample, means, stds) 163x_o = new_sample[:, dim_obs:, :] # extract condition from the joint sample 164 165# Then we invoke the pipeline's sample method. 166samples = pipeline.sample(rngs.sample(), x_o, nsamples=100_000) 167# Finally, we unnormalize the samples to get them back to the original scale. 168samples = unnormalize(samples, means[:dim_obs], stds[:dim_obs]) 169 170# %% Plot the samples 171# We verify the model's performance by plotting the marginal distributions of the generated samples 172# against the true parameters. 173plot_marginals( 174 np.array(samples[..., 0]), 175 gridsize=30, 176 true_param=np.array(true_theta[0, :, 0]), 177 range=[(1, 3), (1, 3), (-0.6, 0.5)], 178) 179plt.savefig("joint_diffusion_pipeline_marginals.png", dpi=100, bbox_inches="tight") 180plt.show()
Note
If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a
if __name__ == "__main__":guard. See https://docs.python.org/3/library/multiprocessing.html- _wrap_model()[source]#
Wrap the model for evaluation (either using JointWrapper or ConditionalWrapper).
- classmethod get_default_training_config()[source]#
Return a dictionary of default training configuration parameters.
- Returns:
training_config – Default training configuration.
- Return type:
dict
- get_sampler(x_o, nsteps=18, use_ema=True, return_intermediates=False, **model_extras)[source]#
Get a sampler function for generating samples from the trained model.
- Parameters:
key (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable.
step_size (float, optional) – Step size for the sampler.
use_ema (bool, optional) – Whether to use the EMA model for sampling.
time_grid (array-like, optional) – Time grid for the sampler (if applicable).
model_extras (dict, optional) – Additional model-specific parameters.
- Returns:
sampler – A function that generates samples when called with a random key and number of samples.
- Return type:
Callable: key, nsamples -> samples
- classmethod init_pipeline_from_config()[source]#
- Abstractmethod:
Initialize the pipeline from a configuration file.
- Parameters:
train_dataset (iterable) – Training dataset.
val_dataset (iterable) – Validation dataset.
dim_obs (int) – Dimensionality of the parameter (theta) space.
dim_cond (int) – Dimensionality of the observation (x) space.
config_path (str) – Path to the configuration file.
checkpoint_dir (str) – Directory for saving checkpoints.
- Returns:
pipeline – An instance of the pipeline initialized from the configuration.
- Return type:
- sample(key, x_o, nsamples=10000, nsteps=18, use_ema=True, return_intermediates=False, **model_extras)[source]#
Generate samples from the trained model.
- Parameters:
key (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable (e.g., observed data).
nsamples (int, optional) – Number of samples to generate.
- Returns:
samples – Generated samples of size (nsamples, dim_obs, ch_obs).
- Return type:
array-like
- condition_mask_kind = 'structured'#
- loss_fn#
- path#
- class gensbi.recipes.JointFlowPipeline(model, train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, params=None, training_config=None, condition_mask_kind='structured')[source]#
Bases:
gensbi.recipes.pipeline.AbstractPipelineFlow pipeline for training and using a Joint model for simulation-based inference.
- Parameters:
train_dataset (grain dataset or iterator over batches) – Training dataset.
val_dataset (grain dataset or iterator over batches) – Validation dataset.
dim_obs (int) – Dimension of the parameter space.
dim_cond (int) – Dimension of the observation space.
ch_obs (int, optional) – Number of channels for the observation space. Default is 1.
params (JointParams, optional) – Parameters for the Joint model. If None, default parameters are used.
training_config (dict, optional) – Configuration for training. If None, default configuration is used.
condition_mask_kind (str, optional) – Kind of condition mask to use. One of [“structured”, “posterior”].
Examples
Minimal example on how to instantiate and use the JointFlowPipeline:
1# %% Imports 2import os 3 4# Set JAX backend (use 'cuda' for GPU, 'cpu' otherwise) 5# os.environ["JAX_PLATFORMS"] = "cuda" 6 7import grain 8import numpy as np 9import jax 10from jax import numpy as jnp 11from numpyro import distributions as dist 12from flax import nnx 13 14from gensbi.recipes import JointFlowPipeline 15from gensbi.models import Simformer, SimformerParams 16 17from gensbi.utils.plotting import plot_marginals 18import matplotlib.pyplot as plt 19 20 21# %% 22 23theta_prior = dist.Uniform( 24 low=jnp.array([-2.0, -2.0, -2.0]), high=jnp.array([2.0, 2.0, 2.0]) 25) 26 27dim_obs = 3 28dim_cond = 3 29dim_joint = dim_obs + dim_cond 30 31 32# %% 33def simulator(key, nsamples): 34 theta_key, sample_key = jax.random.split(key, 2) 35 thetas = theta_prior.sample(theta_key, (nsamples,)) 36 37 xs = thetas + 1 + jax.random.normal(sample_key, thetas.shape) * 0.1 38 39 thetas = thetas[..., None] 40 xs = xs[..., None] 41 42 # when making a dataset for the joint pipeline, thetas need to come first 43 data = jnp.concatenate([thetas, xs], axis=1) 44 45 return data 46 47 48# %% Define your training and validation datasets. 49# We generate a training dataset and a validation dataset using the simulator. 50# The simulator is a simple function that generates parameters (theta) and data (x). 51# In this example, we use a simple Gaussian simulator. 52train_data = simulator(jax.random.PRNGKey(0), 100_000) 53val_data = simulator(jax.random.PRNGKey(1), 2000) 54# %% Normalize the dataset 55# It is important to normalize the data to have zero mean and unit variance. 56# This helps the model training process. 57means = jnp.mean(train_data, axis=0) 58stds = jnp.std(train_data, axis=0) 59 60 61def normalize(data, means, stds): 62 return (data - means) / stds 63 64 65def unnormalize(data, means, stds): 66 return data * stds + means 67 68 69# %% Prepare the data for the pipeline 70# The pipeline expects the data to be normalized but not split (for joint pipelines). 71 72 73# %% Prepare the data for the pipeline 74# The pipeline expects the data to be normalized but not split (for joint pipelines). 75def process_data(data): 76 return normalize(data, means, stds) 77 78 79# %% 80train_data.shape 81 82# %% 83 84# %% Create the input pipeline using Grain 85# We use Grain to create an efficient input pipeline. 86# This involves shuffling, repeating for multiple epochs, and batching the data. 87# We also map the process_data function to prepare (normalize) the data for the model. 88# We also map the process_data function to prepare (normalize) the data for the model. 89# %% Create the input pipeline using Grain 90# We use Grain to create an efficient input pipeline. 91# This involves shuffling, repeating for multiple epochs, and batching the data. 92# We also map the process_data function to prepare (normalize) the data for the model. 93batch_size = 256 94 95train_dataset_grain = ( 96 grain.MapDataset.source(np.array(train_data)) 97 .shuffle(42) 98 .repeat() 99 .to_iter_dataset() 100 .batch(batch_size) 101 .map(process_data) 102 # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching 103) 104 105val_dataset_grain = ( 106 grain.MapDataset.source(np.array(val_data)) 107 .shuffle(42) 108 .repeat() 109 .to_iter_dataset() 110 .batch(batch_size) 111 .map(process_data) 112 # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching 113) 114 115# %% Define your model 116# specific model parameters are defined here. 117# For Simformer, we need to specify dimensions, embedding strategies, and other architecture details. 118params = SimformerParams( 119 rngs=nnx.Rngs(0), 120 in_channels=1, 121 dim_value=20, 122 dim_id=10, 123 dim_condition=10, 124 dim_joint=dim_joint, 125 fourier_features=128, 126 num_heads=4, 127 num_layers=6, 128 widening_factor=3, 129 qkv_features=40, 130 num_hidden_layers=1, 131) 132 133model = Simformer(params) 134 135# %% Instantiate the pipeline 136# The JointFlowPipeline handles the training loop and sampling. 137# We configure it with the model, datasets, dimensions using a default training configuration. 138# We also specify the condition_mask_kind, which determines how conditioning is handled during training. 139training_config = JointFlowPipeline.get_default_training_config() 140training_config["nsteps"] = 10000 141 142pipeline = JointFlowPipeline( 143 model, 144 train_dataset_grain, 145 val_dataset_grain, 146 dim_obs, 147 dim_cond, 148 condition_mask_kind="posterior", 149 training_config=training_config, 150) 151 152# %% Train the model 153# We create a random key for training and start the training process. 154rngs = nnx.Rngs(42) 155pipeline.train( 156 rngs, save_model=False 157) # if you want to save the model, set save_model=True 158 159# %% Sample from the posterior 160# To generate samples, we first need an observation (and its corresponding condition). 161# We generate a new sample from the simulator, normalize it, and extract the condition x_o. 162 163new_sample = simulator(jax.random.PRNGKey(20), 1) 164true_theta = new_sample[:, :dim_obs, :] # extract observation from the joint sample 165 166new_sample = normalize(new_sample, means, stds) 167x_o = new_sample[:, dim_obs:, :] # extract condition from the joint sample 168 169# Then we invoke the pipeline's sample method. 170samples = pipeline.sample(rngs.sample(), x_o, nsamples=100_000) 171# Finally, we unnormalize the samples to get them back to the original scale. 172samples = unnormalize(samples, means[:dim_obs], stds[:dim_obs]) 173 174# %% Plot the samples 175# We verify the model's performance by plotting the marginal distributions of the generated samples 176# against the true parameters. 177plot_marginals( 178 np.array(samples[..., 0]), 179 gridsize=30, 180 true_param=np.array(true_theta[0, :, 0]), 181 range=[(1, 3), (1, 3), (-0.6, 0.5)], 182) 183plt.savefig("joint_flow_pipeline_marginals.png", dpi=100, bbox_inches="tight") 184plt.show()
Note
If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a
if __name__ == "__main__":guard. See https://docs.python.org/3/library/multiprocessing.html- _wrap_model()[source]#
Wrap the model for evaluation (either using JointWrapper or ConditionalWrapper).
- get_sampler(x_o, step_size=0.01, use_ema=True, time_grid=None, **model_extras)[source]#
Get a sampler function for generating samples from the trained model.
- Parameters:
key (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable.
step_size (float, optional) – Step size for the sampler.
use_ema (bool, optional) – Whether to use the EMA model for sampling.
time_grid (array-like, optional) – Time grid for the sampler (if applicable).
model_extras (dict, optional) – Additional model-specific parameters.
- Returns:
sampler – A function that generates samples when called with a random key and number of samples.
- Return type:
Callable: key, nsamples -> samples
- classmethod init_pipeline_from_config()[source]#
- Abstractmethod:
Initialize the pipeline from a configuration file.
- Parameters:
train_dataset (iterable) – Training dataset.
val_dataset (iterable) – Validation dataset.
dim_obs (int) – Dimensionality of the parameter (theta) space.
dim_cond (int) – Dimensionality of the observation (x) space.
config_path (str) – Path to the configuration file.
checkpoint_dir (str) – Directory for saving checkpoints.
- Returns:
pipeline – An instance of the pipeline initialized from the configuration.
- Return type:
- sample(key, x_o, nsamples=10000, step_size=0.01, use_ema=True, time_grid=None, **model_extras)[source]#
Generate samples from the trained model.
- Parameters:
key (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable (e.g., observed data).
nsamples (int, optional) – Number of samples to generate.
- Returns:
samples – Generated samples of size (nsamples, dim_obs, ch_obs).
- Return type:
array-like
- condition_mask_kind = 'structured'#
- dim_joint#
- loss_fn#
- p0_joint#
- p0_obs#
- path#
- class gensbi.recipes.SimformerDiffusionPipeline(train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, params=None, training_config=None, edge_mask=None, condition_mask_kind='structured')[source]#
Bases:
gensbi.recipes.joint_pipeline.JointDiffusionPipelineDiffusion pipeline for training and using a Joint model for simulation-based inference.
- Parameters:
train_dataset (grain dataset or iterator over batches) – Training dataset.
val_dataset (grain dataset or iterator over batches) – Validation dataset.
dim_obs (int) – Dimension of the parameter space.
dim_cond (int) – Dimension of the observation space.
ch_obs (int, optional) – Number of channels for the observation space. Default is 1.
params (optional) – Parameters for the Joint model. If None, default parameters are used.
training_config (dict, optional) – Configuration for training. If None, default configuration is used.
condition_mask_kind (str, optional) – Kind of condition mask to use. One of [“structured”, “posterior”].
Examples
Minimal example on how to instantiate and use the JointDiffusionPipeline:
1# %% Imports 2import os 3 4# Set JAX backend (use 'cuda' for GPU, 'cpu' otherwise) 5# os.environ["JAX_PLATFORMS"] = "cuda" 6 7import grain 8import numpy as np 9import jax 10from jax import numpy as jnp 11from gensbi.recipes import JointDiffusionPipeline 12from gensbi.utils.plotting import plot_marginals 13 14from gensbi.models import Simformer, SimformerParams 15import matplotlib.pyplot as plt 16 17from numpyro import distributions as dist 18 19 20from flax import nnx 21 22# %% 23 24theta_prior = dist.Uniform( 25 low=jnp.array([-2.0, -2.0, -2.0]), high=jnp.array([2.0, 2.0, 2.0]) 26) 27 28dim_obs = 3 29dim_cond = 3 30dim_joint = dim_obs + dim_cond 31 32 33# %% 34def simulator(key, nsamples): 35 theta_key, sample_key = jax.random.split(key, 2) 36 thetas = theta_prior.sample(theta_key, (nsamples,)) 37 38 xs = thetas + 1 + jax.random.normal(sample_key, thetas.shape) * 0.1 39 40 thetas = thetas[..., None] 41 xs = xs[..., None] 42 43 # when making a dataset for the joint pipeline, thetas need to come first 44 data = jnp.concatenate([thetas, xs], axis=1) 45 46 return data 47 48 49# %% Define your training and validation datasets. 50# We generate a training dataset and a validation dataset using the simulator. 51# The simulator is a simple function that generates parameters (theta) and data (x). 52# In this example, we use a simple Gaussian simulator. 53train_data = simulator(jax.random.PRNGKey(0), 100_000) 54val_data = simulator(jax.random.PRNGKey(1), 2000) 55# %% Normalize the dataset 56# It is important to normalize the data to have zero mean and unit variance. 57# This helps the model training process. 58means = jnp.mean(train_data, axis=0) 59stds = jnp.std(train_data, axis=0) 60 61 62def normalize(data, means, stds): 63 return (data - means) / stds 64 65 66def unnormalize(data, means, stds): 67 return data * stds + means 68 69 70# %% Prepare the data for the pipeline 71# The pipeline expects the data to be normalized but not split (for joint pipelines). 72 73 74def process_data(data): 75 return normalize(data, means, stds) 76 77 78# %% 79train_data.shape 80 81# %% 82 83# %% Create the input pipeline using Grain 84# We use Grain to create an efficient input pipeline. 85# This involves shuffling, repeating for multiple epochs, and batching the data. 86# We also map the process_data function to prepare (normalize) the data for the model. 87batch_size = 256 88 89train_dataset_grain = ( 90 grain.MapDataset.source(np.array(train_data)) 91 .shuffle(42) 92 .repeat() 93 .to_iter_dataset() 94 .batch(batch_size) 95 .map(process_data) 96 # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching 97) 98 99val_dataset_grain = ( 100 grain.MapDataset.source(np.array(val_data)) 101 .shuffle( 102 42 103 ) # Use a different seed/strategy for validation if needed, but shuffling is fine 104 .repeat() 105 .to_iter_dataset() 106 .batch(batch_size) 107 .map(process_data) 108 # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching 109) 110 111# %% Define your model 112# specific model parameters are defined here. 113# For Simformer, we need to specify dimensions, embedding strategies, and other architecture details. 114params = SimformerParams( 115 rngs=nnx.Rngs(0), 116 in_channels=1, 117 dim_value=20, 118 dim_id=10, 119 dim_condition=10, 120 dim_joint=dim_joint, 121 fourier_features=128, 122 num_heads=4, 123 num_layers=6, 124 widening_factor=3, 125 qkv_features=40, 126 num_hidden_layers=1, 127) 128 129model = Simformer(params) 130 131# %% Instantiate the pipeline 132# The JointDiffusionPipeline handles the training loop and sampling. 133# We configure it with the model, datasets, dimensions using a default training configuration. 134# We also specify the condition_mask_kind, which determines how conditioning is handled during training. 135training_config = JointDiffusionPipeline.get_default_training_config() 136training_config["nsteps"] = 10000 137 138pipeline = JointDiffusionPipeline( 139 model, 140 train_dataset_grain, 141 val_dataset_grain, 142 dim_obs, 143 dim_cond, 144 condition_mask_kind="posterior", 145 training_config=training_config, 146) 147 148# %% Train the model 149# We create a random key for training and start the training process. 150rngs = nnx.Rngs(42) 151pipeline.train( 152 rngs, save_model=False 153) # if you want to save the model, set save_model=True 154 155# %% Sample from the posterior 156# To generate samples, we first need an observation (and its corresponding condition). 157# We generate a new sample from the simulator, normalize it, and extract the condition x_o. 158 159new_sample = simulator(jax.random.PRNGKey(20), 1) 160true_theta = new_sample[:, :dim_obs, :] # extract observation from the joint sample 161 162new_sample = normalize(new_sample, means, stds) 163x_o = new_sample[:, dim_obs:, :] # extract condition from the joint sample 164 165# Then we invoke the pipeline's sample method. 166samples = pipeline.sample(rngs.sample(), x_o, nsamples=100_000) 167# Finally, we unnormalize the samples to get them back to the original scale. 168samples = unnormalize(samples, means[:dim_obs], stds[:dim_obs]) 169 170# %% Plot the samples 171# We verify the model's performance by plotting the marginal distributions of the generated samples 172# against the true parameters. 173plot_marginals( 174 np.array(samples[..., 0]), 175 gridsize=30, 176 true_param=np.array(true_theta[0, :, 0]), 177 range=[(1, 3), (1, 3), (-0.6, 0.5)], 178) 179plt.savefig("joint_diffusion_pipeline_marginals.png", dpi=100, bbox_inches="tight") 180plt.show()
Note
If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a
if __name__ == "__main__":guard. See https://docs.python.org/3/library/multiprocessing.html- classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_obs, dim_cond, config_path, checkpoint_dir)[source]#
Initialize the pipeline from a configuration file.
- Parameters:
config_path (str) – Path to the configuration file.
dim_obs (int)
dim_cond (int)
checkpoint_dir (str)
- sample(key, x_o, nsamples=10000, nsteps=18, use_ema=True, return_intermediates=False)[source]#
Generate samples from the trained model.
- Parameters:
key (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable (e.g., observed data).
nsamples (int, optional) – Number of samples to generate.
- Returns:
samples – Generated samples of size (nsamples, dim_obs, ch_obs).
- Return type:
array-like
- ch_obs = 1#
- dim_joint#
- edge_mask = None#
- ema_model#
- class gensbi.recipes.SimformerFlowPipeline(train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=1, params=None, training_config=None, edge_mask=None, condition_mask_kind='structured')[source]#
Bases:
gensbi.recipes.joint_pipeline.JointFlowPipelineFlow pipeline for training and using a Joint model for simulation-based inference.
- Parameters:
train_dataset (grain dataset or iterator over batches) – Training dataset.
val_dataset (grain dataset or iterator over batches) – Validation dataset.
dim_obs (int) – Dimension of the parameter space.
dim_cond (int) – Dimension of the observation space.
ch_obs (int, optional) – Number of channels for the observation space. Default is 1.
params (JointParams, optional) – Parameters for the Joint model. If None, default parameters are used.
training_config (dict, optional) – Configuration for training. If None, default configuration is used.
condition_mask_kind (str, optional) – Kind of condition mask to use. One of [“structured”, “posterior”].
Examples
Minimal example on how to instantiate and use the JointFlowPipeline:
1# %% Imports 2import os 3 4# Set JAX backend (use 'cuda' for GPU, 'cpu' otherwise) 5# os.environ["JAX_PLATFORMS"] = "cuda" 6 7import grain 8import numpy as np 9import jax 10from jax import numpy as jnp 11from numpyro import distributions as dist 12from flax import nnx 13 14from gensbi.recipes import JointFlowPipeline 15from gensbi.models import Simformer, SimformerParams 16 17from gensbi.utils.plotting import plot_marginals 18import matplotlib.pyplot as plt 19 20 21# %% 22 23theta_prior = dist.Uniform( 24 low=jnp.array([-2.0, -2.0, -2.0]), high=jnp.array([2.0, 2.0, 2.0]) 25) 26 27dim_obs = 3 28dim_cond = 3 29dim_joint = dim_obs + dim_cond 30 31 32# %% 33def simulator(key, nsamples): 34 theta_key, sample_key = jax.random.split(key, 2) 35 thetas = theta_prior.sample(theta_key, (nsamples,)) 36 37 xs = thetas + 1 + jax.random.normal(sample_key, thetas.shape) * 0.1 38 39 thetas = thetas[..., None] 40 xs = xs[..., None] 41 42 # when making a dataset for the joint pipeline, thetas need to come first 43 data = jnp.concatenate([thetas, xs], axis=1) 44 45 return data 46 47 48# %% Define your training and validation datasets. 49# We generate a training dataset and a validation dataset using the simulator. 50# The simulator is a simple function that generates parameters (theta) and data (x). 51# In this example, we use a simple Gaussian simulator. 52train_data = simulator(jax.random.PRNGKey(0), 100_000) 53val_data = simulator(jax.random.PRNGKey(1), 2000) 54# %% Normalize the dataset 55# It is important to normalize the data to have zero mean and unit variance. 56# This helps the model training process. 57means = jnp.mean(train_data, axis=0) 58stds = jnp.std(train_data, axis=0) 59 60 61def normalize(data, means, stds): 62 return (data - means) / stds 63 64 65def unnormalize(data, means, stds): 66 return data * stds + means 67 68 69# %% Prepare the data for the pipeline 70# The pipeline expects the data to be normalized but not split (for joint pipelines). 71 72 73# %% Prepare the data for the pipeline 74# The pipeline expects the data to be normalized but not split (for joint pipelines). 75def process_data(data): 76 return normalize(data, means, stds) 77 78 79# %% 80train_data.shape 81 82# %% 83 84# %% Create the input pipeline using Grain 85# We use Grain to create an efficient input pipeline. 86# This involves shuffling, repeating for multiple epochs, and batching the data. 87# We also map the process_data function to prepare (normalize) the data for the model. 88# We also map the process_data function to prepare (normalize) the data for the model. 89# %% Create the input pipeline using Grain 90# We use Grain to create an efficient input pipeline. 91# This involves shuffling, repeating for multiple epochs, and batching the data. 92# We also map the process_data function to prepare (normalize) the data for the model. 93batch_size = 256 94 95train_dataset_grain = ( 96 grain.MapDataset.source(np.array(train_data)) 97 .shuffle(42) 98 .repeat() 99 .to_iter_dataset() 100 .batch(batch_size) 101 .map(process_data) 102 # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching 103) 104 105val_dataset_grain = ( 106 grain.MapDataset.source(np.array(val_data)) 107 .shuffle(42) 108 .repeat() 109 .to_iter_dataset() 110 .batch(batch_size) 111 .map(process_data) 112 # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching 113) 114 115# %% Define your model 116# specific model parameters are defined here. 117# For Simformer, we need to specify dimensions, embedding strategies, and other architecture details. 118params = SimformerParams( 119 rngs=nnx.Rngs(0), 120 in_channels=1, 121 dim_value=20, 122 dim_id=10, 123 dim_condition=10, 124 dim_joint=dim_joint, 125 fourier_features=128, 126 num_heads=4, 127 num_layers=6, 128 widening_factor=3, 129 qkv_features=40, 130 num_hidden_layers=1, 131) 132 133model = Simformer(params) 134 135# %% Instantiate the pipeline 136# The JointFlowPipeline handles the training loop and sampling. 137# We configure it with the model, datasets, dimensions using a default training configuration. 138# We also specify the condition_mask_kind, which determines how conditioning is handled during training. 139training_config = JointFlowPipeline.get_default_training_config() 140training_config["nsteps"] = 10000 141 142pipeline = JointFlowPipeline( 143 model, 144 train_dataset_grain, 145 val_dataset_grain, 146 dim_obs, 147 dim_cond, 148 condition_mask_kind="posterior", 149 training_config=training_config, 150) 151 152# %% Train the model 153# We create a random key for training and start the training process. 154rngs = nnx.Rngs(42) 155pipeline.train( 156 rngs, save_model=False 157) # if you want to save the model, set save_model=True 158 159# %% Sample from the posterior 160# To generate samples, we first need an observation (and its corresponding condition). 161# We generate a new sample from the simulator, normalize it, and extract the condition x_o. 162 163new_sample = simulator(jax.random.PRNGKey(20), 1) 164true_theta = new_sample[:, :dim_obs, :] # extract observation from the joint sample 165 166new_sample = normalize(new_sample, means, stds) 167x_o = new_sample[:, dim_obs:, :] # extract condition from the joint sample 168 169# Then we invoke the pipeline's sample method. 170samples = pipeline.sample(rngs.sample(), x_o, nsamples=100_000) 171# Finally, we unnormalize the samples to get them back to the original scale. 172samples = unnormalize(samples, means[:dim_obs], stds[:dim_obs]) 173 174# %% Plot the samples 175# We verify the model's performance by plotting the marginal distributions of the generated samples 176# against the true parameters. 177plot_marginals( 178 np.array(samples[..., 0]), 179 gridsize=30, 180 true_param=np.array(true_theta[0, :, 0]), 181 range=[(1, 3), (1, 3), (-0.6, 0.5)], 182) 183plt.savefig("joint_flow_pipeline_marginals.png", dpi=100, bbox_inches="tight") 184plt.show()
Note
If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a
if __name__ == "__main__":guard. See https://docs.python.org/3/library/multiprocessing.html- classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_obs, dim_cond, config_path, checkpoint_dir)[source]#
Initialize the pipeline from a configuration file.
- Parameters:
config_path (str) – Path to the configuration file.
dim_obs (int)
dim_cond (int)
checkpoint_dir (str)
- sample(key, x_o, nsamples=10000, step_size=0.01, use_ema=True, time_grid=None)[source]#
Generate samples from the trained model.
- Parameters:
key (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable (e.g., observed data).
nsamples (int, optional) – Number of samples to generate.
- Returns:
samples – Generated samples of size (nsamples, dim_obs, ch_obs).
- Return type:
array-like
- ch_obs = 1#
- dim_joint#
- edge_mask = None#
- ema_model#
- class gensbi.recipes.UnconditionalDiffusionPipeline(model, train_dataset, val_dataset, dim_obs, ch_obs=1, params=None, training_config=None)[source]#
Bases:
gensbi.recipes.pipeline.AbstractPipelineDiffusion pipeline for training and using an Unconditional model for simulation-based inference.
- Parameters:
model (nnx.Module) – The model to be trained.
train_dataset (grain dataset or iterator over batches) – Training dataset.
val_dataset (grain dataset or iterator over batches) – Validation dataset.
dim_obs (int) – Dimension of the parameter space.
ch_obs (int) – Number of channels in the observation space.
params (optional) – Parameters for the model. Serves no use if a custom model is provided.
training_config (dict, optional) – Configuration for training. If None, default configuration is used.
Examples
Minimal example on how to instantiate and use the UnconditionalDiffusionPipeline:
1# %% Imports 2import os 3 4# Set JAX backend (use 'cuda' for GPU, 'cpu' otherwise) 5# os.environ["JAX_PLATFORMS"] = "cuda" 6 7import grain 8import numpy as np 9import jax 10from jax import numpy as jnp 11from gensbi.recipes import UnconditionalDiffusionPipeline 12from gensbi.utils.model_wrapping import _expand_dims, _expand_time 13from gensbi.utils.plotting import plot_marginals 14import matplotlib.pyplot as plt 15from gensbi.models import Simformer, SimformerParams 16 17 18from flax import nnx 19 20 21# %% define a simulator 22def simulator(key, nsamples): 23 return 3 + jax.random.normal(key, (nsamples, 2)) * jnp.array([0.5, 1]).reshape( 24 1, 2 25 ) # a simple 2D gaussian 26 27 28# %% 29 30 31# %% Define your training and validation datasets. 32train_data = simulator(jax.random.PRNGKey(0), 100_000).reshape(-1, 2, 1) 33val_data = simulator(jax.random.PRNGKey(1), 2000).reshape(-1, 2, 1) 34# %% 35# %% Normalize the dataset 36# It is important to normalize the data to have zero mean and unit variance. 37# This helps the model training process. 38means = jnp.mean(train_data, axis=0) 39stds = jnp.std(train_data, axis=0) 40 41 42def normalize(data, means, stds): 43 return (data - means) / stds 44 45 46def unnormalize(data, means, stds): 47 return data * stds + means 48 49 return normalize(data, means, stds) 50 51 52def process_data(data): 53 return normalize(data, means, stds) 54 55 56# %% Create the input pipeline using Grain 57# We use Grain to create an efficient input pipeline. 58# This involves shuffling, repeating for multiple epochs, and batching the data. 59# We also map the process_data function to prepare (normalize) the data. 60batch_size = 256 61 62train_dataset_grain = ( 63 grain.MapDataset.source(np.array(train_data)) 64 .shuffle(42) 65 .repeat() 66 .to_iter_dataset() 67 .batch(batch_size) 68 .map(process_data) 69 # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching 70) 71 72val_dataset_grain = ( 73 grain.MapDataset.source(np.array(val_data)) 74 .shuffle( 75 42 76 ) # Use a different seed/strategy for validation if needed, but shuffling is fine 77 .repeat() 78 .to_iter_dataset() 79 .batch(batch_size) 80 .map(process_data) 81 # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching 82) 83# %% Define your model 84# Here we define a MLP velocity field model, 85# this model only works for inputs of shape (batch, dim, 1). 86# For more complex models, please refer to the transformer-based models in gensbi.models. 87 88 89class MLP(nnx.Module): 90 def __init__(self, input_dim: int = 2, hidden_dim: int = 512, *, rngs: nnx.Rngs): 91 92 self.input_dim = input_dim 93 self.hidden_dim = hidden_dim 94 95 din = input_dim + 1 96 97 self.linear1 = nnx.Linear(din, self.hidden_dim, rngs=rngs) 98 self.linear2 = nnx.Linear(self.hidden_dim, self.hidden_dim, rngs=rngs) 99 self.linear3 = nnx.Linear(self.hidden_dim, self.hidden_dim, rngs=rngs) 100 self.linear4 = nnx.Linear(self.hidden_dim, self.hidden_dim, rngs=rngs) 101 self.linear5 = nnx.Linear(self.hidden_dim, self.input_dim, rngs=rngs) 102 103 def __call__(self, t: jax.Array, obs: jax.Array, node_ids, *args, **kwargs): 104 obs = _expand_dims(obs)[ 105 ..., 0 106 ] # for this specific model, we use samples of shape (batch, dim), while for transformer models we use (batch, dim, c) 107 t = _expand_time(t) 108 if t.ndim == 3: 109 t = t.reshape(t.shape[0], t.shape[1]) 110 t = jnp.broadcast_to(t, (obs.shape[0], 1)) 111 112 h = jnp.concatenate([obs, t], axis=-1) 113 114 x = self.linear1(h) 115 x = jax.nn.gelu(x) 116 117 x = self.linear2(x) 118 x = jax.nn.gelu(x) 119 120 x = self.linear3(x) 121 x = jax.nn.gelu(x) 122 123 x = self.linear4(x) 124 x = jax.nn.gelu(x) 125 126 x = self.linear5(x) 127 128 return x[..., None] # return shape (batch, dim, 1) 129 130 131model = MLP( 132 rngs=nnx.Rngs(42) 133) # your nnx.Module model here, e.g., a simple MLP, or the Simformer model 134# if you define a custom model, it should take as input the following arguments: 135# t: Array, 136# obs: Array, 137# node_ids: Array (optional, if your model is a transformer-based model) 138# *args 139# **kwargs 140 141# the obs input should have shape (batch_size, dim_joint, c), and the output will be of the same shape 142# %% Instantiate the pipeline 143dim_obs = 2 # Dimension of the parameter space 144ch_obs = 1 # Number of channels of the parameter space 145 146# The UnconditionalDiffusionPipeline handles the training loop and sampling. 147# We configure it with the model, datasets, dimensions using a default training configuration. 148training_config = UnconditionalDiffusionPipeline.get_default_training_config() 149training_config["nsteps"] = 10000 150 151pipeline = UnconditionalDiffusionPipeline( 152 model, 153 train_dataset_grain, 154 val_dataset_grain, 155 dim_obs, 156 ch_obs, 157 training_config=training_config, 158) 159 160# %% Train the model 161# We create a random key for training and start the training process. 162rngs = nnx.Rngs(42) 163pipeline.train( 164 rngs, save_model=False 165) # if you want to save the model, set save_model=True 166 167# %% Sample from the posterior 168# We generate new samples using the trained model. 169samples = pipeline.sample(rngs.sample(), nsamples=100_000) 170# Finally, we unnormalize the samples to get them back to the original scale. 171samples = unnormalize(samples, means, stds) 172 173# %% Plot the samples 174# We verify the model's performance by plotting the marginal distributions of the generated samples. 175samples.mean(axis=0), samples.std(axis=0) 176# %% 177 178plot_marginals( 179 np.array(samples[..., 0]), true_param=[3, 3], gridsize=20, range=[(-2, 8), (-2, 8)] 180) 181plt.savefig("unconditional_diffusion_samples.png", dpi=300, bbox_inches="tight") 182plt.show() 183 184# %%
Note
If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a
if __name__ == "__main__":guard. See https://docs.python.org/3/library/multiprocessing.html- _wrap_model()[source]#
Wrap the model for evaluation (either using JointWrapper or ConditionalWrapper).
- classmethod get_default_training_config(sde='EDM')[source]#
Return a dictionary of default training configuration parameters.
- Returns:
training_config – Default training configuration.
- Return type:
dict
- get_sampler(nsteps=18, use_ema=True, return_intermediates=False, **model_extras)[source]#
Get a sampler function for generating samples from the trained model.
- Parameters:
key (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable.
step_size (float, optional) – Step size for the sampler.
use_ema (bool, optional) – Whether to use the EMA model for sampling.
time_grid (array-like, optional) – Time grid for the sampler (if applicable).
model_extras (dict, optional) – Additional model-specific parameters.
- Returns:
sampler – A function that generates samples when called with a random key and number of samples.
- Return type:
Callable: key, nsamples -> samples
- classmethod init_pipeline_from_config()[source]#
- Abstractmethod:
Initialize the pipeline from a configuration file.
- Parameters:
train_dataset (iterable) – Training dataset.
val_dataset (iterable) – Validation dataset.
dim_obs (int) – Dimensionality of the parameter (theta) space.
dim_cond (int) – Dimensionality of the observation (x) space.
config_path (str) – Path to the configuration file.
checkpoint_dir (str) – Directory for saving checkpoints.
- Returns:
pipeline – An instance of the pipeline initialized from the configuration.
- Return type:
- sample(key, nsamples=10000, nsteps=18, use_ema=True, return_intermediates=False, **model_extras)[source]#
Generate samples from the trained model.
- Parameters:
key (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable (e.g., observed data).
nsamples (int, optional) – Number of samples to generate.
- Returns:
samples – Generated samples of size (nsamples, dim_obs, ch_obs).
- Return type:
array-like
- abstractmethod sample_batched(*args, **kwargs)[source]#
Generate samples from the trained model in batches.
- Parameters:
key (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable (e.g., observed data).
nsamples (int) – Number of samples to generate.
chunk_size (int, optional) – Size of each batch for sampling. Default is 50.
show_progress_bars (bool, optional) – Whether to display progress bars during sampling. Default is True.
args (tuple) – Additional positional arguments for the sampler.
kwargs (dict) – Additional keyword arguments for the sampler.
- Returns:
samples – Generated samples of shape (nsamples, batch_size_cond, dim_obs, ch_obs).
- Return type:
array-like
- loss_fn#
- obs_ids#
- path#
- class gensbi.recipes.UnconditionalFlowPipeline(model, train_dataset, val_dataset, dim_obs, ch_obs=1, params=None, training_config=None)[source]#
Bases:
gensbi.recipes.pipeline.AbstractPipelineFlow pipeline for training and using an Unconditional model for simulation-based inference.
- Parameters:
model (nnx.Module) – The model to be trained.
train_dataset (grain dataset or iterator over batches) – Training dataset.
val_dataset (grain dataset or iterator over batches) – Validation dataset.
dim_obs (int) – Dimension of the parameter space.
ch_obs (int) – Number of channels in the observation space.
params (optional) – Parameters for the model. Serves no use if a custom model is provided.
training_config (dict, optional) – Configuration for training. If None, default configuration is used.
Examples
Minimal example on how to instantiate and use the UnconditionalFlowPipeline:
1# %% Imports 2import os 3 4# Set JAX backend (use 'cuda' for GPU, 'cpu' otherwise) 5# os.environ["JAX_PLATFORMS"] = "cuda" 6 7import grain 8import numpy as np 9import jax 10from jax import numpy as jnp 11from gensbi.recipes import UnconditionalFlowPipeline 12from gensbi.utils.model_wrapping import _expand_dims, _expand_time 13from gensbi.utils.plotting import plot_marginals 14import matplotlib.pyplot as plt 15 16 17from flax import nnx 18 19 20# %% define a simulator 21def simulator(key, nsamples): 22 return 3 + jax.random.normal(key, (nsamples, 2)) * jnp.array([0.5, 1]).reshape( 23 1, 2 24 ) # a simple 2D gaussian 25 26 27# %% Define your training and validation datasets. 28# We generate a training dataset and a validation dataset using the simulator. 29# The simulator generates samples from a 2D Gaussian distribution. 30train_data = simulator(jax.random.PRNGKey(0), 100_000).reshape(-1, 2, 1) 31val_data = simulator(jax.random.PRNGKey(1), 2000).reshape(-1, 2, 1) 32 33# %% Normalize the dataset 34# It is important to normalize the data to have zero mean and unit variance. 35# This helps the model training process. 36means = jnp.mean(train_data, axis=0) 37stds = jnp.std(train_data, axis=0) 38 39 40def normalize(data, means, stds): 41 return (data - means) / stds 42 43 44def unnormalize(data, means, stds): 45 return data * stds + means 46 47 48def process_data(data): 49 return normalize(data, means, stds) 50 51 52# %% Create the input pipeline using Grain 53# We use Grain to create an efficient input pipeline. 54# This involves shuffling, repeating for multiple epochs, and batching the data. 55# We also map the process_data function to prepare (normalize) the data. 56batch_size = 256 57 58train_dataset_grain = ( 59 grain.MapDataset.source(np.array(train_data)) 60 .shuffle(42) 61 .repeat() 62 .to_iter_dataset() 63 .batch(batch_size) 64 .map(process_data) 65 # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching 66) 67 68val_dataset_grain = ( 69 grain.MapDataset.source(np.array(val_data)) 70 .shuffle( 71 42 72 ) # Use a different seed/strategy for validation if needed, but shuffling is fine 73 .repeat() 74 .to_iter_dataset() 75 .batch(batch_size) 76 .map(process_data) 77 # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching 78) 79 80 81# %% Define your model 82# Here we define a MLP velocity field model, 83# this model only works for inputs of shape (batch, dim, 1). 84# For more complex models, please refer to the transformer-based models in gensbi.models. 85class MLP(nnx.Module): 86 def __init__(self, input_dim: int = 2, hidden_dim: int = 128, *, rngs: nnx.Rngs): 87 88 self.input_dim = input_dim 89 self.hidden_dim = hidden_dim 90 91 din = input_dim + 1 92 93 self.linear1 = nnx.Linear(din, self.hidden_dim, rngs=rngs) 94 self.linear2 = nnx.Linear(self.hidden_dim, self.hidden_dim, rngs=rngs) 95 self.linear3 = nnx.Linear(self.hidden_dim, self.hidden_dim, rngs=rngs) 96 self.linear4 = nnx.Linear(self.hidden_dim, self.hidden_dim, rngs=rngs) 97 self.linear5 = nnx.Linear(self.hidden_dim, self.input_dim, rngs=rngs) 98 99 def __call__(self, t: jax.Array, obs: jax.Array, node_ids, *args, **kwargs): 100 obs = _expand_dims(obs)[ 101 ..., 0 102 ] # for this specific model, we use samples of shape (batch, dim), while for transformer models we use (batch, dim, c) 103 t = _expand_time(t) 104 t = jnp.broadcast_to(t, (obs.shape[0], 1)) 105 106 h = jnp.concatenate([obs, t], axis=-1) 107 108 x = self.linear1(h) 109 x = jax.nn.gelu(x) 110 111 x = self.linear2(x) 112 x = jax.nn.gelu(x) 113 114 x = self.linear3(x) 115 x = jax.nn.gelu(x) 116 117 x = self.linear4(x) 118 x = jax.nn.gelu(x) 119 120 x = self.linear5(x) 121 122 return x[..., None] # return shape (batch, dim, 1) 123 124 125model = MLP( 126 rngs=nnx.Rngs(42) 127) # your nnx.Module model here, e.g., a simple MLP, or the Simformer model 128# if you define a custom model, it should take as input the following arguments: 129# t: Array, 130# obs: Array, 131# node_ids: Array (optional, if your model is a transformer-based model) 132# *args 133# **kwargs 134 135# the obs input should have shape (batch_size, dim_joint, c), and the output will be of the same shape 136 137# %% Instantiate the pipeline 138# The UnconditionalFlowPipeline handles the training loop and sampling. 139# We configure it with the model, datasets, dimensions using a default training configuration. 140training_config = UnconditionalFlowPipeline.get_default_training_config() 141training_config["nsteps"] = 10000 142 143dim_obs = 2 # Dimension of the parameter space 144ch_obs = 1 # Number of channels of the parameter space 145 146pipeline = UnconditionalFlowPipeline( 147 model, 148 train_dataset_grain, 149 val_dataset_grain, 150 dim_obs, 151 ch_obs, 152 training_config=training_config, 153) 154 155# %% Train the model 156# We create a random key for training and start the training process. 157rngs = nnx.Rngs(42) 158pipeline.train( 159 rngs, save_model=False 160) # if you want to save the model, set save_model=True 161 162# %% Sample from the posterior 163# We generate new samples using the trained model. 164samples = pipeline.sample(rngs.sample(), nsamples=100_000) 165# Finally, we unnormalize the samples to get them back to the original scale. 166samples = unnormalize(samples, means, stds) 167 168# %% Plot the samples 169# We verify the model's performance by plotting the marginal distributions of the generated samples. 170plot_marginals( 171 np.array(samples[..., 0]), true_param=[3, 3], gridsize=30, range=[(-2, 8), (-2, 8)] 172) 173plt.savefig("unconditional_flow_samples.png", dpi=300, bbox_inches="tight") 174plt.show() 175# %%
Note
If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a
if __name__ == "__main__":guard. See https://docs.python.org/3/library/multiprocessing.html- _wrap_model()[source]#
Wrap the model for evaluation (either using JointWrapper or ConditionalWrapper).
- get_sampler(step_size=0.01, use_ema=True, time_grid=None, **model_extras)[source]#
Get a sampler function for generating samples from the trained model.
- Parameters:
key (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable.
step_size (float, optional) – Step size for the sampler.
use_ema (bool, optional) – Whether to use the EMA model for sampling.
time_grid (array-like, optional) – Time grid for the sampler (if applicable).
model_extras (dict, optional) – Additional model-specific parameters.
- Returns:
sampler – A function that generates samples when called with a random key and number of samples.
- Return type:
Callable: key, nsamples -> samples
- classmethod init_pipeline_from_config()[source]#
- Abstractmethod:
Initialize the pipeline from a configuration file.
- Parameters:
train_dataset (iterable) – Training dataset.
val_dataset (iterable) – Validation dataset.
dim_obs (int) – Dimensionality of the parameter (theta) space.
dim_cond (int) – Dimensionality of the observation (x) space.
config_path (str) – Path to the configuration file.
checkpoint_dir (str) – Directory for saving checkpoints.
- Returns:
pipeline – An instance of the pipeline initialized from the configuration.
- Return type:
- sample(key, nsamples=10000, step_size=0.01, use_ema=True, time_grid=None, **model_extras)[source]#
Generate samples from the trained model.
- Parameters:
key (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable (e.g., observed data).
nsamples (int, optional) – Number of samples to generate.
- Returns:
samples – Generated samples of size (nsamples, dim_obs, ch_obs).
- Return type:
array-like
- abstractmethod sample_batched(*args, **kwargs)[source]#
Generate samples from the trained model in batches.
- Parameters:
key (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable (e.g., observed data).
nsamples (int) – Number of samples to generate.
chunk_size (int, optional) – Size of each batch for sampling. Default is 50.
show_progress_bars (bool, optional) – Whether to display progress bars during sampling. Default is True.
args (tuple) – Additional positional arguments for the sampler.
kwargs (dict) – Additional keyword arguments for the sampler.
- Returns:
samples – Generated samples of shape (nsamples, batch_size_cond, dim_obs, ch_obs).
- Return type:
array-like
- loss_fn#
- obs_ids#
- p0_obs#
- path#