Source code for gensbi.solver
"""
Abstract base class for solvers.
"""
from abc import ABC, abstractmethod
from typing import Any
from jax import Array
[docs]
class Solver(ABC):
"""Abstract base class for generative model solvers."""
@abstractmethod
[docs]
def sample(self, *args, **kwargs) -> Array:
"""
Sample from the solver.
Parameters
----------
*args : Any
Positional arguments.
**kwargs : Any
Keyword arguments.
Returns
-------
Array
Sampled output from the solver.
"""
... # pragma: no cover
[docs]
def get_log_prob(self, *args, **kwargs):
"""Return a callable that computes the log-probability.
Only supported by solvers that can evaluate the continuous
change-of-variables formula (e.g. ``ODESolver``).
Raises
------
NotImplementedError
If the solver does not support log-probability computation.
"""
raise NotImplementedError(
f"{type(self).__name__} does not support log-probability computation."
)
[docs]
def compute_log_prob(self, *args, **kwargs):
"""Compute the log-probability for given samples.
Only supported by solvers that can evaluate the continuous
change-of-variables formula (e.g. ``ODESolver``).
Raises
------
NotImplementedError
If the solver does not support log-probability computation.
"""
raise NotImplementedError(
f"{type(self).__name__} does not support log-probability computation."
)