Skip to content

Linearize

This sub-repository provides functions for linearizing conditional distributions with automatic differentiation into a linear-Gaussian form. That is, form an approximate Gaussian defined by the tuple \((H, d, L)\) such that

\[ \log p(y \mid x) \approx -\frac{1}{2}(y - H x - d)^T (LL^T)^{-1} (y - H x - d) + \text{const}. \]

Additionally, some linearization techniques may apply to an unconditional potential \(G(x)\) and return a tuple \((m, L)\) such that

\[ \log G(x) \approx -\frac{1}{2}(x - m)^T (L L^T)^{-1} (x - m) + \text{const}. \]

The former approach requires a conditional distribution that is differentiable with respect to \(x\) and \(y\). The latter approach only requires differentiability with respect to \(x\) and therefore works with e.g. discrete or non-ordinal \(y\).

Linearization techniques

  • linearize_log_density: Linearize a conditional log density around given points.
  • linearize_moments: Linearize conditional mean and Cholesky covariance functions around a given point.
  • linearize_taylor: Linearize a log potential function around a given point using Taylor expansion.

Linearization with sigma points can also be found in the [quadrature] sub-repository.

Example usage

Specifically for linearize_log_density, the usage is as follows:

from linearize import linearize_log_density

def log_density(x, y):
    ... # some conditional log density function log p(y|x) that returns a scalar

x, y = ... # some input points

mat, shift, chol_cov = linearize_log_density(log_density, x, y)

Note that when log_density is exactly linear Gaussian, then the output from linearize_log_density is exact for all points x and y. For non-linear and/or non-Gaussian log_density, the output is an approximation that will truncate any singular values of the precision matrix (negative Hessian of log_density).

cuthbertlib.linearize.log_density

Implements linearization of conditional log densities.

linearize_log_density(log_density, x, y, has_aux=False, rtol=None, ignore_nan_dims=False)

linearize_log_density(log_density: LogConditionalDensity, x: ArrayLike, y: ArrayLike, has_aux: bool = False, rtol: float | None = None, ignore_nan_dims: bool = False) -> tuple[Array, Array, Array]
linearize_log_density(log_density: LogConditionalDensityAux, x: ArrayLike, y: ArrayLike, has_aux: bool = True, rtol: float | None = None, ignore_nan_dims: bool = False) -> tuple[Array, Array, Array, ArrayTree]

Linearizes a conditional log density around given points.

The linearization is exact in the case of a linear-Gaussian log_density, i.e., it returns \((H, d, L)\) if log_density is of the form

\[ \log p(y \mid x) = -\frac{1}{2}(y - H x - d)^\top (LL^\top)^{-1} (y - H x - d) + \textrm{const}. \]

The Cholesky factor of the covariance is calculated using the negative Hessian of log_density with respect to y as the precision matrix. symmetric_inv_sqrt is used to calculate the inverse square root by ignoring any singular values that are sufficiently close to zero (this is a projection in the case the Hessian is not positive definite).

Alternatively, the Cholesky factor can be provided directly in linearize_log_density_given_chol_cov.

Parameters:

Name Type Description Default
log_density LogConditionalDensity | LogConditionalDensityAux

A conditional log density of y given x. Returns a scalar.

required
x ArrayLike

The input points.

required
y ArrayLike

The output points.

required
has_aux bool

Whether log_density returns an auxiliary value.

False
rtol float | None

The relative tolerance for the singular values of the precision matrix when passed to symmetric_inv_sqrt. Cutoff for small singular values; singular values smaller than rtol * largest_singular_value are treated as zero. The default is determined based on the floating point precision of the dtype. See https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.pinv.html.

None
ignore_nan_dims bool

Whether to treat dimensions with NaN on the diagonal of the precision matrix as missing and ignore all rows and columns associated with them.

False

Returns:

Type Description
tuple[Array, Array, Array] | tuple[Array, Array, Array, ArrayTree]

Linearized matrix, shift, and Cholesky factor of the covariance matrix. The auxiliary value is also returned if has_aux is True.

