Skip to content

Gaussian Taylor Filter and Smoother

cuthbert.gaussian.taylor

cuthbert.gaussian.taylor.filter

Linearized Taylor Kalman filter.

Uses automatic differentiation to extract conditionally Gaussian parameters from log densities of the dynamics and observation distributions.

This differs from gaussian/moments, which requires mean and chol_cov functions as input rather than log densities.

I.e., we approximate conditional densities as

\[ p(y \mid x) \approx N(y \mid H x + d, L L^T), \]

and potentials as

\[ G(x) \approx N(x \mid m, L L^T), \]

where \(L\) is the Cholesky factor of the covariance matrix.

See cuthbertlib.linearize for more details.

Parallelism via associative_scan is supported, but requires the state argument to be ignored in get_dynamics_log_density and get_observation_func. I.e. the linearization points are pre-defined or extracted from model inputs.

build_filter(get_init_log_density, get_dynamics_log_density, get_observation_func, associative=False, rtol=None, ignore_nan_dims=False)

Build linearized Taylor Kalman inference filter.

If associative is True all filtering linearization points are pre-defined or extracted from model inputs. The state argument should be ignored in get_dynamics_log_density and get_observation_func.

If associative is False the linearization points can be extracted from the previous filter state for dynamics parameters and the predict state for observation parameters.

Parameters:

Name Type Description Default
get_init_log_density GetInitLogDensity

Function to get log density log p(x_0) and linearization point. Only takes model_inputs as input.

required
get_dynamics_log_density GetDynamicsLogDensity

Function to get dynamics log density log p(x_t+1 | x_t) and linearization points (for the previous and current time points) If associative is True, the state argument should be ignored.

required
get_observation_func GetObservationFunc

Function to get observation function (either conditional log density or log potential), linearization point and optional observation (not required for log potential functions). If associative is True, the state argument should be ignored.

required
associative bool

If True, then the filter is suitable for associative scan, but assumes that the state is ignored in get_dynamics_log_density and get_observation_func. If False, then the filter is suitable for non-associative scan, but the user is free to use the state to extract the linearization points.

False
rtol float | None

The relative tolerance for the singular values of precision matrices when passed to symmetric_inv_sqrt during linearization. Cutoff for small singular values; singular values smaller than rtol * largest_singular_value are treated as zero. The default is determined based on the floating point precision of the dtype. See https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.pinv.html.

None
ignore_nan_dims bool

Whether to treat dimensions with NaN on the diagonal of the precision matrices (found via linearization) as missing and ignore all rows and columns associated with them.

False

Returns:

Type Description
Filter

Linearized Taylor Kalman filter object.

Source code in cuthbert/gaussian/taylor/filter.py
def build_filter(
    get_init_log_density: GetInitLogDensity,
    get_dynamics_log_density: GetDynamicsLogDensity,
    get_observation_func: GetObservationFunc,
    associative: bool = False,
    rtol: float | None = None,
    ignore_nan_dims: bool = False,
) -> Filter:
    """Build linearized Taylor Kalman inference filter.

    If `associative` is True all filtering linearization points are pre-defined or
    extracted from model inputs. The `state` argument should be ignored in
    `get_dynamics_log_density` and `get_observation_func`.

    If `associative` is False the linearization points can be extracted from the
    previous filter state for dynamics parameters and the predict state for
    observation parameters.

    Args:
        get_init_log_density: Function to get log density log p(x_0)
            and linearization point.
            Only takes `model_inputs` as input.
        get_dynamics_log_density: Function to get dynamics log density log p(x_t+1 | x_t)
            and linearization points (for the previous and current time points)
            If `associative` is True, the `state` argument should be ignored.
        get_observation_func: Function to get observation function (either conditional
            log density or log potential), linearization point and optional observation
            (not required for log potential functions).
            If `associative` is True, the `state` argument should be ignored.
        associative: If True, then the filter is suitable for associative scan, but
            assumes that the `state` is ignored in `get_dynamics_log_density` and
            `get_observation_func`.
            If False, then the filter is suitable for non-associative scan, but
            the user is free to use the `state` to extract the linearization points.
        rtol: The relative tolerance for the singular values of precision matrices
            when passed to `symmetric_inv_sqrt` during linearization.
            Cutoff for small singular values; singular values smaller than
            `rtol * largest_singular_value` are treated as zero.
            The default is determined based on the floating point precision of the dtype.
            See https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.pinv.html.
        ignore_nan_dims: Whether to treat dimensions with NaN on the diagonal of the
            precision matrices (found via linearization) as missing and ignore all rows
            and columns associated with them.

    Returns:
        Linearized Taylor Kalman filter object.
    """
    if associative:
        return Filter(
            init_prepare=partial(
                associative_filter.init_prepare,
                get_init_log_density=get_init_log_density,
                get_observation_func=get_observation_func,
                rtol=rtol,
                ignore_nan_dims=ignore_nan_dims,
            ),
            filter_prepare=partial(
                associative_filter.filter_prepare,
                get_init_log_density=get_init_log_density,
                get_dynamics_log_density=get_dynamics_log_density,
                get_observation_func=get_observation_func,
                rtol=rtol,
                ignore_nan_dims=ignore_nan_dims,
            ),
            filter_combine=associative_filter.filter_combine,
            associative=True,
        )
    else:
        return Filter(
            init_prepare=partial(
                non_associative_filter.init_prepare,
                get_init_log_density=get_init_log_density,
                get_observation_func=get_observation_func,
                rtol=rtol,
                ignore_nan_dims=ignore_nan_dims,
            ),
            filter_prepare=partial(
                non_associative_filter.filter_prepare,
                get_init_log_density=get_init_log_density,
            ),
            filter_combine=partial(
                non_associative_filter.filter_combine,
                get_dynamics_log_density=get_dynamics_log_density,
                get_observation_func=get_observation_func,
                rtol=rtol,
                ignore_nan_dims=ignore_nan_dims,
            ),
            associative=False,
        )

