import jax
from jax import numpy as jnp
import numpy as np
from typing import Union, Tuple
from einops import repeat, rearrange
from jax import Array
from gensbi.diffusion.path import EDMPath
from gensbi.diffusion.path.scheduler import (
EDMScheduler,
VEEdmScheduler,
VPEdmScheduler,
)
from gensbi.diffusion.path.sm_path import SMPath
from gensbi.diffusion.path.scheduler import VPSmScheduler, VESmScheduler
[docs]
def init_ids_joint(dim_obs: int, dim_cond: int):
dim_joint = dim_obs + dim_cond
node_ids = jnp.arange(dim_joint).reshape((1, -1, 1))
obs_ids = jnp.arange(dim_obs).reshape((1, -1, 1)) # observation ids
cond_ids = jnp.arange(dim_obs, dim_joint).reshape((1, -1, 1)) # conditional ids
return node_ids, obs_ids, cond_ids
[docs]
def init_ids_1d(dim: int, semantic_id: Union[int, None] = None):
if semantic_id is None:
ids = np.zeros((1, dim, 1), dtype=np.int32)
else:
ids = np.zeros((1, dim, 2), dtype=np.int32)
ids[..., 1] = semantic_id
ids[0, :, 0] = np.arange(dim)
return jnp.array(ids, dtype=jnp.int32), dim
[docs]
def init_ids_2d(dim: Tuple[int, int], semantic_id: int = 0):
img_ids = np.zeros((dim[0] // 2, dim[1] // 2, 3), dtype=np.int32)
img_ids[..., 0] = semantic_id
img_ids[..., 1] = img_ids[..., 1] + np.arange(dim[0] // 2)[:, None]
img_ids[..., 2] = img_ids[..., 2] + np.arange(dim[1] // 2)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=1)
dim = (dim[0] // 2) * (dim[1] // 2)
return jnp.array(img_ids, dtype=jnp.int32), dim
@jax.jit
[docs]
def patchify_2d(x: Array):
return rearrange(x, "b (h ph) (w pw) c -> b (h w) (c ph pw)", ph=2, pw=2)
[docs]
def scale_lr(batch_size, base_lr=1e-4, reference_batch_size=256):
"""Scale learning rate based on batch size using square root scaling.
Parameters
----------
batch_size : int
The current batch size.
base_lr : float
The base learning rate for the reference batch size.
reference_batch_size : int, optional
The reference batch size. Defaults to 256.
Returns
-------
float
The adjusted learning rate.
"""
import math
return base_lr * math.sqrt(batch_size / reference_batch_size)
[docs]
_EMBEDDINGS_1D = {"absolute", "pos1d", "rope1d"}
[docs]
_EMBEDDINGS_2D = {"pos2d", "rope2d"}
[docs]
def _resolve_embedding_ids(dim, strategy: str, semantic_id: int):
"""Resolve ID embeddings by strategy name.
Parameters
----------
dim : int or tuple of int
Dimension specification (number of tokens, or (H, W) for 2D images).
strategy : str
Embedding strategy name (e.g., "absolute", "pos1d", "rope1d",
"pos2d", "rope2d").
semantic_id : int
Semantic identifier for the token group (0=obs, 1=cond).
Returns
-------
ids : Array
Token ID array.
resolved_dim : int
Resolved flat dimension.
Raises
------
ValueError
If ``strategy`` is not recognized.
"""
if strategy in _EMBEDDINGS_1D:
return init_ids_1d(dim, semantic_id=semantic_id)
elif strategy in _EMBEDDINGS_2D:
return init_ids_2d(dim, semantic_id=semantic_id)
else:
raise ValueError(f"Unknown id embedding strategy: {strategy}")
[docs]
def build_edm_path(sde: str, config: dict) -> EDMPath:
"""Build an EDM-family diffusion path from an SDE type string and config.
Parameters
----------
sde : str
SDE type: ``"EDM"``, ``"VE"``, or ``"VP"``.
config : dict
Training configuration dict; scheduler hyperparameters are read from
here with sensible defaults.
Returns
-------
EDMPath
Configured diffusion path.
Raises
------
ValueError
If ``sde`` is not one of ``{"EDM", "VE", "VP"}``.
"""
if sde == "EDM":
return EDMPath(
scheduler=EDMScheduler(
sigma_min=config.get("sigma_min", 0.002),
sigma_max=config.get("sigma_max", 80.0),
)
)
elif sde == "VE":
return EDMPath(
scheduler=VEEdmScheduler(
sigma_min=config.get("sigma_min", 0.02),
sigma_max=config.get("sigma_max", 100.0),
)
)
elif sde == "VP":
return EDMPath(
scheduler=VPEdmScheduler(
beta_min=config.get("beta_min", 0.1),
beta_max=config.get("beta_max", 19.9),
)
)
else:
raise ValueError(f"Unknown sde type: {sde}")
[docs]
def build_sm_path(sde_type: str, config: dict) -> SMPath:
"""Build a score-matching path from an SDE type string and config.
Parameters
----------
sde_type : str
SDE type: ``"VP"`` or ``"VE"``.
config : dict
Training configuration dict; scheduler hyperparameters are read from
here with sensible defaults.
Returns
-------
SMPath
Configured score-matching path.
Raises
------
ValueError
If ``sde_type`` is not one of ``{"VP", "VE"}``.
"""
if sde_type == "VP":
return SMPath(
VPSmScheduler(
beta_min=config.get("beta_min", 0.001),
beta_max=config.get("beta_max", 3.0),
)
)
elif sde_type == "VE":
return SMPath(
VESmScheduler(
sigma_min=config.get("sigma_min", 0.001),
sigma_max=config.get("sigma_max", 15.0),
)
)
else:
raise ValueError(f"sde_type must be one of ['VP', 'VE'], got {sde_type}.")
[docs]
def parse_training_config(config_path: str):
"""Parse training and optimizer configuration from a YAML config file.
Reads the ``training`` and ``optimizer`` sections of the config and
returns a flat dictionary consumed by :class:`AbstractPipeline`.
Parameters
----------
config_path : str
Path to the YAML configuration file.
Returns
-------
training_config : dict
Parsed training configuration dictionary.
"""
import yaml
with open(config_path, "r") as f:
config = yaml.safe_load(f)
# Training parameters
train_params = config.get("training", {})
multistep = train_params.get("multistep", 1)
training_config = {
"nsteps": train_params.get("nsteps", 30000) * multistep,
"ema_decay": train_params.get("ema_decay", 0.999),
"multistep": multistep,
"experiment_id": train_params.get("experiment_id", 1),
"early_stopping": train_params.get("early_stopping", True),
"val_every": train_params.get("val_every", 100) * multistep,
"val_error_ratio": train_params.get("val_error_ratio", 1.3),
# Optional method-specific parameters (override strategy defaults)
"sigma_min": train_params.get("sigma_min", 0.002),
"sigma_max": train_params.get("sigma_max", 80.0),
}
# Optimizer parameters
opt_params = config.get("optimizer", {})
MAX_LR = opt_params.get("max_lr", 1e-3)
MIN_LR = opt_params.get("min_lr", 0.0)
training_config["max_lr"] = MAX_LR
training_config["min_lr"] = MIN_LR
training_config["min_scale"] = MIN_LR / MAX_LR if MAX_LR > 0 else 0.0
training_config["warmup_steps"] = opt_params.get("warmup_steps", 500)
training_config["decay_transition"] = opt_params.get("decay_transition", 0.85)
# ema_decay can also be specified in optimizer section (backward compat)
if "ema_decay" in opt_params:
training_config["ema_decay"] = opt_params["ema_decay"]
return training_config