Source code in cuthbertlib/linearize/log_density.py
def linearize_log_density(
    log_density: LogConditionalDensity | LogConditionalDensityAux,
    x: ArrayLike,
    y: ArrayLike,
    has_aux: bool = False,
    rtol: float | None = None,
    ignore_nan_dims: bool = False,
) -> tuple[Array, Array, Array] | tuple[Array, Array, Array, ArrayTree]:
    r"""Linearizes a conditional log density around given points.

    The linearization is exact in the case of a linear-Gaussian `log_density`, i.e., it returns
    $(H, d, L)$ if `log_density` is of the form

    $$
    \log p(y \mid x) = -\frac{1}{2}(y - H x - d)^\top (LL^\top)^{-1} (y - H x - d) + \textrm{const}.
    $$

    The Cholesky factor of the covariance is calculated using the negative Hessian
    of `log_density` with respect to `y` as the precision matrix.
    `symmetric_inv_sqrt` is used to calculate the inverse square root by
    ignoring any singular values that are sufficiently close to zero
    (this is a projection in the case the Hessian is not positive definite).

    Alternatively, the Cholesky factor can be provided directly
    in `linearize_log_density_given_chol_cov`.

    Args:
        log_density: A conditional log density of y given x. Returns a scalar.
        x: The input points.
        y: The output points.
        has_aux: Whether `log_density` returns an auxiliary value.
        rtol: The relative tolerance for the singular values of the precision matrix
            when passed to `symmetric_inv_sqrt`.
            Cutoff for small singular values; singular values smaller than
            `rtol * largest_singular_value` are treated as zero.
            The default is determined based on the floating point precision of the dtype.
            See https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.pinv.html.
        ignore_nan_dims: Whether to treat dimensions with NaN on the diagonal of the
            precision matrix as missing and ignore all rows and columns associated with
            them.

    Returns:
        Linearized matrix, shift, and Cholesky factor of the covariance matrix.
            The auxiliary value is also returned if `has_aux` is `True`.
    """
    prec_and_maybe_aux = hessian(log_density, 1, has_aux=has_aux)(x, y)
    prec = -prec_and_maybe_aux[0] if has_aux else -prec_and_maybe_aux
    if ignore_nan_dims:
        prec_diag = jnp.diag(prec)
        nan_mask = jnp.isnan(y) | jnp.isnan(prec_diag)
        prec = prec.at[jnp.diag_indices_from(prec)].set(
            jnp.where(nan_mask, jnp.nan, prec_diag)
        )

    chol_cov = symmetric_inv_sqrt(prec, rtol=rtol, ignore_nan_dims=ignore_nan_dims)
    mat, shift, *extra = linearize_log_density_given_chol_cov(
        log_density, x, y, chol_cov, has_aux=has_aux, ignore_nan_dims=ignore_nan_dims
    )
    return mat, shift, chol_cov, *extra

linearize_log_density_given_chol_cov(log_density, x, y, chol_cov, has_aux=False, ignore_nan_dims=False)

linearize_log_density_given_chol_cov(log_density: LogConditionalDensity, x: ArrayLike, y: ArrayLike, chol_cov: ArrayLike, has_aux: bool = False, ignore_nan_dims: bool = False) -> tuple[Array, Array]
linearize_log_density_given_chol_cov(log_density: LogConditionalDensityAux, x: ArrayLike, y: ArrayLike, chol_cov: ArrayLike, has_aux: bool = True, ignore_nan_dims: bool = False) -> tuple[Array, Array, ArrayTree]

Linearizes a conditional log density around given points.

The linearization is exact in the case of a linear-Gaussian log_density, i.e., it returns \((H, d)\) if log_density is of the form

\[ \log p(y \mid x) = -\frac{1}{2}(y - H x - d)^\top (LL^\top)^{-1} (y - H x - d) + \textrm{const}, \]

where \(L\) is the argument chol_cov.

Parameters:

Name Type Description Default
log_density LogConditionalDensity | LogConditionalDensityAux

A conditional log density of y given x. Returns a scalar.

