Skip to content

Linalg

This sub-repository contains any modular linear algebra primitives that are useful for cuthbert and not already provided by jax.

In particular we have:

  • tria, which computes a lower triangular matrix square root of a given positive definite matrix R such that R @ R.T = A @ A.T for a given matrix A that is not necessarily square.
  • collect_nans_chol, which reorders a generalized Cholesky factor to move a specified subset of rows and columns to the start with remaining dimensions moved to the end and parameterized so that they are ignored in a Bayesian update or logpdf calculation.
  • symmetric_inv_sqrt, which computes the inverse square root of a symmetric matrix. It does so exactly in the case that the matrix is positive definite. In the case of zero or negative singular values, it supports approximate inverse square roots in a similar manner to (Moore-Penrose) pseudo-inversion.

cuthbertlib.linalg.tria

Implements triangularization operator a matrix via QR decomposition.

tria(A)

A triangularization operator using QR decomposition.

Parameters:

Name Type Description Default
A Array

The matrix to triangularize.

required

Returns:

Type Description
Array

A lower triangular matrix \(R\) such that \(R R^\top = A A^\top\).

Reference

Arasaratnam and Haykin (2008): Square-Root Quadrature Kalman Filtering

Source code in cuthbertlib/linalg/tria.py
def tria(A: Array) -> Array:
    r"""A triangularization operator using QR decomposition.

    Args:
        A: The matrix to triangularize.

    Returns:
        A lower triangular matrix $R$ such that $R R^\top = A A^\top$.

    Reference:
        [Arasaratnam and Haykin (2008)](https://ieeexplore.ieee.org/document/4524036): Square-Root Quadrature Kalman Filtering
    """
    _, R = jax.scipy.linalg.qr(A.T, mode="economic")
    return R.T

cuthbertlib.linalg.collect_nans_chol

Implements collection of NaNs and reordering within a Cholesky factor.

collect_nans_chol(flag, chol, *rest)

Converts chol into an order chol factor with NaNs moved to the bottom right.

Specifically, converts a generalized Cholesky factor of a covariance matrix wit NaNs into an ordered generalized Cholesky factor with NaNs rows and columns moved to the end with diagonal elements set to 1.

Also reorders the rest of the arguments in the same way along the first axis and sets to 0 for dimensions where flag is True.

Example behavior:

flag = jnp.array([False, True, False, True])
new_flag, new_chol, new_mean = collect_nans_chol(flag, chol, mean)

Parameters:

Name Type Description Default
flag ArrayLike

Array, boolean array indicating which entries are NaN True for NaN entries, False for valid

required
chol ArrayLike

Array, Cholesky factor of the covariance matrix

required
rest Any

Any, rest of the arguments to be reordered in the same way along the first axis

()

Returns:

Type Description
Any

flag, chol and rest reordered so that valid entries are first and NaNs are last. Diagonal elements of chol are set to 1/√2π so that normalization is correct

