Source code for gensbi.recipes.flux1

"""
Pipeline for training and using a Flux1 model for simulation-based inference.
"""

import jax.numpy as jnp
from flax import nnx


from gensbi.models import (
    Flux1,
    Flux1Params,
)

import yaml

from gensbi.recipes.conditional_pipeline import ConditionalPipeline
from gensbi.recipes.utils import parse_training_config


[docs] def parse_flux1_params(config_path: str): """ Parse a Flux1 configuration file. Parameters ---------- config_path : str Path to the configuration file. Returns ------- config : dict Parsed configuration dictionary. """ with open(config_path, "r") as f: config = yaml.safe_load(f) model_params = config.get("model", {}) params_dict = dict( in_channels=model_params.get("in_channels", 1), vec_in_dim=model_params.get("vec_in_dim", None), context_in_dim=model_params.get("context_in_dim", 1), mlp_ratio=model_params.get("mlp_ratio", 4), num_heads=model_params.get("num_heads", 6), depth=model_params.get("depth", 8), depth_single_blocks=model_params.get("depth_single_blocks", 16), # Fields with set defaults in Flux1Params axes_dim=model_params.get("axes_dim", None), val_emb_dim=model_params.get("val_emb_dim", None), id_emb_dim=model_params.get("id_emb_dim", None), id_merge_mode=model_params.get("id_merge_mode", "sum"), qkv_bias=model_params.get("qkv_bias", True), theta=model_params.get("theta", None), id_embedding_strategy=tuple( model_params.get("id_embedding_strategy", ("absolute", "absolute")) ), param_dtype=getattr(jnp, model_params.get("param_dtype", "float32")), ) return params_dict
[docs] def get_default_flux1_params( dim_obs: int, dim_cond: int, ch_obs: int = 1, ch_cond: int = 1 ) -> Flux1Params: """ Return default parameters for the Flux1 model. """ return Flux1Params( in_channels=ch_obs, vec_in_dim=None, context_in_dim=ch_cond, mlp_ratio=4, num_heads=6, depth=8, depth_single_blocks=16, qkv_bias=True, rngs=nnx.Rngs(default=42), dim_obs=dim_obs, dim_cond=dim_cond, axes_dim=[6, 0], theta=10 * (dim_obs + dim_cond), id_embedding_strategy=("absolute", "absolute"), param_dtype=jnp.float32, )
[docs] def _flux1_config_from_path(config_path: str, dim_obs: int, dim_cond: int): """ Helper to parse common configuration for Flux1 pipelines. Returns ------- params : Flux1Params The parsed model parameters. training_config : dict The parsed training configuration. method : str The methodology (flow or diffusion) specified in the config. """ with open(config_path, "r") as f: config = yaml.safe_load(f) # methodology strategy = config.get("strategy", {}) method = strategy.get("method") model_type = strategy.get("model") if model_type != "flux": raise ValueError(f"Model type {model_type} not supported.") params_dict = parse_flux1_params(config_path) # Handle theta default logic if it was set to -1 (meaning "auto") if params_dict["theta"] in [-1, None]: dim_joint = dim_obs + dim_cond params_dict["theta"] = 10 * dim_joint # Default value used in original code params = Flux1Params( rngs=nnx.Rngs(0), dim_obs=dim_obs, dim_cond=dim_cond, **params_dict, ) training_config = parse_training_config(config_path) return params, training_config, method
[docs] class Flux1FlowPipeline(ConditionalPipeline): def __init__( self, train_dataset, val_dataset, dim_obs: int, dim_cond: int, ch_obs=1, ch_cond=1, params=None, training_config=None, ): """ Flow pipeline for training and using a Flux1 model for simulation-based inference. Parameters ---------- train_dataset : grain dataset or iterator over batches Training dataset. val_dataset : grain dataset or iterator over batches Validation dataset. dim_obs : int Dimension of the parameter space. dim_cond : int Dimension of the observation space. ch_obs : int, optional Number of channels in the observation data. Default is 1. ch_cond : int, optional Number of channels in the conditional data. Default is 1. params : Flux1Params, optional Parameters for the Flux1 model. If None, default parameters are used. training_config : dict, optional Configuration for training. If None, default configuration is used. Examples -------- Minimal example on how to instantiate and use the Flux1FlowPipeline: .. literalinclude:: /examples/flux1_flow_pipeline.py :language: python :linenos: .. image:: /examples/flux1_flow_pipeline_marginals.png :width: 600 .. note:: If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a ``if __name__ == "__main__":`` guard. See https://docs.python.org/3/library/multiprocessing.html """ if params is not None: ch_obs = params.in_channels if params is not None: ch_cond = params.context_in_dim if params is None: params = get_default_flux1_params(dim_obs, dim_cond, ch_obs, ch_cond) model = self._make_model(params) from gensbi.core import FlowMatchingMethod super().__init__( model=model, train_dataset=train_dataset, val_dataset=val_dataset, dim_obs=dim_obs, dim_cond=dim_cond, method=FlowMatchingMethod(), ch_obs=ch_obs, ch_cond=ch_cond, params=params, training_config=training_config, id_embedding_strategy=params.id_embedding_strategy, )
[docs] self.ema_model = nnx.clone(self.model)
@classmethod
[docs] def init_pipeline_from_config( cls, train_dataset, val_dataset, dim_obs: int, dim_cond: int, config_path: str, checkpoint_dir: str, **kwargs, ): """ Initialize the pipeline from a configuration file. Parameters ---------- config_path : str Path to the configuration file. **kwargs Additional keyword arguments forwarded to the constructor. """ params, training_config, method = _flux1_config_from_path( config_path, dim_obs, dim_cond ) if method != "flow": raise ValueError(f"Method {method} not supported in Flux1FlowPipeline.") # add checkpoint dir to training config training_config["checkpoint_dir"] = checkpoint_dir pipeline = cls( train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=params.in_channels, ch_cond=params.context_in_dim, params=params, training_config=training_config, **kwargs, ) return pipeline
[docs] def _make_model(self, params): """ Create and return the Flux1 model to be trained. """ model = Flux1(params) return model
@classmethod
[docs] def get_default_params(cls, dim_obs, dim_cond, ch_obs, ch_cond): """ Return a dictionary of default model parameters. """ return get_default_flux1_params(dim_obs, dim_cond, ch_obs, ch_cond)
[docs] class Flux1DiffusionPipeline(ConditionalPipeline): def __init__( self, train_dataset, val_dataset, dim_obs: int, dim_cond: int, ch_obs=1, ch_cond=1, params=None, training_config=None, ): """ Diffusion pipeline for training and using a Flux1 model for simulation-based inference. Parameters ---------- train_dataset : grain dataset or iterator over batches Training dataset. val_dataset : grain dataset or iterator over batches Validation dataset. dim_obs : int Dimension of the parameter space. dim_cond : int Dimension of the observation space. ch_obs : int, optional Number of channels in the observation data. Default is 1. ch_cond : int, optional Number of channels in the conditional data. Default is 1. params : Flux1Params, optional Parameters for the Flux1 model. If None, default parameters are used. training_config : dict, optional Configuration for training. If None, default configuration is used. Examples -------- Minimal example on how to instantiate and use the Flux1DiffusionPipeline: .. literalinclude:: /examples/flux1_diffusion_pipeline.py :language: python :linenos: .. image:: /examples/flux1_diffusion_pipeline_marginals.png :width: 600 .. note:: If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a ``if __name__ == "__main__":`` guard. See https://docs.python.org/3/library/multiprocessing.html """ if params is not None: ch_obs = params.in_channels if params is not None: ch_cond = params.context_in_dim if params is None: params = get_default_flux1_params(dim_obs, dim_cond, ch_obs, ch_cond) model = self._make_model(params) from gensbi.core import DiffusionEDMMethod super().__init__( model=model, train_dataset=train_dataset, val_dataset=val_dataset, dim_obs=dim_obs, dim_cond=dim_cond, method=DiffusionEDMMethod(), ch_obs=ch_obs, ch_cond=ch_cond, params=params, training_config=training_config, id_embedding_strategy=params.id_embedding_strategy, )
[docs] self.ema_model = nnx.clone(self.model)
@classmethod
[docs] def init_pipeline_from_config( cls, train_dataset, val_dataset, dim_obs: int, dim_cond: int, config_path: str, checkpoint_dir: str, **kwargs, ): """ Initialize the pipeline from a configuration file. Parameters ---------- config_path : str Path to the configuration file. **kwargs Additional keyword arguments forwarded to the constructor. """ params, training_config, method = _flux1_config_from_path( config_path, dim_obs, dim_cond ) if method != "diffusion": raise ValueError( f"Method {method} not supported in Flux1DiffusionPipeline." ) # add checkpoint dir to training config training_config["checkpoint_dir"] = checkpoint_dir pipeline = cls( train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=params.in_channels, ch_cond=params.context_in_dim, params=params, training_config=training_config, **kwargs, ) return pipeline
[docs] def _make_model(self, params): """ Create and return the Flux1 model to be trained. """ model = Flux1(params) return model
@classmethod
[docs] def get_default_params(cls, dim_obs, dim_cond, ch_obs, ch_cond): """ Return a dictionary of default model parameters. """ return get_default_flux1_params(dim_obs, dim_cond, ch_obs, ch_cond)
[docs] class Flux1SMPipeline(ConditionalPipeline): def __init__( self, train_dataset, val_dataset, dim_obs: int, dim_cond: int, ch_obs=1, ch_cond=1, sde_type: str = "VP", params=None, training_config=None, ): """ Score matching pipeline for training and using a Flux1 model for simulation-based inference. Parameters ---------- train_dataset : grain dataset or iterator over batches Training dataset. val_dataset : grain dataset or iterator over batches Validation dataset. dim_obs : int Dimension of the parameter space. dim_cond : int Dimension of the observation space. ch_obs : int, optional Number of channels in the observation data. Default is 1. ch_cond : int, optional Number of channels in the conditional data. Default is 1. sde_type : str Type of SDE. One of ``"VP"`` or ``"VE"``. params : Flux1Params, optional Parameters for the Flux1 model. If None, default parameters are used. training_config : dict, optional Configuration for training. If None, default configuration is used. """ if params is not None: ch_obs = params.in_channels if params is not None: ch_cond = params.context_in_dim if params is None: params = get_default_flux1_params(dim_obs, dim_cond, ch_obs, ch_cond) model = self._make_model(params) from gensbi.core import ScoreMatchingMethod super().__init__( model=model, train_dataset=train_dataset, val_dataset=val_dataset, dim_obs=dim_obs, dim_cond=dim_cond, method=ScoreMatchingMethod(sde_type=sde_type), ch_obs=ch_obs, ch_cond=ch_cond, params=params, training_config=training_config, id_embedding_strategy=params.id_embedding_strategy, )
[docs] self.ema_model = nnx.clone(self.model)
@classmethod
[docs] def init_pipeline_from_config( cls, train_dataset, val_dataset, dim_obs: int, dim_cond: int, config_path: str, checkpoint_dir: str, **kwargs, ): """ Initialize the pipeline from a configuration file. Parameters ---------- config_path : str Path to the configuration file. **kwargs Additional keyword arguments forwarded to the constructor (e.g. ``sde_type="VE"`` for score matching). """ params, training_config, method = _flux1_config_from_path( config_path, dim_obs, dim_cond ) if method != "score_matching": raise ValueError( f"Method {method} not supported in Flux1SMPipeline." ) training_config["checkpoint_dir"] = checkpoint_dir pipeline = cls( train_dataset, val_dataset, dim_obs, dim_cond, ch_obs=params.in_channels, ch_cond=params.context_in_dim, params=params, training_config=training_config, **kwargs, ) return pipeline
[docs] def _make_model(self, params): """ Create and return the Flux1 model to be trained. """ model = Flux1(params) return model
@classmethod
[docs] def get_default_params(cls, dim_obs, dim_cond, ch_obs, ch_cond): """ Return a dictionary of default model parameters. """ return get_default_flux1_params(dim_obs, dim_cond, ch_obs, ch_cond)