required
x ArrayLike

The input points.

required
y ArrayLike

The output points.

required
chol_cov ArrayLike

The Cholesky factor of the covariance matrix of the Gaussian.

required
has_aux bool

Whether log_density returns an auxiliary value.

False
ignore_nan_dims bool

Whether to ignore dimensions with NaN on the diagonal of the precision matrix or in y.

False

Returns:

Type Description
tuple[Array, Array] | tuple[Array, Array, ArrayTree]

Linearized matrix and shift. The auxiliary value is also returned if has_aux is True.

Source code in cuthbertlib/linearize/log_density.py
def linearize_log_density_given_chol_cov(
    log_density: LogConditionalDensity | LogConditionalDensityAux,
    x: ArrayLike,
    y: ArrayLike,
    chol_cov: ArrayLike,
    has_aux: bool = False,
    ignore_nan_dims: bool = False,
) -> tuple[Array, Array] | tuple[Array, Array, ArrayTree]:
    r"""Linearizes a conditional log density around given points.

    The linearization is exact in the case of a linear-Gaussian `log_density`, i.e., it returns
    $(H, d)$ if `log_density` is of the form

    $$
    \log p(y \mid x) = -\frac{1}{2}(y - H x - d)^\top (LL^\top)^{-1} (y - H x - d) + \textrm{const},
    $$

    where $L$ is the argument `chol_cov`.

    Args:
        log_density: A conditional log density of y given x. Returns a scalar.
        x: The input points.
        y: The output points.
        chol_cov: The Cholesky factor of the covariance matrix of the Gaussian.
        has_aux: Whether `log_density` returns an auxiliary value.
        ignore_nan_dims: Whether to ignore dimensions with NaN on the diagonal of the
            precision matrix or in y.

    Returns:
        Linearized matrix and shift. The auxiliary value is also returned if `has_aux` is `True`.
    """
    chol_cov = jnp.asarray(chol_cov)

    cov = (
        chol_cov_with_nans_to_cov(chol_cov)
        if ignore_nan_dims
        else chol_cov @ chol_cov.T
    )

    if has_aux:

        def grad_log_density_wrapper_aux(x, y):
            g, aux = grad(log_density, 1, has_aux=True)(x, y)
            return g, (g, aux)

        jac, (g, *extra) = jacobian(grad_log_density_wrapper_aux, 0, has_aux=True)(x, y)
    else:

        def grad_log_density_wrapper(x, y):
            g = grad(log_density, 1)(x, y)
            return g, (g,)

        jac, (g, *extra) = jacobian(grad_log_density_wrapper, 0, has_aux=True)(x, y)

    mat = cov @ jac
    shift = y - mat @ x + cov @ g
    return mat, shift, *extra

cuthbertlib.linearize.moments

Implements moment-based linearization.

MeanAndCholCovFunc = Callable[[ArrayLike], tuple[Array, Array]] module-attribute

MeanAndCholCovFuncAux = Callable[[ArrayLike], tuple[Array, Array, ArrayTree]] module-attribute

linearize_moments(mean_and_chol_cov_function, x, has_aux=False)

linearize_moments(mean_and_chol_cov_function: MeanAndCholCovFunc, x: ArrayLike, has_aux: bool = False) -> tuple[Array, Array, Array]
linearize_moments(mean_and_chol_cov_function: MeanAndCholCovFuncAux, x: ArrayLike, has_aux: bool = True) -> tuple[Array, Array, Array, ArrayTree]

Linearizes conditional mean and chol_cov functions into a linear-Gaussian form.

Takes a function mean_and_chol_cov_function(x) that returns the conditional mean and Cholesky factor of the covariance matrix of the distribution \(p(y \mid x)\) for a given input x.

Returns \((H, d, L)\) defining a linear-Gaussian approximation to the conditional distribution \(p(y \mid x) \approx N(y \mid H x + d, L L^\top)\).

mean_and_chol_cov_function has the following signature with has_aux = False:

m, chol = mean_and_chol_cov_function(x)
or with has_aux = True:
m, chol, aux = mean_and_chol_cov_function(x)

