gensbi.flow_matching.path.path#
Classes#
Abstract class, representing a probability path. |
Module Contents#
- class gensbi.flow_matching.path.path.ProbPath[source]#
Bases:
abc.ABCAbstract class, representing a probability path.
A probability path transforms the distribution \(p(X_0)\) into \(p(X_1)\) over \(t=0\rightarrow 1\).
The
ProbPathclass is designed to support model training in the flow matching framework. It supports two key functionalities: (1) sampling the conditional probability path and (2) conversion between various training objectives. Here is a high-level example# Instantiate a probability path my_path = ProbPath(...) # Sets t to a random value in [0,1] key = jax.random.PRNGKey(0) t = jax.random.uniform(key) # Samples the conditional path X_t ~ p_t(X_t|X_0,X_1) path_sample = my_path.sample(x_0=x_0, x_1=x_1, t=t)
- assert_sample_shape(x_0, x_1, t)[source]#
Checks that the shapes of x_0, x_1, and t are compatible for sampling.
- Parameters:
x_0 (Array) – Source data point.
x_1 (Array) – Target data point.
t (Array) – Time vector.
- Raises:
AssertionError – If the shapes are not compatible.
- Return type:
None
- abstractmethod sample(x_0, x_1, t)[source]#
Sample from an abstract probability path.
Given \((X_0,X_1) \sim \pi(X_0,X_1)\). Returns \(X_0, X_1, X_t \sim p_t(X_t|X_0,X_1)\), and a conditional target \(Y\), all objects are under
PathSample.- Parameters:
x_0 (Array) – Source data point, shape (batch_size, …).
x_1 (Array) – Target data point, shape (batch_size, …).
t (Array) – Times in [0,1], shape (batch_size,).
- Returns:
A conditional sample.
- Return type: