Source code for gensbi.diffusion.solver.edm_solver

from functools import partial
from typing import Callable, Optional, Sequence, Tuple, Union, Any

import jax
import jax.numpy as jnp
from jax import jit
from jax import Array

from gensbi.solver import Solver
from gensbi.diffusion.solver.edm_samplers import edm_sampler, edm_ablation_sampler
from gensbi.diffusion.path import EDMPath


[docs] class EDMSolver(Solver): def __init__(self, score_model: Callable, path: EDMPath) -> None: """ Initialize the SDE solver. Parameters ---------- score_model : Callable The score model function. path : EDMPath The EDMPath object. Example: .. code-block:: python from gensbi.diffusion.solver import EDMSolver from gensbi.diffusion.path import EDMPath from gensbi.diffusion.path.scheduler import EDMScheduler import jax, jax.numpy as jnp scheduler = EDMScheduler() path = EDMPath(scheduler) def score_model(x, t): return x + t solver = EDMSolver(score_model, path) key = jax.random.PRNGKey(0) x_init = jax.random.normal(key, (16, 2)) samples = solver.sample(key, x_init, nsteps=10) print(samples.shape) # (10, 16, 2) """
[docs] self.score_model = score_model
[docs] self.path = path
assert self.path.scheduler.name in [ "EDM", "EDM-VP", "EDM-VE", ], f"Path must be one of ['EDM', 'EDM-VP', 'EDM-VE'], got {self.path.name}."
[docs] def get_sampler( self, condition_mask: Optional[Array] = None, condition_value: Optional[Array] = None, cfg_scale: Optional[float] = None, nsteps: int = 18, method: str = "Heun", return_intermediates: bool = False, static_model_kwargs: dict = None, solver_params: Optional[dict] = None, solver_scheduler: Optional[Any] = None, ) -> Callable: """ Returns a sampler function for the SDE. Parameters ---------- condition_mask : Optional[Array] Mask for conditioning. condition_value : Optional[Array] Value for conditioning. cfg_scale : Optional[float] Classifier-free guidance scale (not implemented). nsteps : int Number of steps. method : str Integration method. return_intermediates : bool Whether to return intermediate steps. static_model_kwargs : dict Static model arguments baked into the sampler. Condition-dependent data should be passed at call time via ``model_extras``. solver_params : Optional[dict] Additional solver parameters. solver_scheduler : Optional[Any] Scheduler to use for the solver. If None, the path's scheduler is used. Returns ------- Callable ``sample(key, x_init, model_extras=None)`` sampler function. """ if solver_scheduler is None: solver_scheduler = self.path.scheduler if static_model_kwargs is None: static_model_kwargs = {} if solver_params is None: solver_params = {} if solver_scheduler.name == "EDM": sampler_ = edm_sampler else: # Bind the training scheduler as denoise_scheduler so the model # is always called with the preconditioning it was trained with. # Use a lambda (not partial) to insert denoise_scheduler in the # correct positional slot while keeping the same call signature # as edm_sampler: (sched, model, x_1, **kw). _denoise_sched = self.path.scheduler sampler_ = lambda sched, model, x_1, **kw: edm_ablation_sampler( sched, _denoise_sched, model, x_1, **kw ) if cfg_scale is not None: raise NotImplementedError( "CFG scale is not implemented for EDM samplers yet." ) S_churn = solver_params.get("S_churn", 0) # type: ignore S_min = solver_params.get("S_min", 0) # type: ignore S_max = solver_params.get("S_max", float("inf")) # type: ignore S_noise = solver_params.get("S_noise", 1) # type: ignore @jit def sample(key: Array, x_init: Array, model_extras=None) -> Array: if model_extras is None: model_extras = {} return sampler_( solver_scheduler, self.score_model, x_init, key=key, condition_mask=condition_mask, condition_value=condition_value, return_intermediates=return_intermediates, n_steps=nsteps, S_churn=S_churn, S_min=S_min, S_max=S_max, S_noise=S_noise, method=method, model_kwargs={**static_model_kwargs, **model_extras}, ) return sample
[docs] def sample( self, key: Array, x_init: Array, condition_mask: Optional[Array] = None, condition_value: Optional[Array] = None, cfg_scale: Optional[float] = None, nsteps: int = 18, method: str = "Heun", return_intermediates: bool = False, model_extras: dict = None, solver_params: Optional[dict] = None, solver_scheduler: Optional[Any] = None, ) -> Array: """ Sample from the SDE using the sampler. Parameters ---------- key : Array JAX random key. x_init : Array Initial value. condition_mask : Optional[Array] Mask for conditioning. condition_value : Optional[Array] Value for conditioning. cfg_scale : Optional[float] Classifier-free guidance scale (not implemented). nsteps : int Number of steps. method : str Integration method. return_intermediates : bool Whether to return intermediate steps. model_extras : dict Runtime model extras (e.g. ``cond``, ``obs_ids``). solver_params : Optional[dict] Additional solver parameters. solver_scheduler : Optional[Any] Scheduler to use for the solver. If None, the path's scheduler is used. Returns ------- Array Sampled output. """ sample = self.get_sampler( condition_mask=condition_mask, condition_value=condition_value, cfg_scale=cfg_scale, nsteps=nsteps, method=method, return_intermediates=return_intermediates, solver_params=solver_params, solver_scheduler=solver_scheduler, ) return sample(key, x_init, model_extras=model_extras)