Skip to content

Smoothing

cuthbert.smoothing

Unified cuthbert smoothing interface.

smoother(smoother_obj, filter_states, model_inputs=None, parallel=False, key=None)

Applies offline smoothing given a smoother object, output from filter, and model inputs.

filter_states should have leading temporal dimension of length T + 1, where T is the number of time steps excluding the initial state.

Each element of model_inputs refers to the transition from t to t+1, except for the first element which refers to the initial state. The initial state model_inputs are not used for smoothing. Thus the model_inputs used here have length T. By default, filter_states.model_inputs[1:] are used (i.e. the model_inputs used for the initial state is ignored).

Parameters:

Name Type Description Default
smoother_obj Smoother

The smoother inference object.

required
filter_states ArrayTreeLike

The filtered states (with leading temporal dimension of length T + 1).

required
model_inputs ArrayTreeLike | None

The model inputs (with leading temporal dimension of length T). Optional, if None then filter_states.model_inputs[1:] are used.

None
parallel bool

Whether to run the smoother in parallel. Requires smoother_obj.associative_smoother to be True.

False
key KeyArray | None

The key for the random number generator.

None

Returns:

Type Description
ArrayTree

The smoothed states (NamedTuple with leading temporal dimension of length T + 1).

Source code in cuthbert/smoothing.py
def smoother(
    smoother_obj: Smoother,
    filter_states: ArrayTreeLike,
    model_inputs: ArrayTreeLike | None = None,
    parallel: bool = False,
    key: KeyArray | None = None,
) -> ArrayTree:
    """Applies offline smoothing given a smoother object, output from filter, and model inputs.

    `filter_states` should have leading temporal dimension of length T + 1, where
    T is the number of time steps excluding the initial state.

    Each element of `model_inputs` refers to the transition from t to t+1, except for the
    first element which refers to the initial state. The initial state `model_inputs`
    are not used for smoothing. Thus the `model_inputs` used here have length T.
    By default, `filter_states.model_inputs[1:]` are used (i.e. the `model_inputs`
    used for the initial state is ignored).

    Args:
        smoother_obj: The smoother inference object.
        filter_states: The filtered states (with leading temporal dimension of length T + 1).
        model_inputs: The model inputs (with leading temporal dimension of length T).
            Optional, if None then `filter_states.model_inputs[1:]` are used.
        parallel: Whether to run the smoother in parallel.
            Requires `smoother_obj.associative_smoother` to be `True`.
        key: The key for the random number generator.

    Returns:
        The smoothed states (NamedTuple with leading temporal dimension of length T + 1).
    """
    if parallel and not smoother_obj.associative:
        warnings.warn(
            "Parallel smoothing attempted but smoother.associative is False "
            f"for {smoother}"
        )

    if model_inputs is None:
        model_inputs = filter_states.model_inputs

    T = tree.leaves(filter_states)[0].shape[0] - 1

    # model_inputs for the dynamics distribution from t-1 to t is stored
    # in model_inputs[t] thus we need model_inputs[1:]
    # model_inputs[0] is only used for init_prepare and not for smoothing.
    # Therefore, we allow model_inputs to be either of length T + 1 or T
    # where if length is T + 1 then we simply discard model_inputs[0]
    model_inputs_length = tree.leaves(model_inputs)[0].shape[0]
    if model_inputs_length == T + 1:
        model_inputs = tree.map(lambda x: x[1:], model_inputs)
    elif model_inputs_length != T:
        raise ValueError(
            "model_inputs must have length T + 1 or T, got length "
            f"{model_inputs_length}"
        )

    if key is None:
        # This will throw error if used as a key, which is desired
        # (albeit not a useful error, we could improve this)
        prepare_keys = jnp.empty(T + 1)
    else:
        prepare_keys = random.split(key, T + 1)

    final_filter_state = tree.map(lambda x: x[-1], filter_states)
    other_filter_states = tree.map(lambda x: x[:-1], filter_states)

    # Final smoother state doesn't need model inputs, so we create a dummy one
    # with the same structure as model_inputs but with all values set to dummy values.
    dummy_single_model_inputs = dummy_tree_like(tree.map(lambda x: x[0], model_inputs))

    final_smoother_state = smoother_obj.convert_filter_to_smoother_state(
        final_filter_state, model_inputs=dummy_single_model_inputs, key=prepare_keys[0]
    )

    if parallel:
        prep_states = vmap(
            lambda fs, inp, k: smoother_obj.smoother_prepare(
                fs, model_inputs=inp, key=k
            )
        )(other_filter_states, model_inputs, prepare_keys[1:])
        prep_states = tree.map(
            lambda x, y: jnp.concatenate([x, y[None]]),
            prep_states,
            final_smoother_state,
        )

        states = associative_scan(
            vmap(lambda current, next: smoother_obj.smoother_combine(next, current)),
            # TODO: Maybe change cuthbertlib direction so that this lambda isn't needed
            prep_states,
            reverse=True,
        )
    else:

        def body(next_state, filt_state_and_prep_inp_and_k):
            filt_state, prep_inp, k = filt_state_and_prep_inp_and_k
            prep_state = smoother_obj.smoother_prepare(
                filt_state, model_inputs=prep_inp, key=k
            )
            state = smoother_obj.smoother_combine(prep_state, next_state)
            return state, state

        _, states = scan(
            body,
            final_smoother_state,
            (other_filter_states, model_inputs, prepare_keys[1:]),
            reverse=True,
        )

        states = tree.map(
            lambda x, y: jnp.concatenate([x, y[None]]),
            states,
            final_smoother_state,
        )

    return states

