gensbi.flow_matching.path.path#

Classes#

ProbPath

Abstract class, representing a probability path.

Module Contents#

class gensbi.flow_matching.path.path.ProbPath[source]#

Bases: abc.ABC

Abstract class, representing a probability path.

A probability path transforms the distribution \(p(X_0)\) into \(p(X_1)\) over \(t=0\rightarrow 1\).

The ProbPath class is designed to support model training in the flow matching framework. It supports two key functionalities: (1) sampling the conditional probability path and (2) conversion between various training objectives. Here is a high-level example

# Instantiate a probability path
my_path = ProbPath(...)

# Sets t to a random value in [0,1]
key = jax.random.PRNGKey(0)
t = jax.random.uniform(key)

# Samples the conditional path X_t ~ p_t(X_t|X_0,X_1)
path_sample = my_path.sample(x_0=x_0, x_1=x_1, t=t)
assert_sample_shape(x_0, x_1, t)[source]#

Checks that the shapes of x_0, x_1, and t are compatible for sampling.

Parameters:
  • x_0 (Array) – Source data point.

  • x_1 (Array) – Target data point.

  • t (Array) – Time vector.

Raises:

AssertionError – If the shapes are not compatible.

Return type:

None

abstractmethod sample(x_0, x_1, t)[source]#

Sample from an abstract probability path.

Given \((X_0,X_1) \sim \pi(X_0,X_1)\). Returns \(X_0, X_1, X_t \sim p_t(X_t|X_0,X_1)\), and a conditional target \(Y\), all objects are under PathSample.

Parameters:
  • x_0 (Array) – Source data point, shape (batch_size, …).

  • x_1 (Array) – Target data point, shape (batch_size, …).

  • t (Array) – Times in [0,1], shape (batch_size,).

Returns:

A conditional sample.

Return type:

PathSample