Source code in cuthbertlib/linalg/collect_nans_chol.py
def collect_nans_chol(flag: ArrayLike, chol: ArrayLike, *rest: Any) -> Any:
    """Converts chol into an order chol factor with NaNs moved to the bottom right.

    Specifically, converts a generalized Cholesky factor of a covariance matrix wit
    NaNs into an ordered generalized Cholesky factor with NaNs rows and columns
    moved to the end with diagonal elements set to 1.

    Also reorders the rest of the arguments in the same way along the first axis
    and sets to 0 for dimensions where flag is True.

    Example behavior:
    ```
    flag = jnp.array([False, True, False, True])
    new_flag, new_chol, new_mean = collect_nans_chol(flag, chol, mean)
    ```

    Args:
        flag: Array, boolean array indicating which entries are NaN
            True for NaN entries, False for valid
        chol: Array, Cholesky factor of the covariance matrix
        rest: Any, rest of the arguments to be reordered in the same way
            along the first axis

    Returns:
        flag, chol and rest reordered so that valid entries are first and NaNs are last.
            Diagonal elements of chol are set to 1/√2π so that normalization is correct
    """
    flag = jnp.asarray(flag)
    chol = jnp.asarray(chol)

    # TODO: Can we support batching? I.e. when `chol` is a batch of Cholesky factors,
    # possibly with multiple leading dimensions

    if flag.ndim > 1 or chol.ndim > 2:
        raise ValueError("Batched flag or chol not supported yet")

    if not flag.shape:
        return (
            flag,
            _set_to_zero(flag, chol),
            *tree.map(lambda x: _set_to_zero(flag, x), rest),
        )

    if chol.size == 1:
        chol *= jnp.ones_like(flag, dtype=chol.dtype)

    # group the NaN entries together
    argsort = jnp.argsort(flag, stable=True)

    if chol.ndim == 1:
        chol = chol[argsort]
        flag = flag[argsort]
        chol = jnp.where(flag, 1 / jnp.sqrt(2 * jnp.pi), chol)

    else:
        chol = jnp.where(flag[:, None], 0.0, chol)
        chol = chol[argsort]
        # compute the tria of the covariance matrix with NaNs set to 0
        chol = tria(chol)

        flag = flag[argsort]

        # set the diagonal of chol_cov to 1/√2π where nans were present so that normalization is correct
        diag_chol = jnp.diag(chol)
        diag_chol = jnp.where(flag, 1 / jnp.sqrt(2 * jnp.pi), diag_chol)
        diag_indices = jnp.diag_indices_from(chol)
        chol = chol.at[diag_indices].set(diag_chol)

    # Only reorder non-scalar arrays in rest
    rest = tree.map(lambda x: x[argsort] if jnp.asarray(x).shape else x, rest)
    rest = tree.map(lambda x: _set_to_zero(flag, x), rest)

    return flag, chol, *rest

cuthbertlib.linalg.symmetric_inv_sqrt

Implements inverse square root of a symmetric matrix.

symmetric_inv_sqrt(A, rtol=None, ignore_nan_dims=False)

Computes the inverse square root of a symmetric matrix.

I.e., a lower triangular matrix \(L\) such that \(L L^{\top} = A^{-1}\) (for positive definite \(A\)). Note that this is not unique and will generally not match the Cholesky factor of \(A^{-1}\).

For singular matrices, small singular values will be cut off reminiscent of the Moore-Penrose pseudoinverse - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.pinv.html.

In the case of singular or indefinite \(A\), the output will be an approximation and \(L L^{\top} = A^{-1}\) will not hold in general.

Parameters:

Name Type Description Default
A ArrayLike

A symmetric matrix.

required
rtol float | ArrayLike | None

The relative tolerance for the singular values. Cutoff for small singular values; singular values smaller than rtol * largest_singular_value are treated as zero. See https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.pinv.html.

None
ignore_nan_dims bool

Whether to treat dimensions with NaN on the diagonal as missing and ignore all rows and columns associated with them (with result in those dimensions being NaN on the diagonal and zero off-diagonal).

False

Returns:

Type Description
Array

A lower triangular matrix \(L\) such that \(L L^{\top} = A^{-1}\) (for valid dimensions).