cuthbert.inference

Provides protocols and types for representing unified inference objects.

Smoother

Bases: NamedTuple

Smoother object.

Typically passed to cuthbert.smoothing.smoother.

Attributes:

Name Type Description
convert_filter_to_smoother_state ConvertFilterToSmootherState

Function to convert the final filter state to a smoother state.

smoother_prepare SmootherPrepare

Function to prepare intermediate states for the smoother.

smoother_combine SmootherCombine

Function that combines two smoother states to produce another.

associative bool

Whether smoother_combine is an associative operator. Temporally parallelized smoothers are guaranteed to produce correct results only if associative=True.

convert_filter_to_smoother_state instance-attribute

smoother_prepare instance-attribute

smoother_combine instance-attribute

associative = False class-attribute instance-attribute

SmootherPrepare

Bases: Protocol

Protocol for preparing the state for the smoother.

__call__(filter_state, model_inputs, key=None)

Prepare the state for the smoother at the next time point.

Converts filter_state with model_inputs (and any stochasticity) into a unified state object which can be combined with a state (of the same form) from the next time point with SmootherCombine.

Remember smoothing iterates backwards in time.

state = SmootherCombine(
    SmootherPrepare(filter_state, model_inputs, key), next_smoother_state
)

Note that the model_inputs here are different to filter_state.model_inputs. The model_inputs required here are for the transition from t to t+1. filter_state.model_inputs represents the transition from t-1 to t.

Parameters:

Name Type Description Default
filter_state ArrayTreeLike

The state from the filter at the previous time point.

required
model_inputs ArrayTreeLike

Model inputs for the transition from t to t+1.

required
key KeyArray | None

The key for the random number generator. Optional, as only used for stochastic inference methods

None

Returns:

Type Description
ArrayTree

The state prepared for SmootherCombine, a NamedTuple with inference-specific fields.