Parameters:

Name Type Description Default
mean_and_chol_cov_function MeanAndCholCovFunc | MeanAndCholCovFuncAux

A callable that returns the conditional mean and Cholesky factor of the covariance matrix of the distribution for a given input.

required
x ArrayLike

The point to linearize around.

required
has_aux bool

Whether mean_and_chol_cov_function returns an auxiliary value.

False

Returns:

Type Description
tuple[Array, Array, Array] | tuple[Array, Array, Array, ArrayTree]

Linearized matrix, shift, and Cholesky factor of the covariance matrix. The auxiliary value is also returned if has_aux is True.

References
Source code in cuthbertlib/linearize/moments.py
def linearize_moments(
    mean_and_chol_cov_function: MeanAndCholCovFunc | MeanAndCholCovFuncAux,
    x: ArrayLike,
    has_aux: bool = False,
) -> tuple[Array, Array, Array] | tuple[Array, Array, Array, ArrayTree]:
    r"""Linearizes conditional mean and chol_cov functions into a linear-Gaussian form.

    Takes a function `mean_and_chol_cov_function(x)` that returns the
    conditional mean and Cholesky factor of the covariance matrix of the distribution
    $p(y \mid x)$ for a given input `x`.

    Returns $(H, d, L)$ defining a linear-Gaussian approximation to the conditional
    distribution $p(y \mid x) \approx N(y \mid H x + d, L L^\top)$.

    `mean_and_chol_cov_function` has the following signature with `has_aux` = False:
    ```
    m, chol = mean_and_chol_cov_function(x)
    ```
    or with `has_aux` = True:
    ```
    m, chol, aux = mean_and_chol_cov_function(x)
    ```

    Args:
        mean_and_chol_cov_function: A callable that returns the conditional mean and
            Cholesky factor of the covariance matrix of the distribution for a given
            input.
        x: The point to linearize around.
        has_aux: Whether `mean_and_chol_cov_function` returns an auxiliary value.

    Returns:
        Linearized matrix, shift, and Cholesky factor of the covariance matrix.
            The auxiliary value is also returned if `has_aux` is `True`.

    References:
        - [sqrt-parallel-smoothers](https://github.com/EEA-sensors/sqrt-parallel-smoothers/blob/main/parsmooth/linearization/_extended.py)
    """
    if has_aux:
        mean_and_chol_cov_function = cast(
            MeanAndCholCovFuncAux, mean_and_chol_cov_function
        )

        def mean_and_chol_cov_function_wrapper_aux(
            x: ArrayLike,
        ) -> tuple[Array, tuple[Array, Array, ArrayTree]]:
            mean, chol_cov, aux = mean_and_chol_cov_function(x)
            return mean, (mean, chol_cov, aux)

        F, (m, *extra) = jax.jacfwd(
            mean_and_chol_cov_function_wrapper_aux, has_aux=True
        )(x)

    else:
        mean_and_chol_cov_function = cast(
            MeanAndCholCovFunc, mean_and_chol_cov_function
        )

        def mean_and_chol_cov_function_wrapper(
            x: ArrayLike,
        ) -> tuple[Array, tuple[Array, Array]]:
            mean, chol_cov = mean_and_chol_cov_function(x)
            return mean, (mean, chol_cov)

        F, (m, *extra) = jax.jacfwd(mean_and_chol_cov_function_wrapper, has_aux=True)(x)

    b = m - F @ x
    return F, b, *extra

cuthbertlib.linearize.taylor

Implements Taylor-like linearization.

linearize_taylor(log_potential, x, has_aux=False, rtol=None, ignore_nan_dims=False)

linearize_taylor(log_potential: Callable[[ArrayLike], Array], x: ArrayLike, has_aux: bool = False, rtol: float | None = None, ignore_nan_dims: bool = False) -> tuple[Array, Array]
linearize_taylor(log_potential: Callable[[ArrayLike], tuple[Array, ArrayTree]], x: ArrayLike, has_aux: bool = True, rtol: float | None = None, ignore_nan_dims: bool = False) -> tuple[Array, Array, ArrayTree]