Source code in cuthbertlib/linalg/symmetric_inv_sqrt.py
def symmetric_inv_sqrt(
    A: ArrayLike,
    rtol: float | ArrayLike | None = None,
    ignore_nan_dims: bool = False,
) -> Array:
    r"""Computes the inverse square root of a symmetric matrix.

    I.e., a lower triangular matrix $L$ such that $L L^{\top} = A^{-1}$ (for positive definite
    $A$). Note that this is not unique and will generally not match the Cholesky factor
    of $A^{-1}$.

    For singular matrices, small singular values will be cut off reminiscent of
    the Moore-Penrose pseudoinverse - https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.pinv.html.

    In the case of singular or indefinite $A$, the output will be an approximation
    and $L L^{\top} = A^{-1}$ will not hold in general.

    Args:
        A: A symmetric matrix.
        rtol: The relative tolerance for the singular values.
            Cutoff for small singular values; singular values smaller than
            `rtol * largest_singular_value` are treated as zero.
            See https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.pinv.html.
        ignore_nan_dims: Whether to treat dimensions with NaN on the diagonal as missing
            and ignore all rows and columns associated with them (with result in those
            dimensions being NaN on the diagonal and zero off-diagonal).

    Returns:
        A lower triangular matrix $L$ such that $L L^{\top} = A^{-1}$ (for valid dimensions).
    """
    arr = jnp.asarray(A)

    # Check for NaNs on the diagonal (missing dimensions)
    diag_vals = jnp.diag(arr)
    nan_diag_mask = jnp.isnan(diag_vals) * ignore_nan_dims

    # Check for dimensions whose row and column are all 0
    zero_mask = jnp.all(arr == 0.0, axis=0) & jnp.all(arr == 0.0, axis=1)

    nan_mask = nan_diag_mask | zero_mask

    # Sort to group valid dimensions first (needed for SVD to work correctly)
    argsort = jnp.argsort(nan_mask, stable=True)
    arr_sorted = arr[argsort[:, None], argsort]
    nan_mask_sorted = nan_mask[argsort]

    # Zero out invalid dimensions before computation
    invalid_mask_2d = ((nan_mask_sorted[:, None]) | (nan_mask_sorted[None, :])) & (
        ignore_nan_dims
    )
    arr_sorted = jnp.where(invalid_mask_2d, 0.0, arr_sorted)

    # Compute inverse square root on sorted, masked matrix
    L_sorted = _symmetric_inv_sqrt(arr_sorted, rtol)

    # Post-process: zero out invalid rows/cols, set NaN on invalid diagonal
    L_sorted = jnp.where(invalid_mask_2d, 0.0, L_sorted)
    diag_L = jnp.where(nan_mask_sorted, jnp.nan, jnp.diag(L_sorted))
    L_sorted = L_sorted.at[jnp.diag_indices_from(L_sorted)].set(diag_L)

    # Un-sort to restore original order
    inv_argsort = jnp.argsort(argsort)
    L = L_sorted[inv_argsort[:, None], inv_argsort]

    return L

chol_cov_with_nans_to_cov(chol_cov)

Converts a Cholesky factor to a covariance matrix.

NaNs on the diagonal specify dimensions to be ignored.

Parameters:

Name Type Description Default
chol_cov ArrayLike

A Cholesky factor of a covariance matrix with NaNs on the diagonal specifying dimensions to be ignored.

required

Returns:

Type Description
Array

A covariance matrix equivalent to chol_cov @ chol_cov.T in dimensions where the Cholesky factor is valid and for invalid dimensions (ones with NaN on the diagonal in chol_cov) with NaN on the diagonal and zero off-diagonal.

Source code in cuthbertlib/linalg/symmetric_inv_sqrt.py
def chol_cov_with_nans_to_cov(chol_cov: ArrayLike) -> Array:
    """Converts a Cholesky factor to a covariance matrix.

    NaNs on the diagonal specify dimensions to be ignored.

    Args:
        chol_cov: A Cholesky factor of a covariance matrix with NaNs on the diagonal
            specifying dimensions to be ignored.

    Returns:
        A covariance matrix equivalent to chol_cov @ chol_cov.T in dimensions where
            the Cholesky factor is valid and for invalid dimensions (ones with NaN on the
            diagonal in chol_cov) with NaN on the diagonal and zero off-diagonal.
    """
    chol_cov = jnp.asarray(chol_cov)

    nan_mask = jnp.isnan(jnp.diag(chol_cov))

    # Set all rows and columns with invalid diagonal to zero
    chol_cov = jnp.where(nan_mask[:, None] | nan_mask[None, :], 0, chol_cov)

    # Calculate the covariance matrix
    cov = chol_cov @ chol_cov.T

    # Set the diagonal to NaN
    cov = cov.at[jnp.diag_indices_from(cov)].set(
        jnp.where(nan_mask, jnp.nan, jnp.diag(cov))
    )

    return cov