Skip to content

Kalman

This sub-repository provides modular functions for Kalman filtering and smoothing.

The core functions are:

  • predict: Single prediction step.
  • filter_update: Single update step.
  • smoother_update: Single Rauch-Tung-Striebel smoothing step.

Together, predict and filter_update can be used to perform an online filtering step.

In all cases, we operate on the square-root form of the covariance matrix, which is more numerically stable (in low-precision floating point arithmetic) as the outputs are guaranteed to be positive-definite. This means we also require input covariance matrices to be provided in square-root (Cholesky) form.

cuthbertlib.kalman.filtering

Implements the square root parallel Kalman filter and associative variant.

FilterScanElement

Bases: NamedTuple

Arrays carried through the Kalman filter scan.

A instance-attribute

b instance-attribute

U instance-attribute

eta instance-attribute

Z instance-attribute

ell instance-attribute

predict(m, chol_P, F, c, chol_Q)

Propagate the mean and square root covariance through linear Gaussian dynamics.

Parameters:

Name Type Description Default
m ArrayLike

Mean of the state.

required
chol_P ArrayLike

Generalized Cholesky factor of the covariance of the state.

required
F ArrayLike

Transition matrix.

required
c ArrayLike

Transition shift.

required
chol_Q ArrayLike

Generalized Cholesky factor of the transition noise covariance.

required

Returns:

Type Description
tuple[Array, Array]

Propagated mean and square root covariance.

References

Paper: G. J. Bierman, Factorization Methods for Discrete Sequential Estimation, Code: https://github.com/EEA-sensors/sqrt-parallel-smoothers/tree/main/parsmooth/sequential

Source code in cuthbertlib/kalman/filtering.py
def predict(
    m: ArrayLike,
    chol_P: ArrayLike,
    F: ArrayLike,
    c: ArrayLike,
    chol_Q: ArrayLike,
) -> tuple[Array, Array]:
    """Propagate the mean and square root covariance through linear Gaussian dynamics.

    Args:
        m: Mean of the state.
        chol_P: Generalized Cholesky factor of the covariance of the state.
        F: Transition matrix.
        c: Transition shift.
        chol_Q: Generalized Cholesky factor of the transition noise covariance.

    Returns:
        Propagated mean and square root covariance.

    References:
        Paper: G. J. Bierman, Factorization Methods for Discrete Sequential Estimation,
        Code: https://github.com/EEA-sensors/sqrt-parallel-smoothers/tree/main/parsmooth/sequential
    """
    m, chol_P = jnp.asarray(m), jnp.asarray(chol_P)
    F, c, chol_Q = jnp.asarray(F), jnp.asarray(c), jnp.asarray(chol_Q)
    m1 = F @ m + c
    A = jnp.concatenate([F @ chol_P, chol_Q], axis=1)
    chol_P1 = tria(A)
    return m1, chol_P1

update(m, chol_P, H, d, chol_R, y, log_normalizing_constant=0.0)

Update the mean and square root covariance with a linear Gaussian observation.

Parameters:

Name Type Description Default
m ArrayLike

Mean of the state.

required
chol_P ArrayLike

Generalized Cholesky factor of the covariance of the state.

required
H ArrayLike

Observation matrix.

required
d ArrayLike

Observation shift.

required
chol_R ArrayLike

Generalized Cholesky factor of the observation noise covariance.

required
y ArrayLike

Observation.

required
log_normalizing_constant ArrayLike

Optional input of log normalizing constant to be added to log normalizing constant of the Bayesian update.

0.0

Returns:

Type Description
tuple[tuple[Array, Array], Array]

Updated mean and square root covariance as well as the log marginal likelihood.

References

Paper: G. J. Bierman, Factorization Methods for Discrete Sequential Estimation, Code: https://github.com/EEA-sensors/sqrt-parallel-smoothers/tree/main/parsmooth/sequential

