Source code for gensbi.recipes.conditional_pipeline

"""
Pipeline for training and using a Conditional model for simulation-based inference.
"""

import jax
import jax.numpy as jnp
from flax import nnx
import optax
from optax.contrib import reduce_on_plateau

from numpyro import distributions as dist
from tqdm.auto import tqdm
from functools import partial
import orbax.checkpoint as ocp

from typing import Union, Tuple

from gensbi.flow_matching.path import AffineProbPath
from gensbi.flow_matching.path.scheduler import CondOTScheduler
from gensbi.flow_matching.solver import ODESolver

from gensbi.diffusion.path import EDMPath
from gensbi.diffusion.path.scheduler import EDMScheduler, VEScheduler
from gensbi.diffusion.solver import SDESolver

from gensbi.models import ConditionalCFMLoss, ConditionalWrapper, ConditionalDiffLoss

from einops import repeat

from gensbi.models.flux1 import model
from gensbi.utils.model_wrapping import _expand_dims

import os

import yaml

from gensbi.recipes.pipeline import AbstractPipeline

from gensbi.recipes.utils import init_ids_1d, init_ids_2d


[docs] class ConditionalFlowPipeline(AbstractPipeline): """ Flow pipeline for training and using a Conditional model for simulation-based inference. Parameters ---------- model: nnx.Module The model to be trained. train_dataset : grain dataset or iterator over batches Training dataset. val_dataset : grain dataset or iterator over batches Validation dataset. dim_obs : int or tuple of int Dimension of the parameter space (number of tokens). Can represent unstructured data, time-series, or patchified 2D images. For images, provide a tuple (height, width). dim_cond : int or tuple of int Dimension of the observation space (number of tokens). Can represent unstructured data, time-series, or patchified 2D images. For images, provide a tuple (height, width). ch_obs : int, optional Number of channels per token in the observation data. Default is 1. ch_cond : int, optional Number of channels per token in the conditional data. Default is 1. params : ConditionalParams, optional Parameters for the Conditional model. If None, default parameters are used. training_config : dict, optional Configuration for training. If None, default configuration is used. Examples -------- Minimal example on how to instantiate and use the ConditionalFlowPipeline: .. literalinclude:: /examples/conditional_flow_pipeline.py :language: python :linenos: .. image:: /examples/conditional_flow_pipeline_marginals.png :width: 600 .. note:: If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a ``if __name__ == "__main__":`` guard. See https://docs.python.org/3/library/multiprocessing.html .. note:: Sampling in the latent space (latent diffusion/flow) is not currently supported. """ def __init__( self, model, train_dataset, val_dataset, dim_obs: Union[int, Tuple[int, int]], dim_cond: Union[int, Tuple[int, int]], ch_obs=1, ch_cond=1, id_embedding_strategy=("absolute", "absolute"), params=None, training_config=None, ): # if latent diffusion is enabled, make sure to adjust the dimensionality accordingly of the transformer model super().__init__( model=model, train_dataset=train_dataset, val_dataset=val_dataset, dim_obs=dim_obs, dim_cond=dim_cond, ch_obs=ch_obs, ch_cond=ch_cond, params=params, training_config=training_config, ) embeddings_1d = ["absolute", "pos1d", "rope1d"] embeddings_2d = ["pos2d", "rope2d"] if id_embedding_strategy[0] in embeddings_1d: obs_ids = init_ids_1d(dim_obs, semantic_id=0) elif id_embedding_strategy[0] in embeddings_2d: obs_ids = init_ids_2d(dim_obs, semantic_id=0) else: raise ValueError( f"Unknown id embedding strategy: {id_embedding_strategy[0]}" ) if id_embedding_strategy[1] in embeddings_1d: cond_ids = init_ids_1d(dim_cond, semantic_id=1) elif id_embedding_strategy[1] in embeddings_2d: cond_ids = init_ids_2d(dim_cond, semantic_id=1) else: raise ValueError( f"Unknown id embedding strategy: {id_embedding_strategy[1]}" )
[docs] self.obs_ids = obs_ids
[docs] self.cond_ids = cond_ids
[docs] self.path = AffineProbPath(scheduler=CondOTScheduler())
[docs] self.loss_fn = ConditionalCFMLoss(self.path)
[docs] self.p0_obs = dist.Independent( dist.Normal( loc=jnp.zeros((self.dim_obs, self.ch_obs)), scale=jnp.ones((self.dim_obs, self.ch_obs)), ), reinterpreted_batch_ndims=2, )
@classmethod
[docs] def init_pipeline_from_config( cls, ): raise NotImplementedError( "Initialization from config not implemented for ConditionalFlowPipeline." )
[docs] def _make_model(self): raise NotImplementedError( "Model creation not implemented for ConditionalFlowPipeline." )
[docs] def _get_default_params(self): raise NotImplementedError( "Default parameters not implemented for ConditionalFlowPipeline." )
[docs] def get_loss_fn( self, ): def loss_fn(model, batch, key: jax.random.PRNGKey): # obs = batch[:, : self.dim_obs, ...] # cond = batch[:, self.dim_obs :, ...] obs, cond = batch rng_x0, rng_t = jax.random.split(key, 2) batch_size = obs.shape[0] x_1 = obs # x_0 = self.p0_obs.sample(rng_x0, (batch_size,)) x_0 = jax.random.normal(rng_x0, (batch_size, self.dim_obs, self.ch_obs)) t = jax.random.uniform(rng_t, x_1.shape[0]) obs_batch = (x_0, x_1, t) loss = self.loss_fn(model, obs_batch, cond, self.obs_ids, self.cond_ids) return loss return loss_fn
# need to change wrt # def _get_optimizer(self): # """ # Construct the optimizer for training, including learning rate scheduling and gradient clipping. # Returns # ------- # optimizer : nnx.Optimizer # The optimizer instance for the model. # """ # # sbi_model_params = nnx.All(nnx.Param, nnx.PathContains('sbi_model')) # # sbi_model_params = nnx.All(nnx.Param, nnx.PathContains("model")) # opt = optax.chain( # optax.adaptive_grad_clip(10.0), # optax.adamw(self.training_config["max_lr"]), # reduce_on_plateau( # patience=self.training_config["patience"], # cooldown=self.training_config["cooldown"], # factor=self.training_config["factor"], # rtol=self.training_config["rtol"], # accumulation_size=self.training_config["accumulation_size"], # min_scale=self.training_config["min_scale"], # ), # ) # if self.training_config["multistep"] > 1: # opt = optax.MultiSteps(opt, self.training_config["multistep"]) # # optimizer = nnx.Optimizer(self.model, opt, wrt=sbi_model_params) # optimizer = nnx.Optimizer(self.model, opt, wrt=nnx.Param) # return optimizer # need to select the right weights to apply the updates
[docs] def get_train_step_fn(self, loss_fn): """ Return the training step function, which performs a single optimization step. Returns ------- train_step : Callable JIT-compiled training step function. """ # sbi_model_params = nnx.All(nnx.Param, nnx.PathContains('sbi_model')) # sbi_model_params = nnx.All(nnx.Param, nnx.PathContains("model")) @nnx.jit # something bad happens here def train_step(model, optimizer, batch, key: jax.random.PRNGKey): # diff_state = nnx.DiffState( # 0, sbi_model_params # ) # filter head params of the first argument # loss, grads = nnx.value_and_grad(loss_fn, argnums=diff_state)( # model, batch, key # ) loss, grads = nnx.value_and_grad(loss_fn)(model, batch, key) optimizer.update(model, grads, value=loss) return loss return train_step
[docs] def _wrap_model(self): self.model_wrapped = ConditionalWrapper(self.model) self.ema_model_wrapped = ConditionalWrapper(self.ema_model) return
[docs] def get_sampler( self, x_o, step_size=0.01, use_ema=True, time_grid=None, **model_extras, ): if use_ema: vf_wrapped = self.ema_model_wrapped else: vf_wrapped = self.model_wrapped if time_grid is None: time_grid = jnp.array([0.0, 1.0]) return_intermediates = False else: assert jnp.all(time_grid[:-1] <= time_grid[1:]) return_intermediates = True cond = _expand_dims(x_o) solver = ODESolver(velocity_model=vf_wrapped) model_extras = { "cond": cond, "obs_ids": self.obs_ids, "cond_ids": self.cond_ids, **model_extras, } sampler_ = solver.get_sampler( method="Dopri5", step_size=step_size, return_intermediates=return_intermediates, model_extras=model_extras, time_grid=time_grid, ) def sampler(key, nsamples): x_init = jax.random.normal(key, (nsamples, self.dim_obs, self.ch_obs)) samples = sampler_(x_init) return samples return sampler
[docs] def sample( self, key, x_o, nsamples=10_000, step_size=0.01, use_ema=True, time_grid=None, **model_extras, ): sampler_ = self.get_sampler( x_o, step_size=step_size, use_ema=use_ema, time_grid=time_grid, **model_extras, ) samples = sampler_(key, nsamples) return samples
# def compute_unnorm_logprob( # self, x_1, x_o, step_size=0.01, use_ema=True, time_grid=None, **model_extras # ): # if use_ema: # model = self.ema_model_wrapped # else: # model = self.model_wrapped # if time_grid is None: # time_grid = jnp.array([1.0, 0.0]) # return_intermediates = False # else: # # assert time grid is decreasing # assert jnp.all(time_grid[:-1] >= time_grid[1:]) # return_intermediates = True # solver = ODESolver(velocity_model=model) # # x_1 = _expand_dims(x_1) # assert ( # x_1.ndim == 2 # ), "x_1 must be of shape (num_samples, dim_obs), currently sampling for multiple channels is not supported." # cond = _expand_dims(x_o) # model_extras = { # "cond": cond, # "obs_ids": self.obs_ids, # "cond_ids": self.cond_ids, # **model_extras, # } # logp_sampler = solver.get_unnormalized_logprob( # time_grid=time_grid, # method="Dopri5", # step_size=step_size, # log_p0=self.p0_obs.log_prob, # model_extras=model_extras, # return_intermediates=return_intermediates, # ) # if len(x_1) > 4: # # we trigger precompilation first # _ = logp_sampler(x_1[:4]) # exact_log_p = logp_sampler(x_1) # return exact_log_p
[docs] class ConditionalDiffusionPipeline(AbstractPipeline): """ Diffusion pipeline for training and using a Conditional model for simulation-based inference. Parameters ---------- train_dataset : grain dataset or iterator over batches Training dataset. val_dataset : grain dataset or iterator over batches Validation dataset. dim_obs : int or tuple of int Dimension of the parameter space (number of tokens). Can represent unstructured data, time-series, or patchified 2D images. For images, provide a tuple (height, width). dim_cond : int or tuple of int Dimension of the observation space (number of tokens). Can represent unstructured data, time-series, or patchified 2D images. For images, provide a tuple (height, width). params : ConditionalParams, optional Parameters for the Conditional model. If None, default parameters are used. training_config : dict, optional Configuration for training. If None, default configuration is used. Examples -------- Minimal example on how to instantiate and use the ConditionalDiffusionPipeline: .. literalinclude:: /examples/conditional_diffusion_pipeline.py :language: python :linenos: .. image:: /examples/conditional_diffusion_pipeline_marginals.png :width: 600 .. note:: If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a ``if __name__ == "__main__":`` guard. See https://docs.python.org/3/library/multiprocessing.html .. note:: Sampling in the latent space (latent diffusion/flow) is not currently supported. """ def __init__( self, model, train_dataset, val_dataset, dim_obs: Union[int, Tuple[int, int]], dim_cond: Union[int, Tuple[int, int]], ch_obs=1, ch_cond=1, id_embedding_strategy=("absolute", "absolute"), params=None, training_config=None, ): super().__init__( model=model, train_dataset=train_dataset, val_dataset=val_dataset, dim_obs=dim_obs, dim_cond=dim_cond, ch_obs=ch_obs, ch_cond=ch_cond, params=params, training_config=training_config, ) # # Flux1 uses different ids for obs and cond # obs_ids = jnp.zeros((1, dim_obs, 2), dtype=jnp.int32) # obs_ids = obs_ids.at[..., 0].set(jnp.arange(dim_obs)) # cond_ids = jnp.zeros((1, dim_cond, 2), dtype=jnp.int32) # cond_ids = cond_ids.at[..., 0].set(jnp.arange(dim_cond)) # cond_ids = cond_ids.at[..., 1].set( # 1 # ) # set second channel to 1 for conditioning tokens embeddings_1d = ["absolute", "pos1d", "rope1d"] embeddings_2d = ["pos2d", "rope2d"] if id_embedding_strategy[0] in embeddings_1d: obs_ids = init_ids_1d(dim_obs, semantic_id=0) elif id_embedding_strategy[0] in embeddings_2d: obs_ids = init_ids_2d(dim_obs, semantic_id=0) else: raise ValueError( f"Unknown id embedding strategy: {id_embedding_strategy[0]}" ) if id_embedding_strategy[1] in embeddings_1d: cond_ids = init_ids_1d(dim_cond, semantic_id=1) elif id_embedding_strategy[1] in embeddings_2d: cond_ids = init_ids_2d(dim_cond, semantic_id=1) else: raise ValueError( f"Unknown id embedding strategy: {id_embedding_strategy[1]}" )
[docs] self.obs_ids = obs_ids
[docs] self.cond_ids = cond_ids
[docs] self.path = EDMPath( scheduler=EDMScheduler( sigma_min=self.training_config["sigma_min"], sigma_max=self.training_config["sigma_max"], ) )
[docs] self.loss_fn = ConditionalDiffLoss(self.path)
@classmethod
[docs] def init_pipeline_from_config( cls, ): raise NotImplementedError( "Initialization from config not implemented for ConditionalDiffusionPipeline." )
[docs] def _make_model(self): raise NotImplementedError( "Model creation not implemented for ConditionalDiffusionPipeline." )
[docs] def _get_default_params(self): raise NotImplementedError( "Default parameters not implemented for ConditionalDiffusionPipeline." )
@classmethod
[docs] def get_default_training_config(cls): config = super().get_default_training_config() config.update( { "sigma_min": 0.002, # from edm paper "sigma_max": 80.0, } ) return config
[docs] def get_loss_fn( self, ): def loss_fn(model, batch, key: jax.random.PRNGKey): # jax debug print(batch.shape) # (batch_size, dim_obs + dim_cond) # obs = jnp.take_along_axis(batch, self.obs_ids, axis=1) # cond = jnp.take_along_axis(batch, self.cond_ids, axis=1) # obs = batch[:, : self.dim_obs, ...] # cond = batch[:, self.dim_obs :, ...] obs, cond = batch rng_x0, rng_sigma = jax.random.split(key, 2) x_1 = obs # sigma = self.path.sample_sigma(rng_sigma, (x_1.shape[0],)) sigma = self.path.sample_sigma(rng_sigma, (x_1.shape[0], 1, 1)) # sigma = repeat(sigma, f"b -> b {'1 ' * (x_1.ndim - 1)}") # TODO fixme obs_batch = (x_1, sigma) loss = self.loss_fn( rng_x0, model, obs_batch, cond, self.obs_ids, self.cond_ids ) return loss return loss_fn
# def _get_optimizer(self): # """ # Construct the optimizer for training, including learning rate scheduling and gradient clipping. # Returns # ------- # optimizer : nnx.Optimizer # The optimizer instance for the model. # """ # # sbi_model_params = nnx.All(nnx.Param, nnx.PathContains("sbi_model")) # # sbi_model_params = nnx.All(nnx.Param, nnx.PathContains("model")) # opt = optax.chain( # optax.adaptive_grad_clip(10.0), # optax.adamw(self.training_config["max_lr"]), # reduce_on_plateau( # patience=self.training_config["patience"], # cooldown=self.training_config["cooldown"], # factor=self.training_config["factor"], # rtol=self.training_config["rtol"], # accumulation_size=self.training_config["accumulation_size"], # min_scale=self.training_config["min_scale"], # ), # ) # if self.training_config["multistep"] > 1: # opt = optax.MultiSteps(opt, self.training_config["multistep"]) # # optimizer = nnx.Optimizer(self.model, opt, wrt=sbi_model_params) # optimizer = nnx.Optimizer(self.model, opt, wrt=nnx.Param) # return optimizer
[docs] def get_train_step_fn(self, loss_fn): """ Return the training step function, which performs a single optimization step. Returns ------- train_step : Callable JIT-compiled training step function. """ # sbi_model_params = nnx.All(nnx.Param, nnx.PathContains("sbi_model")) # sbi_model_params = nnx.All(nnx.Param, nnx.PathContains("model")) @nnx.jit def train_step(model, optimizer, batch, key: jax.random.PRNGKey): # diff_state = nnx.DiffState(0, sbi_model_params) # loss, grads = nnx.value_and_grad(loss_fn, argnums=diff_state)( # model, batch, key # ) loss, grads = nnx.value_and_grad(loss_fn)(model, batch, key) optimizer.update(model, grads, value=loss) return loss return train_step
[docs] def _wrap_model(self): self.model_wrapped = ConditionalWrapper(self.model) self.ema_model_wrapped = ConditionalWrapper(self.ema_model) return
[docs] def get_sampler( self, x_o, nsteps=18, use_ema=True, return_intermediates=False, **model_extras, ): if use_ema: model = self.ema_model_wrapped else: model = self.model_wrapped cond = _expand_dims(x_o) solver = SDESolver(score_model=model, path=self.path) model_extras = { "cond": cond, "obs_ids": self.obs_ids, "cond_ids": self.cond_ids, **model_extras, } sampler_ = solver.get_sampler( nsteps=nsteps, return_intermediates=return_intermediates, model_extras=model_extras, ) def sampler(key, nsamples): key1, key2 = jax.random.split(key, 2) x_init = self.path.sample_prior(key1, (nsamples, self.dim_obs, self.ch_obs)) samples = sampler_(key2, x_init) return samples return sampler
[docs] def sample( self, key, x_o, nsamples=10_000, nsteps=18, use_ema=True, return_intermediates=False, **model_extras, ): sampler = self.get_sampler( x_o, nsteps=nsteps, use_ema=use_ema, return_intermediates=return_intermediates, **model_extras, ) return sampler(key, nsamples)