Skip to content

Quadrature based linearization

This sub-repository is concerned with the problem of forming linear \(Y \approx A X + b + \epsilon\) (where \(\epsilon\) is a zero-mean Gaussian with covariance \(Q\)) approximations to general statistical models \(p(y \mid x)\) which either exhibit additive noise:

\[ p(y \mid x) = N(y; f(x), \Sigma) \]

where \(f\) is a deterministic function and \(\Sigma\) a given covariance matrix, or for which the conditional mean and covariance

\[ \mathbb{E}[Y \mid X=x] = m(x), \quad \mathbb{V}[Y \mid X=x] = c(x) \]

are known or can be approximated otherwise. This approximation is done by minimizing (approximately for the latter, exactly for the former) the expected Kullback-Leibler divergence

\[ A, b, Q = \textrm{arg min } \mathbb{E}_{N(X \mid m, P)}\left[\textrm{KL}(p(y \mid X) \| N(y; AX + b, Q)\right]. \]

This can be done either directly in the covariance form (where \(\Sigma\) is provided and \(Q\) is obtained as covariance matrices) or in the square-root form, more stable but computationally more expensive (where \(\Sigma\) is provided as a Cholesky decomposition and \(L\) obtained such that \(Q = L L^{T}\) is the covariance matrix of interest).

A typical call to the library would then be:

mean_fn = lambda x: jnp.sin(x)
cov_fn = lambda x: 1e-3 * jnp.eye(2)
quadrature_method = quadrature.gauss_hermite.weights(n_dim=2, order=3)
m = jnp.zeros((2,))
cov = jnp.eye(2)
A, b, Q = quadrature.conditional_moments(mean_fn, cov_fn, m, cov, quadrature_method, mode="covariance")

cuthbertlib.quadrature.cubature

Implements cubature quadrature.

__all__ = ['weights', 'CubatureQuadrature'] module-attribute

CubatureQuadrature

Bases: NamedTuple

Cubature quadrature.

Attributes:

Name Type Description
wm ArrayLike

The mean weights.

wc ArrayLike

The covariance weights.

xi ArrayLike

The sigma points.

wm instance-attribute

wc instance-attribute

xi instance-attribute

get_sigma_points(m, chol)

Get the sigma points.

Parameters:

Name Type Description Default
m ArrayLike

The mean.

required
chol ArrayLike

The Cholesky factor of the covariance.

required

Returns:

Name Type Description
SigmaPoints SigmaPoints

The sigma points.

Source code in cuthbertlib/quadrature/cubature.py
def get_sigma_points(self, m: ArrayLike, chol: ArrayLike) -> SigmaPoints:
    """Get the sigma points.

    Args:
        m: The mean.
        chol: The Cholesky factor of the covariance.

    Returns:
        SigmaPoints: The sigma points.
    """
    return get_sigma_points(m, chol, self.xi, self.wm, self.wc)

get_sigma_points(m, chol, xi, wm, wc)

Source code in cuthbertlib/quadrature/cubature.py
def get_sigma_points(
    m: ArrayLike, chol: ArrayLike, xi: ArrayLike, wm: ArrayLike, wc: ArrayLike
) -> SigmaPoints:
    # TODO: Add docstring here
    m = jnp.asarray(m)
    chol = jnp.asarray(chol)
    xi = jnp.asarray(xi)
    wm = jnp.asarray(wm)
    wc = jnp.asarray(wc)
    sigma_points = m[None, :] + jnp.dot(chol, xi.T).T

    return SigmaPoints(sigma_points, wm, wc)

weights(n_dim)

Computes the weights associated with the spherical cubature method.

The number of sigma-points is 2 * n_dim.

Parameters:

Name Type Description Default
n_dim int

Dimensionality of the problem.

required

Returns:

Type Description
Quadrature

The quadrature object with the weights and sigma-points.

References

Simo Särkkä, Lennard Svensson. Bayesian Filtering and Smoothing. In: Cambridge University Press 2023.

Source code in cuthbertlib/quadrature/cubature.py
def weights(n_dim: int) -> Quadrature:
    """Computes the weights associated with the spherical cubature method.

    The number of sigma-points is 2 * n_dim.

    Args:
        n_dim: Dimensionality of the problem.

    Returns:
        The quadrature object with the weights and sigma-points.

    References:
        Simo Särkkä, Lennard Svensson. *Bayesian Filtering and Smoothing.*
            In: Cambridge University Press 2023.
    """
    wm = np.ones(shape=(2 * n_dim,)) / (2 * n_dim)
    wc = wm
    xi = np.concatenate([np.eye(n_dim), -np.eye(n_dim)], axis=0) * np.sqrt(n_dim)

    return CubatureQuadrature(wm=wm, wc=wc, xi=xi)

cuthbertlib.quadrature.gauss_hermite

Implements Gauss-Hermite quadrature.

__all__ = ['weights', 'GaussHermiteQuadrature'] module-attribute

GaussHermiteQuadrature

Bases: NamedTuple

Gauss-Hermite quadrature.

Attributes:

Name Type Description
wm ArrayLike

The mean weights.

wc ArrayLike

The covariance weights.

xi ArrayLike

The sigma points.

wm instance-attribute

wc instance-attribute

xi instance-attribute

get_sigma_points(m, chol)

Get the sigma points.

Parameters:

Name Type Description Default
m ArrayLike

The mean.

required
chol ArrayLike

The Cholesky factor of the covariance.

required

Returns:

Name Type Description
SigmaPoints SigmaPoints

The sigma points.

Source code in cuthbertlib/quadrature/gauss_hermite.py
def get_sigma_points(self, m: ArrayLike, chol: ArrayLike) -> SigmaPoints:
    """Get the sigma points.

    Args:
        m: The mean.
        chol: The Cholesky factor of the covariance.

    Returns:
        SigmaPoints: The sigma points.
    """
    return cubature.get_sigma_points(m, chol, self.xi, self.wm, self.wc)

weights(n_dim, order=3)

Computes the weights associated with the Gauss-Hermite quadrature method.

The Hermite polynomial is in the probabilist's version.

Parameters:

Name Type Description Default
n_dim int

Dimensionality of the problem.

required
order int

The order of Hermite polynomial. Defaults to 3.

3

Returns:

Type Description
Quadrature

The quadrature object with the weights and sigma-points.

References

Simo Särkkä. Bayesian Filtering and Smoothing. In: Cambridge University Press 2013.

Source code in cuthbertlib/quadrature/gauss_hermite.py
def weights(n_dim: int, order: int = 3) -> Quadrature:
    """Computes the weights associated with the Gauss-Hermite quadrature method.

    The Hermite polynomial is in the probabilist's version.

    Args:
        n_dim: Dimensionality of the problem.
        order: The order of Hermite polynomial. Defaults to 3.

    Returns:
        The quadrature object with the weights and sigma-points.

    References:
        Simo Särkkä. *Bayesian Filtering and Smoothing.*
            In: Cambridge University Press 2013.
    """
    x, w = hermegauss(order)
    xn = np.array(list(product(*(x,) * n_dim)))
    wn = np.prod(np.array(list(product(*(w,) * n_dim))), 1)
    wn /= np.sqrt(2 * np.pi) ** n_dim
    return GaussHermiteQuadrature(wm=wn, wc=wn, xi=xn)

cuthbertlib.quadrature.linearize

Implements quadrature-based linearization of conditional moments and functional.

__all__ = ['conditional_moments', 'functional'] module-attribute

conditional_moments(mean_fn, cov_fn, m, cov, quadrature, mode='covariance')

Linearizes the conditional mean and covariance of a Gaussian distribution.

Parameters:

Name Type Description Default
mean_fn Callable[[ArrayLike], Array]

The mean function \(\mathbb{E}[Y \mid x] =\) mean_fn(x).

required
cov_fn Callable[[ArrayLike], Array]

The covariance function \(\mathbb{C}[Y \mid x] =\) cov_fn(x).

required
m ArrayLike

The mean of the Gaussian distribution.

required
cov ArrayLike

The covariance of the Gaussian distribution.

required
quadrature Quadrature

The quadrature object with the weights and sigma-points.

required
mode str

The mode of the covariance. Default is 'covariance', which means that cov and cov_fn are given as covariance matrices. Otherwise, the Cholesky factor of the covariances are given.

'covariance'

Returns:

Type Description
tuple[Array, Array, Array]

A, b, Q where A, b are the linearized model parameters and Q is either given as a full covariance matrix or as a square root factor depending on the mode.

Source code in cuthbertlib/quadrature/linearize.py
def conditional_moments(
    mean_fn: Callable[[ArrayLike], Array],
    cov_fn: Callable[[ArrayLike], Array],
    m: ArrayLike,
    cov: ArrayLike,
    quadrature: Quadrature,
    mode: str = "covariance",
) -> tuple[Array, Array, Array]:
    r"""Linearizes the conditional mean and covariance of a Gaussian distribution.

    Args:
        mean_fn: The mean function $\mathbb{E}[Y \mid x] =$ `mean_fn(x)`.
        cov_fn: The covariance function $\mathbb{C}[Y \mid x] =$ `cov_fn(x)`.
        m: The mean of the Gaussian distribution.
        cov: The covariance of the Gaussian distribution.
        quadrature: The quadrature object with the weights and sigma-points.
        mode: The mode of the covariance. Default is 'covariance', which means that cov
            and cov_fn are given as covariance matrices.
            Otherwise, the Cholesky factor of the covariances are given.

    Returns:
        A, b, Q where A, b are the linearized model parameters and Q is either given as
            a full covariance matrix or as a square root factor depending on
            the `mode`.
    """
    if mode == "covariance":
        chol = jnp.linalg.cholesky(cov)
    else:
        chol = cov
    x_pts: SigmaPoints = quadrature.get_sigma_points(m, chol)

    f_pts = SigmaPoints(vmap(mean_fn)(x_pts.points), x_pts.wm, x_pts.wc)
    Psi_x = x_pts.covariance(f_pts)

    A = cho_solve((chol, True), Psi_x).T
    b = f_pts.mean - A @ m
    if mode != "covariance":
        # This can probably be abstracted better.
        sqrt_Phi = f_pts.sqrt

        chol_pts = vmap(cov_fn)(x_pts.points)
        temp = jnp.sqrt(x_pts.wc[:, None, None]) * chol_pts

        # concatenate the blocks properly, it's a bit urk, but what can you do...
        temp = jnp.transpose(temp, [1, 0, 2]).reshape(temp.shape[1], -1)
        chol_Q = tria(jnp.concatenate([sqrt_Phi, temp], axis=1))
        chol_Q = cholesky_update_many(chol_Q, (A @ chol).T, -1.0)
        return A, b, chol_Q

    V_pts = vmap(cov_fn)(x_pts.points)
    v_f = jnp.sum(x_pts.wc[:, None, None] * V_pts, 0)

    Phi = f_pts.covariance()
    Q = Phi + v_f - A @ cov @ A.T

    return A, b, 0.5 * (Q + Q.T)

functional(fn, S, m, cov, quadrature, mode='covariance')

Linearizes a nonlinear function of a Gaussian distribution.

For a given Gaussian distribution \(p(x) = N(x \mid m, P)\), and \(Y = f(X) + \epsilon\), where \(\epsilon\) is a zero-mean Gaussian noise with covariance S, this function computes an approximation \(Y = A X + b + \epsilon\) using the sigma points method given by get_sigma_points.

Parameters:

Name Type Description Default
fn Callable[[ArrayLike], Array]

The function \(Y = f(X) + N(0, S)\). Because the function is linearized, the function should be vectorized.

required
S ArrayLike

The covariance of the noise.

required
m ArrayLike

The mean of the Gaussian distribution.

required
cov ArrayLike

The covariance of the Gaussian distribution.

required
quadrature Quadrature

The quadrature object with the weights and sigma-points.

required
mode str

The mode of the covariance. Default is 'covariance', which means that cov and cov_fn are given as covariance matrices. Otherwise, the Cholesky factor of the covariances are given.

'covariance'

Returns:

Type Description
tuple[Array, Array, Array]

A, b, Q: The linearized model parameters \(Y = A X + b + N(0, Q)\). Q is either given as a full covariance matrix or as a square root factor depending on the mode.

Notes

We do not support non-additive noise in this method. If you have a non-additive noise, you should use the conditional_moments or the Taylor linearization method. Another solution is to form the covariance function using the quadrature method itself. For example, if you have a function \(f(x, q)\), where \(q\) is a zero-mean random variable with covariance S, you can form the mean and covariance function as follows:

def linearize_q_part(x):
    n_dim = S.shape[0]
    m_q = jnp.zeros(n_dim)
    A, b, Q = functional(lambda x: f(x, q_sigma_points.points), 0. * S, m_q, S, quadrature, mode)
    return A, b, Q

def cov_fn(x):
    A, b, Q = linearize_q_part(x)
    return Q + A @ S @ A.T

def mean_fn(x):
    A, b, Q = linearize_q_part(x)
    m_q = jnp.zeros(n_dim)
    return b + f(x, m_q)

This technique is a bit wasteful due to our current separation of duties between the mean and covariance functions, but as we develop the library further, we will provide a more elegant solution.

Source code in cuthbertlib/quadrature/linearize.py
def functional(
    fn: Callable[[ArrayLike], Array],
    S: ArrayLike,
    m: ArrayLike,
    cov: ArrayLike,
    quadrature: Quadrature,
    mode: str = "covariance",
) -> tuple[Array, Array, Array]:
    r"""Linearizes a nonlinear function of a Gaussian distribution.

    For a given Gaussian distribution $p(x) = N(x \mid m, P)$,
    and $Y = f(X) + \epsilon$, where $\epsilon$ is a zero-mean Gaussian noise
    with covariance S, this function computes an approximation $Y = A X + b + \epsilon$
    using the sigma points method given by get_sigma_points.

    Args:
        fn: The function $Y = f(X) + N(0, S)$.
            Because the function is linearized, the function should be vectorized.
        S: The covariance of the noise.
        m: The mean of the Gaussian distribution.
        cov: The covariance of the Gaussian distribution.
        quadrature: The quadrature object with the weights and sigma-points.
        mode: The mode of the covariance. Default is 'covariance', which means that cov
            and cov_fn are given as covariance matrices. Otherwise, the Cholesky factor
            of the covariances are given.

    Returns:
        A, b, Q: The linearized model parameters $Y = A X + b + N(0, Q)$.
            Q is either given as a full covariance matrix or as a square root factor depending on the `mode`.

    Notes:
        We do not support non-additive noise in this method.
        If you have a non-additive noise, you should use the `conditional_moments` or
        the Taylor linearization method.
        Another solution is to form the covariance function using the quadrature method
        itself. For example, if you have a function $f(x, q)$, where $q$ is a zero-mean
        random variable with covariance `S`,
        you can form the mean and covariance function as follows:

        ```python
        def linearize_q_part(x):
            n_dim = S.shape[0]
            m_q = jnp.zeros(n_dim)
            A, b, Q = functional(lambda x: f(x, q_sigma_points.points), 0. * S, m_q, S, quadrature, mode)
            return A, b, Q

        def cov_fn(x):
            A, b, Q = linearize_q_part(x)
            return Q + A @ S @ A.T

        def mean_fn(x):
            A, b, Q = linearize_q_part(x)
            m_q = jnp.zeros(n_dim)
            return b + f(x, m_q)
        ```

        This technique is a bit wasteful due to our current separation of duties between
        the mean and covariance functions, but as we develop the library further, we
        will provide a more elegant solution.
    """

    # make the equivalent conditional_moments model
    def mean_fn(x):
        return fn(x)

    def cov_fn(x):
        return jnp.asarray(S)

    return conditional_moments(mean_fn, cov_fn, m, cov, quadrature, mode)

cuthbertlib.quadrature.unscented

Implements unscented quadrature.

__all__ = ['weights', 'UnscentedQuadrature'] module-attribute

UnscentedQuadrature

Bases: NamedTuple

Unscented quadrature.

Attributes:

Name Type Description
wm Array

The mean weights.

wc Array

The covariance weights.

lamda float

The lambda parameter.

wm instance-attribute

wc instance-attribute

lamda instance-attribute

get_sigma_points(m, chol)

Get the sigma points.

Parameters:

Name Type Description Default
m ArrayLike

The mean.

required
chol ArrayLike

The Cholesky factor of the covariance.

required

Returns:

Name Type Description
SigmaPoints SigmaPoints

The sigma points.

Source code in cuthbertlib/quadrature/unscented.py
def get_sigma_points(self, m: ArrayLike, chol: ArrayLike) -> SigmaPoints:
    """Get the sigma points.

    Args:
        m: The mean.
        chol: The Cholesky factor of the covariance.

    Returns:
        SigmaPoints: The sigma points.
    """
    m = jnp.asarray(m)
    chol = jnp.asarray(chol)

    n_dim = m.shape[0]
    scaled_chol = jnp.sqrt(n_dim + self.lamda) * chol

    zeros = jnp.zeros((1, n_dim))
    sigma_points = m[None, :] + jnp.concatenate(
        [zeros, scaled_chol.T, -scaled_chol.T], axis=0
    )
    return SigmaPoints(sigma_points, self.wm, self.wc)

weights(n_dim, alpha=0.5, beta=2.0, kappa=None)

Computes the weights associated with the unscented cubature method.

The number of sigma-points is 2 * n_dim. This method is also known as the Unscented Transform, and generalizes the cubature.py weights: the cubature method is a special case of the unscented for the parameters alpha=1.0, beta=0.0, kappa=0.0.

Parameters:

Name Type Description Default
n_dim int

Dimension of the space.

required
alpha float

Parameter of the unscented transform, default is 0.5.

0.5
beta float

Parameter of the unscented transform, default is 2.0.

2.0
kappa float | None

Parameter of the unscented transform, default is 3 + n_dim.

None

Returns:

Name Type Description
UnscentedQuadrature UnscentedQuadrature

The quadrature object with the weights and sigma-points.

References
  • https://groups.seas.harvard.edu/courses/cs281/papers/unscented.pdf
Source code in cuthbertlib/quadrature/unscented.py
def weights(
    n_dim: int, alpha: float = 0.5, beta: float = 2.0, kappa: float | None = None
) -> UnscentedQuadrature:
    """Computes the weights associated with the unscented cubature method.

    The number of sigma-points is 2 * n_dim.
    This method is also known as the Unscented Transform, and generalizes the
    `cubature.py` weights: the cubature method is a special case of the unscented
    for the parameters `alpha=1.0`, `beta=0.0`, `kappa=0.0`.

    Args:
        n_dim: Dimension of the space.
        alpha: Parameter of the unscented transform, default is 0.5.
        beta: Parameter of the unscented transform, default is 2.0.
        kappa: Parameter of the unscented transform, default is 3 + n_dim.

    Returns:
        UnscentedQuadrature: The quadrature object with the weights and sigma-points.

    References:
        - https://groups.seas.harvard.edu/courses/cs281/papers/unscented.pdf
    """
    if kappa is None:
        kappa = 3.0 + n_dim

    lamda = alpha**2 * (n_dim + kappa) - n_dim
    wm = jnp.full(2 * n_dim + 1, 1 / (2 * (n_dim + lamda)))

    wm = wm.at[0].set(lamda / (n_dim + lamda))
    wc = wm.at[0].set(lamda / (n_dim + lamda) + (1 - alpha**2 + beta))
    return UnscentedQuadrature(wm=wm, wc=wc, lamda=lamda)