Source code in cuthbertlib/kalman/filtering.py
def update(
    m: ArrayLike,
    chol_P: ArrayLike,
    H: ArrayLike,
    d: ArrayLike,
    chol_R: ArrayLike,
    y: ArrayLike,
    log_normalizing_constant: ArrayLike = 0.0,
) -> tuple[tuple[Array, Array], Array]:
    """Update the mean and square root covariance with a linear Gaussian observation.

    Args:
        m: Mean of the state.
        chol_P: Generalized Cholesky factor of the covariance of the state.
        H: Observation matrix.
        d: Observation shift.
        chol_R: Generalized Cholesky factor of the observation noise covariance.
        y: Observation.
        log_normalizing_constant: Optional input of log normalizing constant to be added to
            log normalizing constant of the Bayesian update.

    Returns:
        Updated mean and square root covariance as well as the log marginal likelihood.

    References:
        Paper: G. J. Bierman, Factorization Methods for Discrete Sequential Estimation,
        Code: https://github.com/EEA-sensors/sqrt-parallel-smoothers/tree/main/parsmooth/sequential
    """
    # Handle case where there is no observation
    flag = jnp.isnan(y)
    flag, chol_R, H, d, y = collect_nans_chol(flag, chol_R, H, d, y)

    m, chol_P = jnp.asarray(m), jnp.asarray(chol_P)
    H, d, chol_R = jnp.asarray(H), jnp.asarray(d), jnp.asarray(chol_R)
    y = jnp.asarray(y)

    n_y, n_x = H.shape

    y_hat = H @ m + d
    y_diff = y - y_hat

    M = jnp.block(
        [
            [H @ chol_P, chol_R],
            [chol_P, jnp.zeros((n_x, n_y), dtype=chol_P.dtype)],
        ]
    )
    chol_S = tria(M)
    chol_Py = chol_S[n_y:, n_y:]

    Gmat = chol_S[n_y:, :n_y]
    Imat = chol_S[:n_y, :n_y]

    my = m + Gmat @ solve_triangular(Imat, y_diff, lower=True)

    ell = multivariate_normal.logpdf(y, y_hat, Imat, nan_support=False)
    return (my, chol_Py), jnp.asarray(ell + log_normalizing_constant)

associative_params_single(F, c, chol_Q, H, d, chol_R, y)

Single time step for scan element for square root parallel Kalman filter.

Parameters:

Name Type Description Default
F Array

State transition matrix.

required
c Array

State transition shift vector.

required
chol_Q Array

Generalized Cholesky factor of the state transition noise covariance.

required
H Array

Observation matrix.

required
d Array

Observation shift.

required
chol_R Array

Generalized Cholesky factor of the observation noise covariance.

required
y Array

Observation.

required

Returns:

Type Description
FilterScanElement

Prepared scan element for the square root parallel Kalman filter.

Source code in cuthbertlib/kalman/filtering.py
def associative_params_single(
    F: Array, c: Array, chol_Q: Array, H: Array, d: Array, chol_R: Array, y: Array
) -> FilterScanElement:
    """Single time step for scan element for square root parallel Kalman filter.

    Args:
        F: State transition matrix.
        c: State transition shift vector.
        chol_Q: Generalized Cholesky factor of the state transition noise covariance.
        H: Observation matrix.
        d: Observation shift.
        chol_R: Generalized Cholesky factor of the observation noise covariance.
        y: Observation.

    Returns:
        Prepared scan element for the square root parallel Kalman filter.
    """
    # Handle case where there is no observation
    flag = jnp.isnan(y)
    flag, chol_R, H, d, y = collect_nans_chol(flag, chol_R, H, d, y)

    ny, nx = H.shape

    # joint over the predictive and the observation
    Psi_ = jnp.block([[H @ chol_Q, chol_R], [chol_Q, jnp.zeros((nx, ny))]])

    Tria_Psi_ = tria(Psi_)

    Psi11 = Tria_Psi_[:ny, :ny]
    Psi21 = Tria_Psi_[ny : ny + nx, :ny]
    U = Tria_Psi_[ny : ny + nx, ny:]

    # pre-compute inverse of Psi11: we apply it to matrices and vectors alike.
    Psi11_inv = solve_triangular(Psi11, jnp.eye(ny), lower=True)

    # predictive model given one observation
    K = Psi21 @ Psi11_inv  # local Kalman gain
    HF = H @ F  # temporary variable
    A = F - K @ HF  # corrected transition matrix

    b = c + K @ (y - H @ c - d)  # corrected transition offset

    # information filter
    Z = HF.T @ Psi11_inv.T
    eta = Psi11_inv @ (y - H @ c - d)
    eta = Z @ eta

    if nx > ny:
        Z = jnp.concatenate([Z, jnp.zeros((nx, nx - ny))], axis=1)
    else:
        Z = tria(Z)

    # local log marginal likelihood
    ell = jnp.asarray(
        multivariate_normal.logpdf(y, H @ c + d, Psi11, nan_support=False)
    )

    return FilterScanElement(A, b, U, eta, Z, ell)

