Skip to content

Gaussian Moments Filter and Smoother

cuthbert.gaussian.moments

cuthbert.gaussian.moments.filter

Linearized moments Kalman filter.

Takes a user provided conditional mean and chol_cov functions to define a conditionally linear Gaussian state space model.

I.e., we approximate conditional densities as

\[ p(y \mid x) \approx N(y \mid \mathrm{mean}(x), \mathrm{chol\_cov}(x) @ \mathrm{chol\_cov}(x)^\top). \]

See cuthbertlib.linearize for more details.

Parallelism via associative_scan is supported, but requires the state argument to be ignored in get_dynamics_params and get_observation_params. I.e. the linearization points are pre-defined or extracted from model inputs.

build_filter(get_init_params, get_dynamics_params, get_observation_params, associative=False)

Build linearized moments Kalman inference filter.

If associative is True all filtering linearization points are pre-defined or extracted from model inputs. The state argument should be ignored in get_dynamics_params and get_observation_params.

If associative is False the linearization points can be extracted from the previous filter state for dynamics parameters and the predict state for observation parameters.

Parameters:

Name Type Description Default
get_init_params GetInitParams

Function to get m0, chol_P0 from model inputs.

required
get_dynamics_params GetDynamicsMoments

Function to get dynamics conditional mean and (generalised) Cholesky covariance from linearization point and model inputs. and linearization points (for the previous and current time points) If associative is True, the state argument should be ignored.

required
get_observation_params GetObservationMoments

Function to get observation conditional mean, (generalised) Cholesky covariance and observation from linearization point and model inputs. If associative is True, the state argument should be ignored.

required
associative bool

If True, then the filter is suitable for associative scan, but assumes that the state is ignored in get_dynamics_params and get_observation_params. If False, then the filter is suitable for non-associative scan, but the user is free to use the state to extract the linearization points.

False

Returns:

Type Description
Filter

Linearized moments Kalman filter object.

Source code in cuthbert/gaussian/moments/filter.py
def build_filter(
    get_init_params: GetInitParams,
    get_dynamics_params: GetDynamicsMoments,
    get_observation_params: GetObservationMoments,
    associative: bool = False,
) -> Filter:
    """Build linearized moments Kalman inference filter.

    If `associative` is True all filtering linearization points are pre-defined or
    extracted from model inputs. The `state` argument should be ignored in
    `get_dynamics_params` and `get_observation_params`.

    If `associative` is False the linearization points can be extracted from the
    previous filter state for dynamics parameters and the predict state for
    observation parameters.

    Args:
        get_init_params: Function to get m0, chol_P0 from model inputs.
        get_dynamics_params: Function to get dynamics conditional mean and
            (generalised) Cholesky covariance from linearization point and model inputs.
            and linearization points (for the previous and current time points)
            If `associative` is True, the `state` argument should be ignored.
        get_observation_params: Function to get observation conditional mean,
            (generalised) Cholesky covariance and observation from linearization point
            and model inputs.
            If `associative` is True, the `state` argument should be ignored.
        associative: If True, then the filter is suitable for associative scan, but
            assumes that the `state` is ignored in `get_dynamics_params` and
            `get_observation_params`.
            If False, then the filter is suitable for non-associative scan, but
            the user is free to use the `state` to extract the linearization points.

    Returns:
        Linearized moments Kalman filter object.
    """
    if associative:
        return Filter(
            init_prepare=partial(
                associative_filter.init_prepare,
                get_init_params=get_init_params,
                get_observation_params=get_observation_params,
            ),
            filter_prepare=partial(
                associative_filter.filter_prepare,
                get_init_params=get_init_params,
                get_dynamics_params=get_dynamics_params,
                get_observation_params=get_observation_params,
            ),
            filter_combine=associative_filter.filter_combine,
            associative=True,
        )
    else:
        return Filter(
            init_prepare=partial(
                non_associative_filter.init_prepare,
                get_init_params=get_init_params,
                get_observation_params=get_observation_params,
            ),
            filter_prepare=partial(
                non_associative_filter.filter_prepare,
                get_init_params=get_init_params,
            ),
            filter_combine=partial(
                non_associative_filter.filter_combine,
                get_dynamics_params=get_dynamics_params,
                get_observation_params=get_observation_params,
            ),
            associative=False,
        )

cuthbert.gaussian.moments.smoother

Linearized moments Kalman smoother.

Takes a user provided conditional mean and chol_cov functions to define a conditionally linear Gaussian state space model.

I.e., we approximate conditional densities as

