Skip to content

Filtering

cuthbert.filtering

Unified cuthbert filtering interface.

filter(filter_obj, model_inputs, parallel=False, key=None)

Applies offline filtering given a filter object and model inputs.

model_inputs should have leading temporal dimension of length T + 1, where T is the number of time steps excluding the initial state.

Parameters:

Name Type Description Default
filter_obj Filter

The filter inference object.

required
model_inputs ArrayTreeLike

The model inputs (with leading temporal dimension of length T + 1).

required
parallel bool

Whether to run the filter in parallel. Requires filter.associative_filter to be True.

False
key KeyArray | None

The key for the random number generator.

None

Returns:

Type Description
ArrayTree

The filtered states (NamedTuple with leading temporal dimension of length T + 1).

Source code in cuthbert/filtering.py
def filter(
    filter_obj: Filter,
    model_inputs: ArrayTreeLike,
    parallel: bool = False,
    key: KeyArray | None = None,
) -> ArrayTree:
    """Applies offline filtering given a filter object and model inputs.

    `model_inputs` should have leading temporal dimension of length T + 1,
    where T is the number of time steps excluding the initial state.

    Args:
        filter_obj: The filter inference object.
        model_inputs: The model inputs (with leading temporal dimension of length T + 1).
        parallel: Whether to run the filter in parallel.
            Requires `filter.associative_filter` to be `True`.
        key: The key for the random number generator.

    Returns:
        The filtered states (NamedTuple with leading temporal dimension of length T + 1).
    """
    if parallel and not filter_obj.associative:
        warnings.warn(
            f"Parallel filtering attempted but filter.associative is False for {filter_obj}"
        )

    T = tree.leaves(model_inputs)[0].shape[0] - 1

    if key is None:
        # This will throw error if used as a key, which is desired behavior
        # (albeit not a useful error, we could improve this)
        prepare_keys = jnp.empty(T + 1)
    else:
        prepare_keys = random.split(key, T + 1)

    init_model_input = tree.map(lambda x: x[0], model_inputs)
    init_state = filter_obj.init_prepare(init_model_input, key=prepare_keys[0])

    prep_model_inputs = tree.map(lambda x: x[1:], model_inputs)

    if parallel:
        other_prep_states = vmap(lambda inp, k: filter_obj.filter_prepare(inp, key=k))(
            prep_model_inputs, prepare_keys[1:]
        )
        prep_states = tree.map(
            lambda x, y: jnp.concatenate([x[None], y]), init_state, other_prep_states
        )
        states = associative_scan(
            vmap(filter_obj.filter_combine),
            prep_states,
        )
    else:

        def body(prev_state, prep_inp_and_k):
            prep_inp, k = prep_inp_and_k
            prep_state = filter_obj.filter_prepare(prep_inp, key=k)
            state = filter_obj.filter_combine(prev_state, prep_state)
            return state, state

        _, states = scan(
            body,
            init_state,
            (prep_model_inputs, prepare_keys[1:]),
        )
        states = tree.map(
            lambda x, y: jnp.concatenate([x[None], y]), init_state, states
        )

    return states

cuthbert.inference

Provides protocols and types for representing unified inference objects.

Filter

Bases: NamedTuple

Filter object.

Typically passed to cuthbert.filtering.filter.

Attributes:

Name Type Description
init_prepare InitPrepare

Function to prepare the initial state for the filter.

filter_prepare FilterPrepare

Function to prepare intermediate states for the filter.

filter_combine FilterCombine

Function that combines two filter states to produce another.

associative bool

Whether filter_combine is an associative operator. Temporally parallelized filters are guaranteed to produce correct results only if associative=True.

init_prepare instance-attribute

filter_prepare instance-attribute

filter_combine instance-attribute

associative = False class-attribute instance-attribute

InitPrepare

Bases: Protocol

Protocol for preparing the initial state for the inference.

__call__(model_inputs, key=None)

Prepare the initial state for the inference.

The state at the first time point, prior to any observations.

Parameters:

Name Type Description Default
model_inputs ArrayTreeLike

The model inputs at the first time point.

required
key KeyArray | None

The key for the random number generator. Optional, as only used for stochastic inference methods

None

Returns:

Type Description
ArrayTree

The initial state, a NamedTuple with inference-specific fields.

Source code in cuthbert/inference.py
def __call__(
    self,
    model_inputs: ArrayTreeLike,
    key: KeyArray | None = None,
) -> ArrayTree:
    """Prepare the initial state for the inference.

    The state at the first time point, prior to any observations.

    Args:
        model_inputs: The model inputs at the first time point.
        key: The key for the random number generator.
            Optional, as only used for stochastic inference methods

    Returns:
        The initial state, a NamedTuple with inference-specific fields.
    """
    ...

FilterPrepare

Bases: Protocol

Protocol for preparing the state for the filter at the next time point.

__call__(model_inputs, key=None)

Prepare the state for the filter at the next time point.

Converts the model inputs (and any stochasticity) into a unified state object which can be combined with a state (of the same form) from the previous time point with FilterCombine.

state = FilterCombine(prev_state, FilterPrepare(model_inputs, key))

Parameters:

Name Type Description Default
model_inputs ArrayTreeLike

The model inputs at the next time point.

required
key KeyArray | None

The key for the random number generator. Optional, as only used for stochastic inference methods

None

Returns:

Type Description
ArrayTree

The state prepared for FilterCombine, a NamedTuple with inference-specific fields.

Source code in cuthbert/inference.py
def __call__(
    self,
    model_inputs: ArrayTreeLike,
    key: KeyArray | None = None,
) -> ArrayTree:
    """Prepare the state for the filter at the next time point.

    Converts the model inputs (and any stochasticity) into a unified state
    object which can be combined with a state (of the same form) from the
    previous time point with FilterCombine.

    state = FilterCombine(prev_state, FilterPrepare(model_inputs, key))

    Args:
        model_inputs: The model inputs at the next time point.
        key: The key for the random number generator.
            Optional, as only used for stochastic inference methods

    Returns:
        The state prepared for FilterCombine,
            a NamedTuple with inference-specific fields.
    """
    ...

FilterCombine

Bases: Protocol

Protocol for combining the previous state with the state from FilterPrepare.

__call__(state_1, state_2)

Combine state from previous time point with state from FilterPrepare.

state = FilterCombine(prev_state, FilterPrepare(model_inputs, key))

Parameters:

Name Type Description Default
state_1 ArrayTreeLike

The state from the previous time point.

required
state_2 ArrayTreeLike

The state from FilterPrepare for the current time point.

required

Returns:

Type Description
ArrayTree

The combined filter state, a NamedTuple with inference-specific fields.

Source code in cuthbert/inference.py
def __call__(
    self,
    state_1: ArrayTreeLike,
    state_2: ArrayTreeLike,
) -> ArrayTree:
    """Combine state from previous time point with state from FilterPrepare.

    ```python
    state = FilterCombine(prev_state, FilterPrepare(model_inputs, key))
    ```

    Args:
        state_1: The state from the previous time point.
        state_2: The state from FilterPrepare for the current time point.

    Returns:
        The combined filter state, a NamedTuple with inference-specific fields.
    """
    ...