filtering_operator(elem_i, elem_j)

Binary associative operator for the square root Kalman filter.

Parameters:

Name Type Description Default
elem_i FilterScanElement

Filter scan element for the previous time step.

required
elem_j FilterScanElement

Filter scan element for the current time step.

required

Returns:

Name Type Description
FilterScanElement FilterScanElement

The output of the associative operator applied to the input elements.

Source code in cuthbertlib/kalman/filtering.py
def filtering_operator(
    elem_i: FilterScanElement, elem_j: FilterScanElement
) -> FilterScanElement:
    """Binary associative operator for the square root Kalman filter.

    Args:
        elem_i: Filter scan element for the previous time step.
        elem_j: Filter scan element for the current time step.

    Returns:
        FilterScanElement: The output of the associative operator applied to the input elements.
    """
    A1, b1, U1, eta1, Z1, ell1 = elem_i
    A2, b2, U2, eta2, Z2, ell2 = elem_j

    nx = Z2.shape[0]

    Xi = jnp.block([[U1.T @ Z2, jnp.eye(nx)], [Z2, jnp.zeros_like(A1)]])
    tria_xi = tria(Xi)
    Xi11 = tria_xi[:nx, :nx]
    Xi21 = tria_xi[nx : nx + nx, :nx]
    Xi22 = tria_xi[nx : nx + nx, nx:]

    tmp_1 = solve_triangular(Xi11, U1.T, lower=True).T
    D_inv = jnp.eye(nx) - tmp_1 @ Xi21.T
    tmp_2 = D_inv @ (b1 + U1 @ (U1.T @ eta2))

    A = A2 @ D_inv @ A1
    b = A2 @ tmp_2 + b2
    U = tria(jnp.concatenate([A2 @ tmp_1, U2], axis=1))
    eta = A1.T @ (D_inv.T @ (eta2 - Z2 @ (Z2.T @ b1))) + eta1
    Z = tria(jnp.concatenate([A1.T @ Xi22, Z1], axis=1))

    mu = cho_solve((U1, True), b1)
    t1 = b1 @ mu - (eta2 + mu) @ tmp_2
    ell = ell1 + ell2 - 0.5 * t1 + 0.5 * jnp.linalg.slogdet(D_inv)[1]

    return FilterScanElement(A, b, U, eta, Z, ell)

cuthbertlib.kalman.smoothing

Implements the square root Rauch–Tung–Striebel (RTS) smoother and associative variant.

SmootherScanElement

Bases: NamedTuple

Kalman smoother scan element.

g instance-attribute

E instance-attribute

D instance-attribute

update(filter_m, filter_chol_P, smoother_m, smoother_chol_P, F, c, chol_Q)

Single step of the square root Rauch–Tung–Striebel (RTS) smoother.

Parameters:

Name Type Description Default
filter_m ArrayLike