\[ p(y \mid x) \approx N(y \mid \mathrm{mean}(x), \mathrm{chol\_cov}(x) @ \mathrm{chol\_cov}(x)^\top). \]

See cuthbertlib.linearize for more details.

Parallelism via associative_scan is supported, but requires the state argument to be ignored in get_dynamics_params. I.e. the linearization points are pre-defined or extracted from model inputs.

build_smoother(get_dynamics_params, store_gain=False, store_chol_cov_given_next=False)

Build linearized moments Kalman inference smoother for conditionally Gaussian SSMs.

Parameters:

Name Type Description Default
get_dynamics_params GetDynamicsMoments

Function to get dynamics conditional mean and (generalised) Cholesky covariance from linearization point and model inputs.

required
store_gain bool

Whether to store the gain matrix in the smoother state.

False
store_chol_cov_given_next bool

Whether to store the chol_cov_given_next matrix in the smoother state.

False

Returns:

Type Description
Smoother

Linearized moments Kalman smoother object, suitable for associative scan.

Source code in cuthbert/gaussian/moments/smoother.py
def build_smoother(
    get_dynamics_params: GetDynamicsMoments,
    store_gain: bool = False,
    store_chol_cov_given_next: bool = False,
) -> Smoother:
    """Build linearized moments Kalman inference smoother for conditionally Gaussian SSMs.

    Args:
        get_dynamics_params: Function to get dynamics conditional mean and
            (generalised) Cholesky covariance from linearization point and model inputs.
        store_gain: Whether to store the gain matrix in the smoother state.
        store_chol_cov_given_next: Whether to store the chol_cov_given_next matrix
            in the smoother state.

    Returns:
        Linearized moments Kalman smoother object, suitable for associative scan.
    """
    return Smoother(
        smoother_prepare=partial(
            smoother_prepare,
            get_dynamics_params=get_dynamics_params,
            store_gain=store_gain,
            store_chol_cov_given_next=store_chol_cov_given_next,
        ),
        smoother_combine=smoother_combine,
        convert_filter_to_smoother_state=partial(
            convert_filter_to_smoother_state,
            store_gain=store_gain,
            store_chol_cov_given_next=store_chol_cov_given_next,
        ),
        associative=True,
    )

cuthbert.gaussian.moments.types

Provides types for the moment-based linearization of Gaussian state-space models.

GetDynamicsMoments

Bases: Protocol

Protocol for extracting the dynamics specifications.

__call__(state, model_inputs)

Get dynamics conditional mean and chol_cov function and linearization point.

associative_scan only supported when state is ignored.

Parameters:

Name Type Description Default
state LinearizedKalmanFilterState

NamedTuple containing mean and mean_prev attributes.

required
model_inputs ArrayTreeLike

Model inputs.

required

Returns:

Type Description
tuple[MeanAndCholCovFunc, Array]

Tuple with dynamics conditional mean and (generalised) Cholesky covariance function and linearization point.

Source code in cuthbert/gaussian/moments/types.py
def __call__(
    self,
    state: LinearizedKalmanFilterState,
    model_inputs: ArrayTreeLike,
) -> tuple[MeanAndCholCovFunc, Array]:
    """Get dynamics conditional mean and chol_cov function and linearization point.

    `associative_scan` only supported when `state` is ignored.

    Args:
        state: NamedTuple containing `mean` and `mean_prev` attributes.
        model_inputs: Model inputs.

    Returns:
        Tuple with dynamics conditional mean and (generalised) Cholesky covariance
            function and linearization point.
    """
    ...

GetObservationMoments

Bases: Protocol

Protocol for extracting the observation specifications.

__call__(state, model_inputs)

Get conditional mean and chol_cov function, linearization point and observation.

associative_scan only supported when state input is ignored.

Parameters:

Name Type Description Default
state LinearizedKalmanFilterState

NamedTuple containing mean and mean_prev attributes.

required
model_inputs ArrayTreeLike

Model inputs.

required

Returns:

Type Description
tuple[MeanAndCholCovFunc, Array, Array]

Tuple with conditional mean and chol_cov function, linearization point and observation.

Source code in cuthbert/gaussian/moments/types.py
def __call__(
    self, state: LinearizedKalmanFilterState, model_inputs: ArrayTreeLike
) -> tuple[MeanAndCholCovFunc, Array, Array]:
    """Get conditional mean and chol_cov function, linearization point and observation.

    `associative_scan` only supported when `state` input is ignored.

    Args:
        state: NamedTuple containing `mean` and `mean_prev` attributes.
        model_inputs: Model inputs.

    Returns:
        Tuple with conditional mean and chol_cov function, linearization point
            and observation.
    """
    ...