Source code in cuthbert/inference.py
def __call__(
    self,
    filter_state: ArrayTreeLike,
    model_inputs: ArrayTreeLike,
    key: KeyArray | None = None,
) -> ArrayTree:
    """Prepare the state for the smoother at the next time point.

    Converts `filter_state` with `model_inputs` (and any stochasticity) into a
    unified state object which can be combined with a state (of the same form)
    from the next time point with `SmootherCombine`.

    Remember smoothing iterates backwards in time.

    ```python
    state = SmootherCombine(
        SmootherPrepare(filter_state, model_inputs, key), next_smoother_state
    )
    ```

    Note that the `model_inputs` here are different to `filter_state.model_inputs`.
    The `model_inputs` required here are for the transition from t to t+1.
    `filter_state.model_inputs` represents the transition from t-1 to t.

    Args:
        filter_state: The state from the filter at the previous time point.
        model_inputs: Model inputs for the transition from t to t+1.
        key: The key for the random number generator.
            Optional, as only used for stochastic inference methods

    Returns:
        The state prepared for `SmootherCombine`,
            a NamedTuple with inference-specific fields.
    """
    ...

SmootherCombine

Bases: Protocol

Protocol for combining the next smoother state with the state prepared with latest model inputs.

__call__(state_1, state_2)

Combine the state from the next time point with the state from SmootherPrepare.

Remember smoothing iterates backwards in time.

state = SmootherCombine(
    SmootherPrepare(filter_state, model_inputs, key), next_smoother_state
)

Parameters:

Name Type Description Default
state_1 ArrayTreeLike

The state from SmootherPrepare for the current time point.

required
state_2 ArrayTreeLike

The state from the next time point.

required

Returns:

Type Description
ArrayTree

The combined smoother state, a NamedTuple with inference-specific fields.

Source code in cuthbert/inference.py
def __call__(
    self,
    state_1: ArrayTreeLike,
    state_2: ArrayTreeLike,
) -> ArrayTree:
    """Combine the state from the next time point with the state from `SmootherPrepare`.

    Remember smoothing iterates backwards in time.

    ```python
    state = SmootherCombine(
        SmootherPrepare(filter_state, model_inputs, key), next_smoother_state
    )
    ```

    Args:
        state_1: The state from `SmootherPrepare` for the current time point.
        state_2: The state from the next time point.

    Returns:
        The combined smoother state, a NamedTuple with inference-specific fields.
    """
    ...

ConvertFilterToSmootherState

Bases: Protocol

Protocol for converting a filter state to a smoother state.

__call__(filter_state, model_inputs=None, key=None)

Convert the filter state to a smoother state.

Useful for offline smoothing where the final filter state is statistically equivalent to the final smoother state. This function converts the filter state to the smoother state data structure.

Parameters:

Name Type Description Default
filter_state ArrayTreeLike

The filter state.

required
model_inputs ArrayTreeLike | None

Only used to create an empty model_inputs tree (the values are ignored). Useful so that the final smoother state has the same structure as the rest. By default, filter_state.model_inputs is used. So this is only needed if the smoother model_inputs have a different tree structure to filter_state.model_inputs.

None
key KeyArray | None

The key for the random number generator. Optional, as only used for stochastic inference methods

None

Returns:

Type Description
ArrayTree

The smoother state.

Source code in cuthbert/inference.py
def __call__(
    self,
    filter_state: ArrayTreeLike,
    model_inputs: ArrayTreeLike | None = None,
    key: KeyArray | None = None,
) -> ArrayTree:
    """Convert the filter state to a smoother state.

    Useful for offline smoothing where the final filter state is statistically
    equivalent to the final smoother state.
    This function converts the filter state to the smoother state data structure.

    Args:
        filter_state: The filter state.
        model_inputs: Only used to create an empty `model_inputs` tree
            (the values are ignored).
            Useful so that the final smoother state has the same structure as the rest.
            By default, `filter_state.model_inputs` is used. So this
            is only needed if the smoother `model_inputs` have a different tree
            structure to `filter_state.model_inputs`.
        key: The key for the random number generator.
            Optional, as only used for stochastic inference methods

    Returns:
        The smoother state.
    """
    ...