Mean of the filtered state.

required
filter_chol_P ArrayLike

Generalized Cholesky factor of the filtering covariance.

required
smoother_m ArrayLike

Mean of the smoother state.

required
smoother_chol_P ArrayLike

Generalized Cholesky factor of the smoothing covariance.

required
F ArrayLike

State transition matrix.

required
c ArrayLike

State transition shift vector.

required
chol_Q ArrayLike

Generalized Cholesky factor of the state transition noise covariance.

required

Returns:

Type Description
tuple[Array, Array]

A tuple (smooth_state, info).

Array

smooth_state contains the smoothed mean and square root covariance.

tuple[tuple[Array, Array], Array]

info contains the smoothing gain matrix.

References

Paper: Park and Kailath (1994) - Square-root RTS smoothing algorithms Code: https://github.com/EEA-sensors/sqrt-parallel-smoothers/tree/main/parsmooth/sequential

Source code in cuthbertlib/kalman/smoothing.py
def update(
    filter_m: ArrayLike,
    filter_chol_P: ArrayLike,
    smoother_m: ArrayLike,
    smoother_chol_P: ArrayLike,
    F: ArrayLike,
    c: ArrayLike,
    chol_Q: ArrayLike,
) -> tuple[tuple[Array, Array], Array]:
    """Single step of the square root Rauch–Tung–Striebel (RTS) smoother.

    Args:
        filter_m: Mean of the filtered state.
        filter_chol_P: Generalized Cholesky factor of the filtering covariance.
        smoother_m: Mean of the smoother state.
        smoother_chol_P: Generalized Cholesky factor of the smoothing covariance.
        F: State transition matrix.
        c: State transition shift vector.
        chol_Q: Generalized Cholesky factor of the state transition noise covariance.

    Returns:
        A tuple `(smooth_state, info)`.
        `smooth_state` contains the smoothed mean and square root covariance.
        `info` contains the smoothing gain matrix.

    References:
        Paper: Park and Kailath (1994) - Square-root RTS smoothing algorithms
        Code: https://github.com/EEA-sensors/sqrt-parallel-smoothers/tree/main/parsmooth/sequential
    """
    filter_m, filter_chol_P = jnp.asarray(filter_m), jnp.asarray(filter_chol_P)
    smoother_m, smoother_chol_P = jnp.asarray(smoother_m), jnp.asarray(smoother_chol_P)
    F, c, chol_Q = jnp.asarray(F), jnp.asarray(c), jnp.asarray(chol_Q)

    nx = F.shape[0]
    Phi = jnp.block([[F @ filter_chol_P, chol_Q], [filter_chol_P, jnp.zeros_like(F)]])
    tria_Phi = tria(Phi)
    Phi11 = tria_Phi[:nx, :nx]
    Phi21 = tria_Phi[nx:, :nx]
    Phi22 = tria_Phi[nx:, nx:]
    gain = solve_triangular(Phi11, Phi21.T, trans=True, lower=True).T

    mean_diff = smoother_m - (c + F @ filter_m)
    mean = filter_m + gain @ mean_diff
    chol = tria(jnp.concatenate([Phi22, gain @ smoother_chol_P], axis=1))
    return (mean, chol), gain

associative_params_single(m, chol_P, F, c, chol_Q)

Single time step for scan element for square root parallel Kalman smoother.

Parameters:

Name Type Description Default
m Array

Mean of the smoother state.

required
chol_P Array

Generalized Cholesky factor of the smoothing covariance.

required
F Array

State transition matrix.

required
c Array

State transition shift vector.

required
chol_Q Array

Generalized Cholesky factor of the state transition noise covariance.

required

Returns:

Name Type Description
SmootherScanElement SmootherScanElement

The output of the associative operator applied to the input elements.