Linearizes a log potential function around a given point using Taylor expansion.

Unlike the other linearization methods, this applies to a potential function with no required notion of observation \(y\) or conditional dependence.

Instead we have the linearization

\[ \log G(x) = -\frac{1}{2} (x - m)^\top (L L^\top)^{-1} (x - m). \]

Parameters:

Name Type Description Default
log_potential Callable[[ArrayLike], Array] | Callable[[ArrayLike], tuple[Array, ArrayTree]]

A callable that returns a non-negative scalar. Does not need to be a normalized probability density in its input.

required
x ArrayLike

The point to linearize around.

required
has_aux bool

Whether log_potential returns an auxiliary value.

False
rtol float | None

The relative tolerance for the singular values of the precision matrix when passed to symmetric_inv_sqrt. Cutoff for small singular values; singular values smaller than rtol * largest_singular_value are treated as zero. The default is determined based on the floating point precision of the dtype. See https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.pinv.html.

None
ignore_nan_dims bool

Whether to treat dimensions with NaN on the diagonal of the precision matrix as missing and ignore all rows and columns associated with them.

False

Returns:

Type Description
tuple[Array, Array] | tuple[Array, Array, ArrayTree]

Linearized mean and Cholesky factor of the covariance matrix. The auxiliary value is also returned if has_aux is True.

Source code in cuthbertlib/linearize/taylor.py
def linearize_taylor(
    log_potential: Callable[[ArrayLike], Array]
    | Callable[[ArrayLike], tuple[Array, ArrayTree]],
    x: ArrayLike,
    has_aux: bool = False,
    rtol: float | None = None,
    ignore_nan_dims: bool = False,
) -> tuple[Array, Array] | tuple[Array, Array, ArrayTree]:
    r"""Linearizes a log potential function around a given point using Taylor expansion.

    Unlike the other linearization methods, this applies to a potential function
    with no required notion of observation $y$ or conditional dependence.

    Instead we have the linearization

    $$
    \log G(x) = -\frac{1}{2} (x - m)^\top (L L^\top)^{-1} (x - m).
    $$

    Args:
        log_potential: A callable that returns a non-negative scalar. Does not need
            to be a normalized probability density in its input.
        x: The point to linearize around.
        has_aux: Whether `log_potential` returns an auxiliary value.
        rtol: The relative tolerance for the singular values of the precision matrix
            when passed to `symmetric_inv_sqrt`.
            Cutoff for small singular values; singular values smaller than
            `rtol * largest_singular_value` are treated as zero.
            The default is determined based on the floating point precision of the dtype.
            See https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.pinv.html.
        ignore_nan_dims: Whether to treat dimensions with NaN on the diagonal of the
            precision matrix as missing and ignore all rows and columns associated with
            them.

    Returns:
        Linearized mean and Cholesky factor of the covariance matrix.
            The auxiliary value is also returned if `has_aux` is `True`.
    """
    g_and_maybe_aux = jax.grad(log_potential, has_aux=has_aux)(x)
    prec_and_maybe_aux = jax.hessian(log_potential, has_aux=has_aux)(x)

    g, aux = g_and_maybe_aux if has_aux else (g_and_maybe_aux, None)
    prec = -prec_and_maybe_aux[0] if has_aux else -prec_and_maybe_aux

    L = symmetric_inv_sqrt(prec, rtol=rtol, ignore_nan_dims=ignore_nan_dims)

    # Change nans on diag to zeros for L @ L.T @ g, still retain nans on diag for L for bookkeeping
    # If ignore_nan_dims, change all rows and columns with nans on the diagonal to 0
    L_diag = jnp.diag(L)
    nan_mask = jnp.isnan(L_diag) * ignore_nan_dims
    L_temp = jnp.where(nan_mask[:, None] | nan_mask[None, :], 0.0, L)
    m = x + L_temp @ L_temp.T @ g
    return (m, L, aux) if has_aux else (m, L)