Source code for gensbi.core.time_sampling
"""Training-time timestep samplers for flow matching.
A small, pure helper so the timestep distribution is a configurable knob on
:class:`~gensbi.core.flow_matching.FlowMatchingMethod` without touching the
loss, path, or models.
"""
import jax
[docs]
def sample_time(key, n, *, dist="uniform", logitnorm_mean=0.0, logitnorm_std=1.0):
"""Sample ``n`` flow-matching timesteps in ``(0, 1)``.
Parameters
----------
key : jax.random.PRNGKey
n : int
Number of timesteps (batch size).
dist : str
``"uniform"`` (default) -> ``jax.random.uniform(key, (n,))``, bit-identical
to the previous inline sampling so existing runs are unchanged.
``"logitnormal"`` -> ``sigmoid(logitnorm_mean + logitnorm_std * N(0, 1))``
(SD3 / Esser et al.); concentrates mass near ``sigmoid(logitnorm_mean)``.
The reference's ``lognorm_t`` flag is a misnomer for this logit-normal sampler.
logitnorm_mean, logitnorm_std : float
Mean/std of the underlying normal (used only for ``"logitnormal"``).
Returns
-------
jax.Array
Shape ``(n,)`` timesteps.
"""
if dist == "uniform":
return jax.random.uniform(key, (n,))
if dist == "logitnormal":
eps = jax.random.normal(key, (n,))
return jax.nn.sigmoid(logitnorm_mean + logitnorm_std * eps)
raise ValueError(
f"unknown time dist {dist!r}; expected 'uniform' or 'logitnormal'"
)