Source code in cuthbertlib/kalman/smoothing.py
def associative_params_single(
    m: Array,
    chol_P: Array,
    F: Array,
    c: Array,
    chol_Q: Array,
) -> SmootherScanElement:
    """Single time step for scan element for square root parallel Kalman smoother.

    Args:
        m: Mean of the smoother state.
        chol_P: Generalized Cholesky factor of the smoothing covariance.
        F: State transition matrix.
        c: State transition shift vector.
        chol_Q: Generalized Cholesky factor of the state transition noise covariance.

    Returns:
        SmootherScanElement: The output of the associative operator applied to the input
            elements.
    """
    nx = chol_Q.shape[0]

    Phi = jnp.block([[F @ chol_P, chol_Q], [chol_P, jnp.zeros_like(chol_Q)]])
    Tria_Phi = tria(Phi)
    Phi11 = Tria_Phi[:nx, :nx]
    Phi21 = Tria_Phi[nx:, :nx]
    D = Tria_Phi[nx:, nx:]

    E = jax.scipy.linalg.solve_triangular(Phi11.T, Phi21.T).T
    g = m - E @ (F @ m + c)
    return SmootherScanElement(g, E, D)

smoothing_operator(elem_i, elem_j)

Binary associative operator for the square root Kalman smoother.

Parameters:

Name Type Description Default
elem_i SmootherScanElement

Smoother scan element.

required
elem_j SmootherScanElement

Smoother scan element.

required

Returns:

Name Type Description
SmootherScanElement SmootherScanElement

The output of the associative operator applied to the input elements.

Source code in cuthbertlib/kalman/smoothing.py
def smoothing_operator(
    elem_i: SmootherScanElement, elem_j: SmootherScanElement
) -> SmootherScanElement:
    """Binary associative operator for the square root Kalman smoother.

    Args:
        elem_i: Smoother scan element.
        elem_j: Smoother scan element.

    Returns:
        SmootherScanElement: The output of the associative operator applied to the input elements.
    """
    g_i, E_i, D_i = elem_i
    g_j, E_j, D_j = elem_j

    g = E_j @ g_i + g_j
    E = E_j @ E_i
    D = tria(jnp.concatenate([E_j @ D_i, D_j], axis=1))

    return SmootherScanElement(g, E, D)

cuthbertlib.kalman.sampling

Implements the square root parallel Kalman associative operator for sampling.

Samples from the smoothing distribution without doing the smoothing scan for means and (chol) covariances.

SamplerScanElement

Bases: NamedTuple

Kalman sampling scan element.

gain instance-attribute

sample instance-attribute

sqrt_associative_params(key, ms, chol_Ps, Fs, cs, chol_Qs, shape)

Compute the sampler scan elements.

Source code in cuthbertlib/kalman/sampling.py
def sqrt_associative_params(
    key: ArrayLike,
    ms: Array,
    chol_Ps: Array,
    Fs: Array,
    cs: Array,
    chol_Qs: Array,
    shape: Sequence[int],
) -> SamplerScanElement:
    """Compute the sampler scan elements."""
    shape = tuple(shape)
    eps = jax.random.normal(key, ms.shape[:1] + shape + ms.shape[1:])
    interm_elems = jax.vmap(_sqrt_associative_params_interm)(
        ms[:-1], chol_Ps[:-1], Fs, cs, chol_Qs, eps[:-1]
    )
    last_elem = _sqrt_associative_params_final(ms[-1], chol_Ps[-1], eps[-1])
    return jax.tree.map(
        lambda x, y: jnp.concatenate([x, y[None]]), interm_elems, last_elem
    )

sampling_operator(elem_i, elem_j)

Binary associative operator for sampling.

Source code in cuthbertlib/kalman/sampling.py
def sampling_operator(
    elem_i: SamplerScanElement, elem_j: SamplerScanElement
) -> SamplerScanElement:
    """Binary associative operator for sampling."""
    G_i, e_i = elem_i
    G_j, e_j = elem_j
    G = G_j @ G_i
    e = e_i @ G_j.T + e_j
    return SamplerScanElement(G, e)