cuthbert.gaussian.taylor.smoother

Linearized Taylor Kalman smoother.

Uses automatic differentiation to extract conditionally Gaussian parameters from log densities of the dynamics and observation distributions.

This differs from gaussian/moments, which requires mean and chol_cov functions as input rather than log densities.

I.e., we approximate conditional densities as

\[ p(y \mid x) \approx N(y \mid H x + d, L L^T), \]

and potentials as

\[ G(x) \approx N(x \mid m, L L^T), \]

where \(L\) is the Cholesky factor of the covariance matrix.

See cuthbertlib.linearize for more details.

build_smoother(get_dynamics_log_density, rtol=None, ignore_nan_dims=False, store_gain=False, store_chol_cov_given_next=False)

Build linearized Taylor Kalman inference smoother.

Parameters:

Name Type Description Default
get_dynamics_log_density GetDynamicsLogDensity

Function to get dynamics log density log p(x_t+1 | x_t) and linearization points (for the previous and current time points)

required
rtol float | None

The relative tolerance for the singular values of precision matrices when passed to symmetric_inv_sqrt during linearization. Cutoff for small singular values; singular values smaller than rtol * largest_singular_value are treated as zero. The default is determined based on the floating point precision of the dtype. See https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.pinv.html.

None
ignore_nan_dims bool

Whether to treat dimensions with NaN on the diagonal of the precision matrices (found via linearization) as missing and ignore all rows and columns associated with them.

False
store_gain bool

Whether to store the gain matrix in the smoother state.

False
store_chol_cov_given_next bool

Whether to store the chol_cov_given_next matrix in the smoother state.

False

Returns:

Type Description
Smoother

Linearized Taylor Kalman smoother object, suitable for associative scan.

Source code in cuthbert/gaussian/taylor/smoother.py
def build_smoother(
    get_dynamics_log_density: GetDynamicsLogDensity,
    rtol: float | None = None,
    ignore_nan_dims: bool = False,
    store_gain: bool = False,
    store_chol_cov_given_next: bool = False,
) -> Smoother:
    """Build linearized Taylor Kalman inference smoother.

    Args:
        get_dynamics_log_density: Function to get dynamics log density log p(x_t+1 | x_t)
            and linearization points (for the previous and current time points)
        rtol: The relative tolerance for the singular values of precision matrices
            when passed to `symmetric_inv_sqrt` during linearization.
            Cutoff for small singular values; singular values smaller than
            `rtol * largest_singular_value` are treated as zero.
            The default is determined based on the floating point precision of the dtype.
            See https://docs.jax.dev/en/latest/_autosummary/jax.numpy.linalg.pinv.html.
        ignore_nan_dims: Whether to treat dimensions with NaN on the diagonal of the
            precision matrices (found via linearization) as missing and ignore all rows
            and columns associated with them.
        store_gain: Whether to store the gain matrix in the smoother state.
        store_chol_cov_given_next: Whether to store the chol_cov_given_next matrix
            in the smoother state.

    Returns:
        Linearized Taylor Kalman smoother object, suitable for associative scan.
    """
    return Smoother(
        smoother_prepare=partial(
            smoother_prepare,
            get_dynamics_log_density=get_dynamics_log_density,
            rtol=rtol,
            ignore_nan_dims=ignore_nan_dims,
            store_gain=store_gain,
            store_chol_cov_given_next=store_chol_cov_given_next,
        ),
        smoother_combine=smoother_combine,
        convert_filter_to_smoother_state=partial(
            convert_filter_to_smoother_state,
            store_gain=store_gain,
            store_chol_cov_given_next=store_chol_cov_given_next,
        ),
        associative=True,
    )

cuthbert.gaussian.taylor.types

Provides types for the Taylor-series linearization of Gaussian state-space models.

LogPotential = LogDensity module-attribute

GetInitLogDensity

Bases: Protocol

Protocol for extracting the initial specifications.

__call__(model_inputs)

Get the initial log density and initial linearization point.

Parameters:

Name Type Description Default
model_inputs ArrayTreeLike

Model inputs.

required

Returns:

Type Description
tuple[LogDensity, Array]

Tuple with initial log density and initial linearization point.

Source code in cuthbert/gaussian/taylor/types.py
def __call__(self, model_inputs: ArrayTreeLike) -> tuple[LogDensity, Array]:
    """Get the initial log density and initial linearization point.

    Args:
        model_inputs: Model inputs.

    Returns:
        Tuple with initial log density and initial linearization point.
    """
    ...

GetDynamicsLogDensity

Bases: Protocol

Protocol for extracting the dynamics specifications.

__call__(state, model_inputs)

Get the dynamics log density and linearization points.

Linearization points required for both the previous and current time points

associative_scan only supported when state is ignored.

Parameters:

Name Type Description Default
state LinearizedKalmanFilterState

NamedTuple containing mean and mean_prev attributes.

required
model_inputs ArrayTreeLike

Model inputs.

required

Returns:

Type Description
tuple[LogConditionalDensity, Array, Array]

Tuple with dynamics log density and linearization points.

Source code in cuthbert/gaussian/taylor/types.py
def __call__(
    self,
    state: LinearizedKalmanFilterState,
    model_inputs: ArrayTreeLike,
) -> tuple[LogConditionalDensity, Array, Array]:
    """Get the dynamics log density and linearization points.

    Linearization points required for both the previous and current time points

    `associative_scan` only supported when `state` is ignored.

    Args:
        state: NamedTuple containing `mean` and `mean_prev` attributes.
        model_inputs: Model inputs.

    Returns:
        Tuple with dynamics log density and linearization points.
    """
    ...

GetObservationFunc

Bases: Protocol

Protocol for extracting the required observation specifications.

__call__(state, model_inputs)

Extract observation function, linearization point and optional observation.

State is the predicted state after applying the Kalman dynamics propagation.

associative_scan only supported when state is ignored.

Two types of output are supported: - Observation log density function log p(y | x) and points x and y to linearize around. - Log potential function log G(x) and a linearization point x.

Parameters:

Name Type Description Default
state LinearizedKalmanFilterState

NamedTuple containing mean and mean_prev attributes. Predicted state after applying the Kalman dynamics propagation.

required
model_inputs ArrayTreeLike

Model inputs.

required

Returns:

Type Description
tuple[LogConditionalDensity, Array, Array] | tuple[LogPotential, Array]

Either a tuple with observation function to linearize, linearization point and observation, or a tuple with log potential function and linearization point.

Source code in cuthbert/gaussian/taylor/types.py
def __call__(
    self,
    state: LinearizedKalmanFilterState,
    model_inputs: ArrayTreeLike,
) -> tuple[LogConditionalDensity, Array, Array] | tuple[LogPotential, Array]:
    """Extract observation function, linearization point and optional observation.

    State is the predicted state after applying the Kalman dynamics propagation.

    `associative_scan` only supported when `state` is ignored.

    Two types of output are supported:
    - Observation log density function log p(y | x) and points x and y
        to linearize around.
    - Log potential function log G(x) and a linearization point x.

    Args:
        state: NamedTuple containing `mean` and `mean_prev` attributes.
            Predicted state after applying the Kalman dynamics propagation.
        model_inputs: Model inputs.

    Returns:
        Either a tuple with observation function to linearize, linearization point
            and observation, or a tuple with log potential function and linearization
            point.
    """
    ...